# Two-stage Finetuning

Two stages of finetuning:
1. Finetune the pretrained model to capture the metadata.
2. Finetune the pretrained model (tuned after step 1.) along with a regressor to predict Pawpularity score.

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torchvision
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from typing import Optional
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn, optim
from torchvision.io import read_image
from torchvision.transforms import ConvertImageDtype, Resize, Normalize
from torchvision.transforms import RandomHorizontalFlip, Compose
from pytorch_lightning.callbacks import ModelCheckpoint

pd.set_option("display.max_rows", None)
plt.style.use('ggplot')

# Dataset & Datamodule

In [None]:
# Duplicate images, credit to https://www.kaggle.com/valleyzw/petfinder-duplicate-images

SIMILAR_IDS = [
    '13d215b4c71c3dc603cd13fc3ec80181_373c763f5218610e9b3f82b12ada8ae5',
    '5ef7ba98fc97917aec56ded5d5c2b099_67e97de8ec7ddcda59a58b027263cdcc',
    '839087a28fa67bf97cdcaf4c8db458ef_a8f044478dba8040cc410e3ec7514da1',
    '1feb99c2a4cac3f3c4f8a4510421d6f5_264845a4236bc9b95123dde3fb809a88',
    '3c50a7050df30197e47865d08762f041_def7b2f2685468751f711cc63611e65b',
    '37ae1a5164cd9ab4007427b08ea2c5a3_3f0222f5310e4184a60a7030da8dc84b',
    '5a642ecc14e9c57a05b8e010414011f2_c504568822c53675a4f425c8e5800a36',
    '2a8409a5f82061e823d06e913dee591c_86a71a412f662212fe8dcd40fdaee8e6',
    '3c602cbcb19db7a0998e1411082c487d_a8bb509cd1bd09b27ff5343e3f36bf9e',
    '0422cd506773b78a6f19416c98952407_0b04f9560a1f429b7c48e049bcaffcca',
    '68e55574e523cf1cdc17b60ce6cc2f60_9b3267c1652691240d78b7b3d072baf3',
    '1059231cf2948216fcc2ac6afb4f8db8_bca6811ee0a78bdcc41b659624608125',
    '5da97b511389a1b62ef7a55b0a19a532_8ffde3ae7ab3726cff7ca28697687a42',
    '78a02b3cb6ed38b2772215c0c0a7f78e_c25384f6d93ca6b802925da84dfa453e',
    '08440f8c2c040cf2941687de6dc5462f_bf8501acaeeedc2a421bac3d9af58bb7',
    '0c4d454d8f09c90c655bd0e2af6eb2e5_fe47539e989df047507eaa60a16bc3fd',
    '5a5c229e1340c0da7798b26edf86d180_dd042410dc7f02e648162d7764b50900',
    '871bb3cbdf48bd3bfd5a6779e752613e_988b31dd48a1bc867dbc9e14d21b05f6',
    'dbf25ce0b2a5d3cb43af95b2bd855718_e359704524fa26d6a3dcd8bfeeaedd2e',
    '43bd09ca68b3bcdc2b0c549fd309d1ba_6ae42b731c00756ddd291fa615c822a1',
    '43ab682adde9c14adb7c05435e5f2e0e_9a0238499efb15551f06ad583a6fa951',
    'a9513f7f0c93e179b87c01be847b3e4c_b86589c3e85f784a5278e377b726a4d4',
    '38426ba3cbf5484555f2b5e9504a6b03_6cb18e0936faa730077732a25c3dfb94',
    '589286d5bfdc1b26ad0bf7d4b7f74816_cd909abf8f425d7e646eebe4d3bf4769',
    '9f5a457ce7e22eecd0992f4ea17b6107_b967656eb7e648a524ca4ffbbc172c06',
    'b148cbea87c3dcc65a05b15f78910715_e09a818b7534422fb4c688f12566e38f',
    '3877f2981e502fe1812af38d4f511fd2_902786862cbae94e890a090e5700298b',
    '8f20c67f8b1230d1488138e2adbb0e64_b190f25b33bd52a8aae8fd81bd069888',
    '221b2b852e65fe407ad5fd2c8e9965ef_94c823294d542af6e660423f0348bf31',
    '2b737750362ef6b31068c4a4194909ed_41c85c2c974cc15ca77f5ababb652f84',
    '01430d6ae02e79774b651175edd40842_6dc1ae625a3bfb50571efedc0afc297c',
    '72b33c9c368d86648b756143ab19baeb_763d66b9cf01069602a968e573feb334',
    '03d82e64d1b4d99f457259f03ebe604d_dbc47155644aeb3edd1bd39dba9b6953',
    '851c7427071afd2eaf38af0def360987_b49ad3aac4296376d7520445a27726de',
    '54563ff51aa70ea8c6a9325c15f55399_b956edfd0677dd6d95de6cb29a85db9c',
    '87c6a8f85af93b84594a36f8ffd5d6b8_d050e78384bd8b20e7291b3efedf6a5b',
    '04201c5191c3b980ae307b20113c8853_16d8e12207ede187e65ab45d7def117b'
]

