**Complete training code is explained on the Jarvislabs [Blog](https://jarvislabs.ai/blogs/pet)** 

Also please lower down the batch size if you're planning to training on kaggle.

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master') ## Importing Timm Library

In [None]:
import fastai
import os
import warnings
from pprint import pprint
from glob import glob
from tqdm import tqdm
from fastai.vision.all import *
import torch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from timm import create_model
from sklearn.model_selection import StratifiedKFold
from torchvision.io import read_image
from torch.utils.data import DataLoader, Dataset

In [None]:
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
import albumentations

from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform
import random

import timm
import cv2

from torch.cuda.amp import autocast, GradScaler

import warnings 
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import os

model_output = '/kaggle/working/models/exp0'
if not os.path.exists(model_output):
    os.makedirs(model_output)



In [None]:
class Config:
    model_name='swin_base_patch4_window12_384'
    pretrained = True
    train_dir = '../input/petfinder-pawpularity-score/train' # Train Image Directory
    train_csv = '../input/abhi-folds-petfinder/train_10folds.csv' # Train Csv Location
    image_size= 384
    epochs=10
    num_workers=8
    targets = 1
    lr=5e-5
    batch_size=16
    weight_decay=1e-4
    seed=42
    n_fold=10
    trn_fold=[0, 1, 2, 3,4,5,6,7,8,9]
    train=True

In [None]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=Config.seed)

In [None]:
class PetDataset:
    def __init__(self, df,image_path, augmentations):
        self.image_path = image_path
        self.df = df
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, item):
        id_ = self.df.Id.iloc[item]
        path = f'{self.image_path}/{id_}.jpg'

        targets = self.df.Pawpularity.iloc[item] #/100.
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#.numpy()
        
        if self.augmentations is not None:
            augmented = self.augmentations(image=image)
            image = augmented["image"]
            
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        image = torch.tensor(image, dtype=torch.float) 
        
#         targets = targets[item]
        targets = torch.tensor(targets, dtype=torch.float) / 100
        return image, targets
            

In [None]:
def get_transforms(*, data):
    
    if data == 'train':
        return albumentations.Compose(
                transforms=[
                albumentations.RandomResizedCrop(Config.image_size, Config.image_size, scale=(0.85, 1.0)),

               albumentations.HorizontalFlip(p=0.5),
               albumentations.ShiftScaleRotate(p=0.5),
               albumentations.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1,p=0.4),
               albumentations.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.7),
               albumentations.RandomBrightnessContrast(brightness_limit=(-0.2,0.2), contrast_limit=(-0.2, 0.2), p=0.7),
               albumentations.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
            p=1.0,
        ),    
#             ToTensorV2(),

                
                ])

    elif data == 'valid':
        return albumentations.Compose([
            albumentations.Resize(Config.image_size, Config.image_size),
            albumentations.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
            p=1.0,
        ),           
#         ToTensorV2(),

        ])



In [None]:
class RMSELoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.mse = nn.MSELoss()
        self.eps = eps

    def forward(self, yhat, y):
        loss = torch.sqrt(self.mse(yhat, y) + self.eps)
        return loss

In [None]:
class PetModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.backbone = timm.create_model(self.config.model_name, pretrained=self.config.pretrained)
        self.n_features = self.backbone.head.in_features
        self.backbone.head = nn.Identity()
        self.fc = nn.Linear(self.n_features, self.config.targets)
        
    def forward(self, image):
#         print('done')
        feature = self.backbone(image)
        output = self.fc(feature)
#         print('done2')


        return output

In [None]:
# m = PetModel(Config)
# # 
# del m

In [None]:
import gc
gc.collect()

In [None]:
# Cell
def mse(inp,targ):
    "Mean squared error between `inp` and `targ`."
    return F.mse_loss(*flatten_check(inp,targ))

# Cell
def _rmse(inp, targ): return 100*torch.sqrt(F.mse_loss(F.sigmoid(inp.flatten()), targ))
rmse = AccumMetric(_rmse)
rmse.__doc__ = "rrr"

In [None]:
def run(Config):
    for fold_num in Config.trn_fold:    
        print('*****************************************')
        print(f'Training Fold {fold_num}')
        print('*****************************************')

        kernel_type = model_output
        df = pd.read_csv(Config.train_csv)[['Id', 'Pawpularity', 'kfold']]
        df['is_valid'] = df.kfold.apply(lambda x: x==fold_num)
#         df = df.sample(1000)

        training_fold = df.query('is_valid==False').reset_index(drop=True, inplace=False)
        train_ds = PetDataset(training_fold,Config.train_dir,augmentations = get_transforms(data='train'))

        validation_fold = df.query('is_valid==True').reset_index(drop=True, inplace=False)
        valid_ds = PetDataset(validation_fold,Config.train_dir,augmentations = get_transforms(data='valid'))

        print(f'- Training samples: {len(train_ds)}\n- Validation Samples : {len(valid_ds)}')

        bs = Config.batch_size
        train_dl = torch.utils.data.DataLoader(train_ds, batch_size=bs, num_workers=Config.num_workers,pin_memory=False)
        valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=bs*2, num_workers=Config.num_workers,shuffle=False,pin_memory=False)

        dls = DataLoaders(train_dl, valid_dl)
        rmse = AccumMetric(_rmse)


        model = PetModel(Config)


        early_stop = EarlyStoppingCallback(monitor='petfinder_rmse', min_delta=0.1, patience=5)
        save_callback = SaveModelCallback('petfinder_rmse', every_epoch=True)
        logger = CSVLogger(f'{model_output}/{fold_num}logs.csv')
        learn = Learner(dls, model, loss_func=BCEWithLogitsLossFlat(), metrics=[_rmse],cbs=[early_stop,save_callback,logger]).to_fp16()

        learn.fit_one_cycle(Config.epochs, Config.lr, wd=Config.weight_decay)

        learn.save(f'{kernel_type}/fold_{fold_num}')



In [None]:
run(Config)