In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import copy
import numpy as np
import random
import pandas as pd
import os

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.io import read_image
from timm import create_model

import pytorch_lightning as pl
from pytorch_lightning import callbacks
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningDataModule, LightningModule

from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold

import glob

We already have 2 directories, `train` and `test`. 

Let us check the number of test and train instances

In [None]:
print(f"Number of training instances: {len(os.listdir('data/train'))}")
print(f"Number of test instances: {len(os.listdir('data/test'))}")

In [None]:
class PetFinderDataset(Dataset):
    def __init__(self, df, image_dir, image_size=224):
        self.X = df["Id"].values
        self.y = None

        if "Pawpularity" in df.keys():
            self.y = df["Pawpularity"].values

        self.image_dir = image_dir
        self.transform = T.Resize([image_size, image_size])
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        image_path = self.X[idx]
        image = read_image(os.path.join(self.image_dir, image_path + '.jpg'))
        image = self.transform(image)
        
        if self.y is not None:
            label = self.y[idx]
            return image, label
        
        return image

class PetFinderDataModule(LightningDataModule):
    def __init__(self, df_train, df_val):
        super().__init__()
        self.df_train = df_train
        self.df_val = df_val

    def train_dataloader(self):
        return DataLoader(PetFinderDataset(self.df_train, "data/train"), batch_size=8, shuffle=True)

    def val_dataloader(self):
        return DataLoader(PetFinderDataset(self.df_val, "data/train"), batch_size=8, shuffle=False) # not recommended to shuffle val/test dataloaders


In [None]:
df = pd.read_csv("data/train.csv")
df = df[:50]

dataloader = PetFinderDataModule(df, df).val_dataloader()
dataiter = iter(dataloader)
images, labels = next(dataiter)

plt.figure(figsize=(12, 12))
for it, (image, label) in enumerate(zip(images[:16], labels[:16])):
    plt.subplot(4, 4, it+1)
    plt.imshow(image.permute(1, 2, 0))
    plt.axis('off')
    plt.title(f'Pawpularity: {int(label)}')

Now, lets define our model and train it

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]  # RGB
IMAGENET_STD = [0.229, 0.224, 0.225]  # RGB

train_transforms = T.Compose(
    [
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
        T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        T.ConvertImageDtype(torch.float),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]
)

test_transforms = T.Compose(
    [
        T.ConvertImageDtype(torch.float),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]
)

# https://arxiv.org/abs/1710.09412v2
def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
    assert alpha > 0, "alpha should be larger than 0"
    assert x.size(0) > 1, "Mixup cannot be applied to a single instance."

    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(x.size()[0])
    mixed_x = lam * x + (1 - lam) * x[rand_index, :]
    target_a, target_b = y, y[rand_index]
    return mixed_x, target_a, target_b, lam

# Defining Model

In [None]:
class PawpularityModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.validation_step_outputs = []
        self.training_step_outputs = []

        self.backbone = create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=0, in_chans=3)
        num_features = self.backbone.num_features
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 1)
        )

        self.criterion = nn.MSELoss()
        self.train_transforms = train_transforms
        self.test_transforms = test_transforms
        
    def forward(self, input):
        return self.fc(self.backbone(input))

    def step(self, batch, mode):
        images, labels = batch
        labels = labels.float()
        images = self.train_transforms(images) if mode == "train" else self.test_transforms(images)

        if torch.rand(1)[0] < 0.5 and mode == 'train' and len(images) > 1:
            mix_images, target_a, target_b, lam = mixup(images, labels, alpha=0.5)
            logits = self.forward(mix_images).squeeze(1)
            loss = self.criterion(logits, target_a) * lam + (1 - lam) * self.criterion(logits, target_b)
        else:
            logits = self.forward(images).squeeze(1)
            loss = self.criterion(logits, labels)

        predictions = logits.detach().cpu()
        labels = labels.detach().cpu()

        print(predictions, labels)
        self.log(f'{mode}_loss', loss)
        
        return loss, predictions, labels

    def training_step(self, batch, batch_indexes):
        loss, predictions, labels = self.step(batch, 'train')
        return { 'loss': loss, 'predictions': predictions, 'labels': labels }

    def validation_step(self, batch, batch_indexes):
        loss, predictions, labels = self.step(batch, 'val')
        return { 'loss': loss, 'predictions': predictions, 'labels': labels }

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0 = 20, eta_min=1e-4)

        return [optimizer], [scheduler]

In [None]:
skf = StratifiedKFold(n_splits=5, shuffle=True)

for fold, (train_idx, val_idx) in enumerate(skf.split(df["Id"], df["Pawpularity"])):
    df_train = df.loc[train_idx].reset_index(drop=True)
    df_val = df.loc[val_idx].reset_index(drop=True)

    data_module = PetFinderDataModule(df_train, df_val)
    model = PawpularityModel()

    early_stopping = EarlyStopping(monitor="val_loss")
    lr_monitor = callbacks.LearningRateMonitor()
    loss_checkpoint = callbacks.ModelCheckpoint(filename="best_loss", monitor="val_loss", save_top_k=1, mode="min", save_last=False)

    logger = TensorBoardLogger("logs/swin_224")

    trainer = pl.Trainer(max_epochs=10, callbacks=[lr_monitor, loss_checkpoint, early_stopping], logger=logger)
    trainer.fit(model, datamodule=data_module)

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

path = glob.glob(f'./lightning_logs/version_0/events*')[0]
event_acc = EventAccumulator(path, size_guidance={'scalars': 0})
event_acc.Reload()

scalars = {}
for tag in event_acc.Tags()['scalars']:
    events = event_acc.Scalars(tag)
    scalars[tag] = [event.value for event in events]

In [None]:
import seaborn as sns
sns.set()

plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.plot(range(len(scalars['lr-AdamW'])), scalars['lr-AdamW'])
plt.xlabel('epoch')
plt.ylabel('lr')
plt.title('adamw lr')

plt.subplot(1, 2, 2)
plt.plot(range(len(scalars['train_loss'])), scalars['train_loss'], label='train_loss')
plt.plot(range(len(scalars['val_loss'])), scalars['val_loss'], label='val_loss')
plt.legend()
plt.ylabel('rmse')
plt.xlabel('epoch')
plt.title('train/val rmse')
plt.show()

In [None]:
cols = 4

model = PawpularityModel()
checkpoint = torch.load("tensorboardlogger/swin_224/lightning_logs/version_0/checkpoints/best_loss.ckpt")

model.load_state_dict(checkpoint['state_dict'])

data_module = PetFinderDataModule(df, df)
dataiter = iter(data_module.val_dataloader())
batch = next(dataiter)
images, labels = batch
rows = len(images) // cols + 1

figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 8))
for i in range(len(images)):
    image = images[i]
    label = labels[i]
    with torch.no_grad():
        prediction = model(torch.as_tensor(image, dtype=torch.float32, device='cpu').unsqueeze(0))
        prediction = prediction.cpu().numpy()[0][0]

    ax.ravel()[i].imshow(image.permute(1, 2, 0))
    ax.ravel()[i].set_axis_off()
    ax.ravel()[i].set_title(f"{round(prediction)}, Actual: {label}")

plt.tight_layout(pad=1)
plt.show()