In [None]:
import os
import sys
import gc
sys.path.append("../input/pytorch-image-models")

import time
import numpy as np
import pandas as pd
from easydict import EasyDict as edict

# visualization
import matplotlib.pyplot as plt
import seaborn as sns

# image
import PIL
from PIL import Image
import albumentations as albu

# model validation
import sklearn
from sklearn.model_selection import StratifiedKFold

# model
import torch
import torch.nn as nn
import torchvision
import timm

# pytorch_lightening
import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import callbacks
from pytorch_lightning.callbacks.progress import ProgressBarBase
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningDataModule, LightningModule

In [None]:
# For notebook commit
GET_CV = True
test = pd.read_csv('../input/petfinder-pawpularity-score/test.csv')
if len(test)>8:
    GET_CV = False
else:
    print('this submission notebook will compute CV score, but commit notebook will not')

In [None]:
GET_CV = False

In [None]:
if GET_CV:
    df = pd.read_csv('../input/petfinder-pawpularity-score/train.csv')
else:
    df = pd.read_csv('../input/petfinder-pawpularity-score/test.csv')

# 1. Model Config

In [None]:
__C = edict()
cfg = __C
# model
cfg.model = edict()
cfg.model.name = 'vit_base_patch16_224'
cfg.model.weight = '../input/vit-base-models-pretrained-pytorch/jx_vit_base_p16_224-80ecf9dd.pth'

# optimizer
cfg.optim = edict()
cfg.optim.name = 'torch.optim.AdamW'
cfg.optim.lr = 2e-5
cfg.optim.max_epochs = 20

# lr_schedule
cfg.lr_sched = edict()
cfg.lr_sched.name = 'torch.optim.lr_scheduler.CosineAnnealingWarmRestarts'
cfg.lr_sched.params = {'eta_min':0.0001, 'T_0':20}

# dataset
cfg.trainloader = edict()
cfg.trainloader.batch_size = 32
cfg.trainloader.drop_last = True
cfg.trainloader.num_workers = 4
cfg.trainloader.pin_memory = False
cfg.trainloader.shuffle = True

cfg.valloader = edict()
cfg.valloader.batch_size = 16
cfg.valloader.drop_last = True
cfg.valloader.num_workers = 4
cfg.valloader.pin_memory = False
cfg.valloader.shuffle = False

# data transform and augmentation
cfg.transform = edict()
cfg.transform.img_size = 224
cfg.transform.normalize_mean = [.5, .5, .5]
cfg.transform.normalize_std = [.5, .5, .5]

# data directory
cfg.data = edict()
if GET_CV:
    cfg.data.df_dir = '../input/petfinder-pawpularity-score/train.csv'
else:
    cfg.data.df_dir = '../input/petfinder-pawpularity-score/test.csv'

if GET_CV:    
    cfg.data.image_dir = '../input/petfinder-pawpularity-score/train'
else:
    cfg.data.image_dir = '../input/petfinder-pawpularity-score/test'

# 2. Data Pipeline

In [None]:
def build_transform(cfg):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(cfg.transform.normalize_mean,
                                         cfg.transform.normalize_std),
        torchvision.transforms.Resize(size = (cfg.transform.img_size,cfg.transform.img_size))
        ])
    return transforms

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, df, data_dir, transforms):
        self.df = df
        self.df['Pawpularity'] /= 100.0
        self.data_dir = data_dir
        self.transforms = transforms
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.data_dir, self.df['Id'][index])+'.jpg'
        img = Image.open(img_path)
        img = self.transforms(img)
        label = torch.tensor(self.df['Pawpularity'][index], dtype = torch.float32).reshape(1)
        return img, label
    
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, df, data_dir, transform):
        self.df = df
        self.data_dir = data_dir
        self.transforms = transforms
    def __len__(self):
        return len(self.df)
    def __getitem__(self, index):
        img_path = os.path.join(self.data_dir, self.df['Id'][index])+'.jpg'
        img = Image.open(img_path)
        img = self.transforms(img)
        return img
    
class Dataloader(LightningDataModule):
    def __init__(self, cfg, train_df, val_df):
        super().__init__()
        self.cfg = cfg
        self.train_df = train_df
        self.val_df = val_df
    def _create_dataset(self, train = True):
        if train:
            return Dataset(self.train_df, self.cfg.data.image_dir, build_transform(self.cfg))
        else:
            return Dataset(self.val_df, self.cfg.data.image_dir, build_transform(self.cfg))
    
    def train_dataloader(self):
        dataset = self._create_dataset(train=True)
        return torch.utils.data.DataLoader(dataset, **self.cfg.trainloader)
    def val_dataloader(self):
        dataset = self._create_dataset(train = False)
        return torch.utils.data.DataLoader(dataset, **self.cfg.valloader)

# 3. Vision Transformer model

In [None]:
def RMSELoss(yhat,y):
    return torch.sqrt(torch.mean((yhat-y)**2))

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x

