In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as T
from torchvision.io import read_image
import timm
from timm import create_model

import pytorch_lightning as pl
from pytorch_lightning import callbacks
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningDataModule

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import mean_squared_error

import glob
import gc

from data_loaders import PetFinderDataModule, columns
from transforms import train_transforms, test_transforms, mixup

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed=999
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms = True

DATA_DIR = "data"
TRAIN_DIR = "data/train"
TEST_DIR = "data/test"
OUTPUT_DIR = "output"

In [3]:
df = pd.read_csv("data/train.csv")
df['normalised_score'] = df['Pawpularity'] / 100

# Sturges rule https://www.statology.org/sturges-rule/
# We use the bins split our data based on Pawpularity into multiple bins, to perform StratifiedKFold later
n_bins = int(np.ceil(1 + (np.log2(len(df)))))
df['bins'] = pd.cut(df['normalised_score'], bins=n_bins, labels=False)

In [4]:
NUM_CLASSES = 50
oof_predictions = { "ids": [], "predictions": [], "target": [], "fold": [] }

class PawpularityModel(pl.LightningModule):
    def __init__(self, model_name="swin_large_patch4_window7_224", pretrained=True):
        super().__init__()
        self.validation_step_outputs = []
        self.training_step_outputs = []

        self.fold = 1

        self.backbone = create_model(model_name, pretrained=pretrained, num_classes=NUM_CLASSES, in_chans=3).to('cuda')
        self.criterion = nn.BCEWithLogitsLoss()
        
    def forward(self, input, features):
        x = self.backbone(input)

        x = torch.cat([x, features], dim=1)
        x = torch.sigmoid(x).sum(1) / (NUM_CLASSES + len(features))

        return x

    def training_step(self, batch, batch_indexes):
        loss, predictions, labels, rmse = self.step(batch, 'train')
        self.training_step_outputs.append({ "rmse": rmse, "loss": loss })

        return { 'loss': loss, 'predictions': predictions, 'labels': labels }

    def validation_step(self, batch, batch_indexes):
        loss, predictions, labels, rmse = self.step(batch, 'val')
        self.validation_step_outputs.append({ "rmse": rmse, "loss": loss })

        image_ids, _, _, _ = batch
        oof_predictions["ids"].append(image_ids)
        oof_predictions["predictions"].append(predictions.detach().numpy())
        oof_predictions["target"].append(labels.detach().numpy())
        oof_predictions["fold"].append([self.fold] * len(image_ids))
        
        return { 'loss': loss, 'predictions': predictions, 'labels': labels }
    
    def step(self, batch, mode):
        image_ids, features, images, labels = batch
        labels = labels.float() / 100.0

        images = train_transforms(images) if mode == "train" else test_transforms(images)

        if torch.rand(1)[0] < 0.5 and mode == 'train' and len(images) > 1:
            mix_images, target_a, target_b, lam = mixup(images, labels, alpha=1.0)
            logits = self.forward(mix_images, features).squeeze(-1)
            loss = self.criterion(logits, target_a) * lam + (1 - lam) * self.criterion(logits, target_b)
        else:
            logits = self.forward(images, features).squeeze(-1)
            loss = self.criterion(logits, labels)

        predictions = logits.detach().cpu() * 100
        labels = labels.detach().cpu() * 100
        
        rmse = mean_squared_error(predictions, labels, squared=False) # loss uses BCELoss, while we still calculate RMSE to check
        rmse = torch.tensor(rmse, dtype=torch.float32)

        self.log(f'{mode}_loss', loss)
        
        return loss, predictions, labels, rmse

    def on_train_epoch_end(self):
        rsmes = [x["rmse"] for x in self.training_step_outputs]
        rsme = torch.stack(rsmes).mean()

        self.log(f'train_rmse', rsme, prog_bar=True)

        self.training_step_outputs.clear()

    def on_validation_epoch_end(self):
        rsmes = [x["rmse"] for x in self.validation_step_outputs]
        rsme = torch.stack(rsmes).mean()

        self.log(f'val_rmse', rsme, prog_bar=True)
        
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = 20, eta_min=1e-4)

        return [optimizer], [scheduler]