SIMILAR_PAIRS = pd.Series(SIMILAR_IDS).str.extract(r"(?P<first>\w+)_(?P<second>\w+)")

In [None]:
# Remove duplicate images with lower Pawpularity scores
# following https://www.kaggle.com/c/petfinder-pawpularity-score/discussion/285140

def drop_duplicates(meta, pairs):
    """
    Return metadata where only the duplicate with the highest Pawpularity score
    is kept.
    """
    # Find duplicates
    meta = meta.set_index('Id')
    duplicates = pairs.apply(
        lambda row: (
            row['first']
            if meta.at[row['first'], 'Pawpularity'] < meta.at[row['second'], 'Pawpularity']
            else row['second']
        ),
        axis=1
    )
    # Query non-duplicate metadata
    meta.query('Id not in @duplicates', inplace=True)
    meta.reset_index(inplace=True)
    return meta

In [None]:
class PetfinderDataset(Dataset):
    """Training/Testing dataset of Petfinder profiles."""
    
    def __init__(
        self,
        train: bool = True,
        img_transform = None,
        meta_transform = None,
        score_transform = None
    ):
        """
        Arguments
        ---------
            train: Whether the training dataset or the testing dataset
            img_transform: Transformation of images
            meta_transform: Transformation of metadata
            socre_transform: Transformation of pawpularity scores
        
        Note
        ----
        `score_transform` is not supported if `train` is `False`.
        """
        self.dirpath = '../input/petfinder-pawpularity-score'
        self.img_dir = 'train' if train else 'test'
        self.meta = pd.read_csv(
            os.path.join(self.dirpath, 'train.csv' if train else 'test.csv')
        )
        if train: self.meta = drop_duplicates(self.meta, SIMILAR_PAIRS)
        self.metacols = self.meta.columns.drop(
            ['Id', 'Pawpularity'] if train else 'Id'
        )
        self.train = train
        self.img_transform = img_transform
        self.meta_transform = meta_transform
        self.score_transform = score_transform
    
    def __len__(self):
        return len(self.meta.index)
    
    def __getitem__(self, idx):
        """
        Return the image, metadata and score given the index of a sample.
        
        Note
        ----
        If `self.train` is `False`, the returned score will be -1.
        """
        # Obtain image, metadata and score
        ind = self.meta.index[idx]  # index in metadata
        img_path = os.path.join(
            self.dirpath, self.img_dir,
            f"{self.meta.loc[ind, 'Id']}.jpg"
        )
        img = read_image(img_path)
        meta = self.meta.loc[ind, self.metacols]
        meta = meta.values.astype(np.float32)  # convert data type
        score = self.meta.loc[ind, 'Pawpularity'] if self.train else -1.0
        score = np.float32(score)  # convert data type
        # Apply transformations
        if self.img_transform is not None:
            img = self.img_transform(img)
        if self.meta_transform is not None:
            meta = self.meta_transform(meta)
        if self.train and self.score_transform is not None:
            score = self.score_transform(score)
        return img, meta, score

