In [3]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as T
from torchvision.io import read_image
import timm
from timm import create_model

import pytorch_lightning as pl
from pytorch_lightning import callbacks
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningDataModule

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import mean_squared_error

import glob
import gc

from data_loaders import PetFinderDataModule, columns
from transforms import train_transforms, test_transforms, mixup

In [19]:
seed=999
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms = True

In [33]:
df = pd.read_csv("data/train.csv")
df['normalised_score'] = df['Pawpularity'] / 100

# Sturges rule https://www.statology.org/sturges-rule/
# We use the bins split our data based on Pawpularity into multiple bins, to perform StratifiedKFold later
n_bins = int(np.ceil(1 + (np.log2(len(df)))))
df['bins'] = pd.cut(df['normalised_score'], bins=n_bins, labels=False)

n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True)

df['fold'] = -1
for fold_index, (_, train_index) in enumerate(skf.split(df.index, df['bins'])):
    df.iloc[train_index, -1] = fold_index

df['fold'] = df['fold'].astype('int')

In [None]:
class PawpularityModel(pl.LightningModule):
    def __init__(self, model_name="swin_large_patch4_window7_224", pretrained=True):
        super().__init__()
        self.validation_step_outputs = []
        self.training_step_outputs = []

        self.backbone = create_model(model_name, pretrained=pretrained, num_classes=0, in_chans=3).to('cuda')
        self.dropout = nn.Dropout(0.1)

        num_features = self.backbone.num_features
        self.fc = nn.Sequential(
            nn.Linear(num_features + len(columns), 1)
        )
        
        self.criterion = nn.BCEWithLogitsLoss()
        
    def forward(self, input, features):
        x = self.backbone(input)
        x = self.dropout(x)

        x = torch.cat([x, features], dim=1)
        x = self.fc(x)

        return x

    def training_step(self, batch, batch_indexes):
        loss, predictions, labels, rmse = self.step(batch, 'train')
        self.training_step_outputs.append({ "rmse": rmse, "loss": loss })

        return { 'loss': loss, 'predictions': predictions, 'labels': labels }

    def validation_step(self, batch, batch_indexes):
        loss, predictions, labels, rmse = self.step(batch, 'val')
        self.validation_step_outputs.append({ "rmse": rmse, "loss": loss })
        
        return { 'loss': loss, 'predictions': predictions, 'labels': labels }
    
    def step(self, batch, mode):
        image_ids, features, images, labels = batch
        labels = labels.float() / 100.0

        images = train_transforms(images) if mode == "train" else test_transforms(images)

        if torch.rand(1)[0] < 0.5 and mode == 'train' and len(images) > 1:
            mix_images, target_a, target_b, lam = mixup(images, labels, alpha=1.0)
            logits = self.forward(mix_images, features).squeeze(1)
            loss = self.criterion(logits, target_a) * lam + (1 - lam) * self.criterion(logits, target_b)
        else:
            logits = self.forward(images, features).squeeze(1)
            loss = self.criterion(logits, labels)

        predictions = logits.sigmoid().detach().cpu() * 100
        labels = labels.detach().cpu() * 100
        
        rmse = mean_squared_error(predictions, labels, squared=False) # loss uses BCELoss, while we still calculate RMSE to check
        rmse = torch.tensor(rmse, dtype=torch.float32)

        self.log(f'{mode}_loss', loss)
        
        return loss, predictions, labels, rmse

    def on_train_epoch_end(self):
        rsmes = [x["rmse"] for x in self.training_step_outputs]
        rsme = torch.stack(rsmes).mean()

        self.log(f'train_rmse', rsme, prog_bar=True)

        self.training_step_outputs.clear()

    def on_validation_epoch_end(self):
        rsmes = [x["rmse"] for x in self.validation_step_outputs]
        rsme = torch.stack(rsmes).mean()

        self.log(f'val_rmse', rsme, prog_bar=True)
        
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = 20, eta_min=1e-4)

        return [optimizer], [scheduler]