In [5]:
torch.cuda.empty_cache()
gc.collect()

n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)

df_test = pd.read_csv(os.path.join(DATA_DIR, "test.csv"))
for fold_index, (train_index, val_index) in enumerate(skf.split(df.index, df['bins'])):
    df_train = df.iloc[train_index]
    df_val = df.iloc[val_index]

    df_train = df_train.reset_index(drop=True)
    df_val = df_val.reset_index(drop=True)

    data_module = PetFinderDataModule(
        df_train=df_train, 
        df_val=df_val, 
        df_test=df_test, 
        train_dir=TRAIN_DIR, 
        val_dir=TRAIN_DIR, 
        test_dir=TEST_DIR, 
        batch_size=8,
        image_size=224
    )

    model_name = "swin_large_patch4_window7_224"
    model = PawpularityModel(model_name=model_name, pretrained=True)

    early_stopping = EarlyStopping(monitor="val_loss")
    lr_monitor = callbacks.LearningRateMonitor()
    loss_checkpoint = callbacks.ModelCheckpoint(dirpath=os.path.join(OUTPUT_DIR, "model_checkpoints"), filename="best_loss", monitor="val_loss", save_top_k=1, mode="min", save_last=False)

    logger = TensorBoardLogger(os.path.join(OUTPUT_DIR, "logs"))

    trainer = pl.Trainer(max_epochs=1, callbacks=[lr_monitor, loss_checkpoint, early_stopping], logger=logger)
    trainer.fit(model, datamodule=data_module)
    trainer.validate(model, datamodule=data_module)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: output\logs\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | backbone  | SwinTransformer   | 195 M 
1 | criterion | BCEWithLogitsLoss | 0     
------------------------------------------------
195 M     Trainable params
0         Non-trainable params
195 M     Total params
780.289   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:  64%|██████▍   | 633/992 [2:02:14<1:09:19, 11.59s/it, v_num=0]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Validation DataLoader 0:   1%|          | 2/248 [00:07<15:10,  3.70s/it]

In [12]:
for ids, predictions in zip(oof_predictions["ids"], oof_predictions["predictions"]):
    for batch_ids, batch_predictions in zip(ids, predictions):
        print(batch_ids, batch_predictions.detach().numpy())

0009c66b9439883ba2750fb825e1d7db 59.48245
006cda7fec46a527f9f627f4722a2304 59.402954
006fe962f5f7e2c5f527b2e27e28ed6d 60.606064
0075ec6503412f21cf65ac5f43d80440 58.60515
0009c66b9439883ba2750fb825e1d7db 58.964478
006cda7fec46a527f9f627f4722a2304 58.74107
006fe962f5f7e2c5f527b2e27e28ed6d 60.1452
0075ec6503412f21cf65ac5f43d80440 58.093388
0009c66b9439883ba2750fb825e1d7db 58.964478
006cda7fec46a527f9f627f4722a2304 58.74107
006fe962f5f7e2c5f527b2e27e28ed6d 60.1452
0075ec6503412f21cf65ac5f43d80440 58.093388
0007de18844b0dbbb5e1f607da0606e0 58.884144
001dc955e10590d3ca4673f034feeef2 58.32569
005017716086b8d5e118dd9fe26459b1 59.138496
00655425c10d4c082dd7eeb97fa4fb17 58.69689
0007de18844b0dbbb5e1f607da0606e0 58.385883
001dc955e10590d3ca4673f034feeef2 57.50579
005017716086b8d5e118dd9fe26459b1 58.714615
00655425c10d4c082dd7eeb97fa4fb17 57.986305
0007de18844b0dbbb5e1f607da0606e0 58.385883
001dc955e10590d3ca4673f034feeef2 57.50579
005017716086b8d5e118dd9fe26459b1 58.714615
00655425c10d4c082dd7eeb