In [None]:
class PetfinderDataModule(pl.LightningDataModule):
    """Data module of Petfinder profiles."""
    
    def __init__(
        self,
        image_size: int = 224,
        batch_size: int = 64,
        num_validation: int = 128
    ):
        """
        Arguments
        ---------
            image_size: Size of square images after transformations
            batch_size: Batch size loading training/validation dataset
            num_validataion: Number of observations in validataion dataset
        """
        super().__init__()
        self.image_size = image_size
        self.batch_size = batch_size
        self.num_validation = num_validation
    
    def setup(self, stage: Optional[str] = None):
        # Transformations
        prerequisite = [  # required transformations for pretrained model
            ConvertImageDtype(torch.float32),
            Resize((self.image_size, self.image_size)),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]
        augmentation = [RandomHorizontalFlip(),]  # data augmentation1
        transforms = {'img_transform': Compose(prerequisite+augmentation)}
        # Split training set and validation set
        if stage in (None, 'fit'):
            self.dataset = PetfinderDataset(train=True, **transforms)
            self.trainset, self.valset = random_split(
                self.dataset,
                [len(self.dataset)-self.num_validation, self.num_validation]
            )
        # Load dataset for prediction
        if stage == 'predict':
            self.predictset = PetfinderDataset(train=False, **transforms)
    
    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=BATCH_SIZE, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.valset, batch_size=BATCH_SIZE)
    
    def predict_dataloader(self):
        return DataLoader(self.predictset, batch_size=len(self.predictset))
    
    def num_meta(self):
        """
        Return number of features in the metadata.
        
        Note
        ----
        Must be called after running self.setup().
        """
        return len(self.dataset.metacols)
    
    def meta_odds(self):
        """
        Return the odds against features in the metadata.
        
        Note
        ----
        Must be called after running self.setup().
        """
        pos_rate = self.dataset.meta.loc[:, self.dataset.metacols].mean()
        pos_rate = torch.from_numpy(pos_rate.values).float()
        return (1 - pos_rate) / pos_rate

# Model & Training/Validation Step

In [None]:
class Regressor(nn.Module):
    """Custom regressor to predict Pawpularity score from latent features."""
    
    def __init__(
        self,
        num_feats: int,
        dropout: float = 0.5,
        scale: float = 100.0
    ):
        """
        Arguments
        ---------
            num_feats: Input feature dimension
            dropout: Dropout rate
            scale: Scale of predicted scores; scores are in range of `[0, scale]`.
        """
        super().__init__()
        dim_intermediate = int(num_feats ** 0.5)
        self.dropout = nn.Dropout(dropout)
        self.layer1 = nn.Sequential(
            nn.Linear(num_feats, dim_intermediate),
            nn.BatchNorm1d(dim_intermediate),
            nn.Dropout(dropout),
            nn.Sigmoid()
        )
        self.layer2 = nn.Sequential(
            nn.Linear(dim_intermediate, 1),
            nn.Sigmoid()
        )
        self.scale = scale
    
    def forward(self, feats):
        feats = self.dropout(feats)
        feats = self.layer1(feats)
        preds = self.layer2(feats)
        preds = self.scale * preds
        return preds