class ViT_model(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.backbone = timm.create_model(cfg.model.name, pretrained=False)
        self.backbone.load_state_dict(torch.load(cfg.model.weight))
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.head = nn.Linear(768, 1)
        self.sigmoid = nn.Sigmoid()
        self.criterion = nn.BCELoss()
    
    def forward(self, input):
        x = self.backbone(input)
        x = self.sigmoid(x)
        return x
    
    def _step(self, batch):
        img, target = batch
        pred = self(img)
        loss = self.criterion(pred, target)
        return pred, target, loss
    
    def training_step(self, batch, batch_idx):
        pred, target, loss = self._step(batch)
        metric = RMSELoss(pred, target)
        tensorboard_log = {'train_loss':loss, 'train_rmse':metric}
        return {'loss':loss, 'rmse':metric, 'log':tensorboard_log}
    
    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            pred, target, loss = self._step(batch)
            rmse = RMSELoss(pred*100.0, target*100.0)
        return {'val_loss': loss, 'val_rmse': rmse}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_rmse = torch.stack([x['val_rmse'] for x in outputs]).mean()
        print(f"Epoch {self.current_epoch} loss:{avg_loss} rmse:{avg_rmse}")
        self.log('val_rmse', avg_rmse)
        tensorboard_logs = {'val_loss': avg_loss, 'val_rmse': avg_rmse}
        return {'val_loss': avg_loss,
                'val_rmse': avg_rmse,
                'log': tensorboard_logs}
            
    def configure_optimizers(self):
        optimizer = eval(self.cfg.optim.name)(self.parameters(), lr = self.cfg.optim.lr)
        schedule = eval(self.cfg.lr_sched.name)(optimizer = optimizer, **self.cfg.lr_sched.params)
        return [optimizer], [schedule]

# 4. Training with Pytorch Lightning

In [None]:
# train model and compute cv score
if GET_CV:    
    kfold = StratifiedKFold(n_splits = 5, shuffle = True)
    for fold, (train_idx, val_idx) in enumerate(kfold.split(df["Id"], df["Pawpularity"])):
        print('fold {} training start'.format(fold+1))
        start = time.time()
        # train_test_split
        train_df = df.loc[train_idx].reset_index(drop=True)
        val_df = df.loc[val_idx].reset_index(drop=True)
        # define datamodule
        dataloader = Dataloader(cfg, train_df, val_df)
        # define model
        model = ViT_model(cfg)
        # define callbacks
        earystopping = EarlyStopping(monitor="val_rmse")
        lr_monitor = callbacks.LearningRateMonitor()
        loss_checkpoint = callbacks.ModelCheckpoint(
            filename = None,
            monitor="val_rmse",
            save_top_k=1,
            mode="min",
            save_last=False,
            )
        logger = TensorBoardLogger(cfg.model.name)

        trainer = pl.Trainer(
            logger=logger,
            max_epochs=cfg.optim.max_epochs,
            callbacks=[lr_monitor, loss_checkpoint, earystopping],
            gpus = 1)
        trainer.fit(model, datamodule = dataloader)
    
    elapse = time.time() -start
    print('fold {} complete -- {} seconds elapsed'.format(fold + 1, elapse))
else:
    print('This Notebook is for commit, not for computing CV score')

# 5. Inference & Submission

In [None]:
# inference on test set
if not GET_CV:
    transforms = build_transform(cfg)
    dataset = TestDataset(df,cfg.data.image_dir, transforms)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size = 128)
    pawpularity = []
    models = [ViT_model(cfg) for i in range(5)]
    # load pretrained weights
    weights_1 = torch.load('../input/finetuned-pawpularity-vit/epoch19-step4939.ckpt')['state_dict']
    weights_2 = torch.load('../input/finetuned-pawpularity-vit/epoch19-step4939 (1).ckpt')['state_dict']
    weights_3 = torch.load('../input/finetuned-pawpularity-vit/epoch19-step4939 (2).ckpt')['state_dict']
    weights_4 = torch.load('../input/finetuned-pawpularity-vit/epoch19-step4939 (3).ckpt')['state_dict']
    weights_5 = torch.load('../input/finetuned-pawpularity-vit/epoch19-step4939 (4).ckpt')['state_dict']
    weights = [weights_1, weights_2, weights_3, weights_4, weights_5]
    # load pretrained weights to model
    for i, (model, weight) in enumerate(zip(models, weights)):
        model.load_state_dict(weight)
        model = model.eval()
        models[i] = model.cuda()
    # inference by averaging 5 outputs
    for batch_idx, batch in enumerate(dataloader):
        pred = torch.zeros(batch.shape[0], 1).cuda()
        for model in models:
            with torch.no_grad():
                pred += model(batch.cuda())
        pred = pred / 5.0
        pred *= 100.0
        pawpularity.append(pred)
        # memory efficiency
        torch.cuda.empty_cache()
        del batch
        gc.collect()
        print(f'{64*(batch_idx+1)} image processed')
    pawpularity = np.concatenate([tensor.cpu().numpy().reshape(-1,) for tensor in pawpularity])
    df['Pawpularity'] = pawpularity
    submission = df[['Id','Pawpularity']]
    submission['Pawpularity'] = submission['Pawpularity'].astype(float)
    submission.to_csv('submission.csv')
    print('submission complete')