In [None]:
class PawpularityPredictor(pl.LightningModule):
    """Transfer learning model with two-stage finetuning."""
    
    def __init__(
        self,
        backbone: str ='resnet_18',
        training_phase: str = 'regression',
        num_meta: int = 12,
        pos_weight: torch.Tensor = torch.ones(12),
        classification_threshold: float = 0.5,
        regressor_kwargs: Optional[dict] = None
    ):
        """
        Arguments
        ---------
            backbone: Backbone model to fine tune
            training_phase: Indicator of classification or regression
            num_meta: Number of features in metadata
            pos_weight: Weight of positive samples passed to classification loss
            classification_threshold: Threshold for binary classification
            regressor_kwargs: Keyword arguments passed to regressor
        """
        super().__init__()
        if training_phase not in ('classification', 'regression'):
            raise ValueError('phase must be either classification or regression')
        if backbone == 'resnet_18':
            self.backbone = torchvision.models.resnet18(pretrained=True)
            num_feats = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        else:
            raise ValueError('backbone model not supported')
        self.classifier = nn.Linear(num_feats, num_meta)
        if regressor_kwargs is None: regressor_kwargs = {}
        self.regressor = Regressor(num_feats, **regressor_kwargs)
        self.lossfn_classification = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        self.lossfn_regression = nn.MSELoss()
        self.classification_threshold = classification_threshold
        self.training_phase = training_phase
        self.freeze_part_by_training_phase()
    
    def freeze_part_by_training_phase(self):
        """Freeze part of model according to the internal training phase."""
        if self.training_phase == 'classification':
            self.classifier.requires_grad_(True)
            self.regressor.requires_grad_(False)  # freeze regressor
        else:
            self.regressor.requires_grad_(True)
            self.classifier.requires_grad_(False)  # freeze classifier
    
    def forward(self, imgs):
        return self.regressor(self.backbone(imgs))
    
    def training_step(self, batch, batch_idx):
        imgs, meta, scores = batch
        if self.training_phase == 'classification':
            logits = self.classifier(self.backbone(imgs))
            loss = self.lossfn_classification(logits, meta)
            self.log('Loss:classification/train', loss)
            preds = torch.sigmoid(logits) > self.classification_threshold
            acc = (preds == meta).float().mean().item()  # batch accuracy
            self.log('Accuracy/train', acc)
        else:
            preds = self.regressor(self.backbone(imgs))
            loss = self.lossfn_regression(preds, scores.unsqueeze(-1))
            self.log('Loss:regression/train', loss)
            rmse = torch.sqrt(loss)
            self.log('RMSE/train', rmse)
        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs, meta, scores = batch
        if self.training_phase == 'classification':
            logits = self.classifier(self.backbone(imgs))
            loss = self.lossfn_classification(logits, meta)
            self.log('Loss:classification/validation', loss)
            preds = torch.sigmoid(logits) > self.classification_threshold
            acc = (preds == meta).float().mean().item()  # batch accuracy
            self.log('Accuracy/validation', acc)
        else:
            preds = self.regressor(self.backbone(imgs))
            loss = self.lossfn_regression(preds, scores.unsqueeze(-1))
            self.log('Loss:regression/validation', loss)
            rmse = torch.sqrt(loss)
            self.log('RMSE/validation', rmse)
    
    def configure_optimizers(self):
        if self.training_phase == 'classification':
            optimizer = optim.AdamW(self.parameters(), lr=1e-3)
            return optimizer
        else:
            optimizer = optim.AdamW(self.parameters(), lr=1e-3)
            return optimizer
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        imgs, _, _ = batch
        return self(imgs)

# Training with Two-stage Finetuning

In [None]:
IMAGE_SIZE = 224
BATCH_SIZE = 64
NUM_VALIDATION = 128

CLASSIFICATION_THRESHOLD = 0.5
REGRESSOR_KWARGS = {'dropout': 0.1}

NUM_EPOCHS_CLASSIFICATION = 15
NUM_EPOCHS_REGRESSION = 60
REGRESSOR_CHECKPOINT_PERIOD = 15

In [None]:
datamodule = PetfinderDataModule(
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    num_validation=NUM_VALIDATION
)
datamodule.setup()

## Finetune Backbone & Classifier

In [None]:
model = PawpularityPredictor(
    backbone='resnet_18',
    training_phase='classification',
    num_meta=datamodule.num_meta(),
    pos_weight=datamodule.meta_odds(),
    classification_threshold=CLASSIFICATION_THRESHOLD,
    regressor_kwargs=REGRESSOR_KWARGS
)
checkpoint_callback_classifier = ModelCheckpoint(
    monitor='Accuracy/validation',
    mode='max',
    filename='classifier-{epoch}'
)
logger_classification = pl.loggers.CSVLogger('./logs_classification')
trainer = pl.Trainer(
    gpus=1,
    max_epochs=NUM_EPOCHS_CLASSIFICATION,
    callbacks=[checkpoint_callback_classifier],
    logger=logger_classification
)

trainer.fit(model, datamodule=datamodule)

## Finetune Backbone & Regressor

In [None]:
model = PawpularityPredictor.load_from_checkpoint(
    checkpoint_callback_classifier.best_model_path,
    training_phase='regression'
)
checkpoint_callback_regressor = ModelCheckpoint(
    filename='regressor-{epoch}',
    every_n_epochs=REGRESSOR_CHECKPOINT_PERIOD,
    save_top_k=-1
)
logger_regression = pl.loggers.CSVLogger('./logs_regression')
trainer = pl.Trainer(
    gpus=1,
    max_epochs=NUM_EPOCHS_REGRESSION,
    callbacks=[checkpoint_callback_regressor],
    logger=logger_regression
)

trainer.fit(model, datamodule=datamodule)

# Diagnostics

In [None]:
def diagnose_predictions(model, datamodule, num_epochs):
    """Output a scatter plot of the actual/predicted Pawpularity scores."""
    # Setup figure
    fig, axes = plt.subplots(ncols=2, figsize=(12,6))
    lims = (-2, 102)
    for ax in axes:
        ax.set_xlabel('Actual Pawpularity Score')
        ax.set_ylabel('Predicted Pawpularity Score')
        ax.set_xlim(*lims)
        ax.set_ylim(*lims)
    axes[0].set_title('Training Samples')
    axes[1].set_title('Validation Set')
    fig.suptitle(f'Regressor Trained after {num_epochs} Epochs', fontsize=16)
    
    # Plot diagonal line
    for ax in axes:
        ax.plot(lims, lims, color='C3')
    
    # Visualize training/validation set
    dataloaders = (
        DataLoader(datamodule.trainset, batch_size=NUM_VALIDATION, shuffle=True),
        DataLoader(datamodule.valset, batch_size=NUM_VALIDATION)
    )
    for ax, dataloader in zip(axes, dataloaders):
        # Plot actual/predicted scores
        imgs, _, scores = next(iter(dataloader))
        with torch.no_grad():
            preds = model(imgs)
        ax.scatter(scores.cpu().numpy(), preds.squeeze().cpu().numpy(), c='C1')
        # Add text of RMSE
        rmse = torch.sqrt(model.lossfn_regression(preds, scores.unsqueeze(-1))).item()
        textstr = f'RMSE = {rmse:.2f}'
        props = dict(boxstyle='round', facecolor='C4', alpha=0.5)
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, verticalalignment='top', bbox=props)
    
    # Output
    os.makedirs('./diagnostics', exist_ok=True)
    fig.savefig(f'./diagnostics/regressor-epoch={num_epochs}.png')
    plt.show(fig)

In [None]:
num_checkpoints = NUM_EPOCHS_REGRESSION // REGRESSOR_CHECKPOINT_PERIOD
checkpoint_epochs = [
    REGRESSOR_CHECKPOINT_PERIOD * i for i in range(num_checkpoints+1)
]

# Diagnose predictions during regressor training
for epoch in checkpoint_epochs:
    if epoch == 0:
        ckptpath = checkpoint_callback_classifier.best_model_path
    else:
        ckptpath = os.path.join(
            logger_regression.log_dir,
            f'checkpoints/regressor-epoch={epoch-1}.ckpt'
        )
    model = PawpularityPredictor.load_from_checkpoint(ckptpath)
    diagnose_predictions(model, datamodule, epoch)

# Inference

In [None]:
preds, = trainer.predict(model, datamodule=datamodule)

In [None]:
# Output predictions
predictions = pd.DataFrame({
    'Id': datamodule.predictset.meta['Id'],
    'Pawpularity': preds.squeeze().cpu().numpy()
})
predictions.to_csv('./submission.csv', index=False)