In [None]:
!pip install timm

In [None]:
import os
import math
import random
import time
from tqdm import tqdm

import numpy as np
import pandas as pd

import cv2
import albumentations
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import timm
from sklearn.metrics import mean_squared_error

import gc
gc.enable()

import warnings
warnings.filterwarnings("ignore")

# Config

In [None]:
def get_train_file_path(id):
    return f"{train_img_path}/{id}.jpg"

In [None]:
TTA_LIST = ['transforms_test','tta_flip','tta_shift_scale_rotate','tta_rotate']
DIM = 384
seed = 42
train_img_path = '../input/petfinder-pawpularity-score/train'

device = 'cuda'
backbone = 'swin_large_patch4_window12_384'

data = pd.read_csv('../input/pawpular-models/pawpularity_folds.csv')
data['file_path'] = data['Id'].apply(get_train_file_path)

# Utils

In [None]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
def get_valid_transforms():
    return albumentations.Compose(
        [
          albumentations.Resize(DIM,DIM),
          albumentations.Normalize(
              mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225],
          ),
          ToTensorV2(p=1.0)
        ]
    )

In [None]:
def get_inference_transforms(image_size):
    """Performs Augmentation on test dataset.
    Returns the transforms for inference in a dictionary which can hold TTA transforms.

    Args:
        image_size (int, optional): [description]. Defaults to AUG.image_size.

    Returns:
        Dict[str, albumentations.Compose]: [description]
    """
    transforms_dict = {
        "transforms_test": albumentations.Compose(
            [
                albumentations.Resize(image_size, image_size),
                albumentations.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                ),
                ToTensorV2(p=1.0),
            ]
        ),
        "tta_flip": albumentations.Compose(
            [
                albumentations.HorizontalFlip(p=1),
                albumentations.Resize(image_size, image_size),
                albumentations.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    max_pixel_value=255.0,
                    p=1.0,
                ),
                ToTensorV2(p=1.0),
            ]
        ),
        
        "tta_rotate": albumentations.Compose(
            [
                albumentations.Rotate(limit=180, p=1),
                albumentations.Resize(image_size, image_size),
                albumentations.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    max_pixel_value=255.0,
                    p=1.0,
                ),
                ToTensorV2(p=1.0),
            ]
        ),
        
        "tta_shift_scale_rotate": albumentations.Compose(
            [
                albumentations.ShiftScaleRotate(
                    shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=1
                ),
                albumentations.Resize(image_size, image_size),
                albumentations.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    max_pixel_value=255.0,
                    p=1.0,
                ),
                ToTensorV2(p=1.0),
            ]
        ),
        
        "tta_hue_saturation_value": albumentations.Compose(
            [
                albumentations.HueSaturationValue(
                    hue_shift_limit=0.2,
                    sat_shift_limit=0.2,
                    val_shift_limit=0.2,
                    p=1,
                ),
                albumentations.Resize(image_size, image_size),
                albumentations.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    max_pixel_value=255.0,
                    p=1.0,
                ),
                ToTensorV2(p=1.0),
            ]
        ),
        
        "tta_random_brightness_contrast": albumentations.Compose(
            [
                albumentations.RandomBrightnessContrast(
                    brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=1
                ),
                albumentations.Resize(image_size, image_size),
                albumentations.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                    max_pixel_value=255.0,
                    p=1.0,
                ),
                ToTensorV2(p=1.0),
            ]
        )
    }
    
    return transforms_dict

# Dataset

In [None]:
class PawpularityDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.file_names = df['file_path'].values
        self.transforms = transforms

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = self.file_names[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return img#,self.df.loc[index,'Pawpularity'],self.df.loc[index,'Id']

# Model

In [None]:
class PawpularityModel(nn.Module):
    def __init__(self, backbone, pretrained=True):
        super(PawpularityModel, self).__init__()
        self.backbone = timm.create_model(backbone, pretrained=pretrained)
        self.n_features = self.backbone.head.in_features
        self.backbone.reset_classifier(0)

        self.fc = nn.Sequential(
            nn.Linear(self.n_features, self.n_features//2),
            nn.Tanh(),
            nn.Dropout(0.2),
            nn.Linear(self.n_features//2, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1)
        )

    def forward(self, images):
        features = self.backbone(images)
        output = self.fc(features)           
        return output

# Engine

In [None]:
def oof_out(dataloader,model,device):
    model.eval()
    fin_out = []
    #fin_tar = []
    #fin_id = []
  
    with torch.no_grad():
        #,target,ids
        for batch_num, (image) in enumerate(dataloader):
            image = image.to(device)
            
            output = model(image)
            output = output.sigmoid().squeeze(1).detach().cpu().numpy()
            #target = target.detach().cpu().numpy()
    
            fin_out.append(output*100)
            #fin_tar.append(target)
            #fin_id.append(ids)
        
    return np.concatenate(fin_out)#,np.concatenate(fin_tar),np.concatenate(fin_id)

In [None]:
def make_oofs():
    oof_id = []
    oof_fold = []
    oof_pred = []
    oof_tar=[]

    model = PawpularityModel(backbone=backbone)
    model.to(device)
    model.eval()
    
    transform = get_inference_transforms(DIM)

    for fold in tqdm(range(5)):
        set_seed(seed+fold)
        
        valid = data[data['kfold']==fold].reset_index(drop=True)
        
        tta_preds = 0
        for i,tta in enumerate(TTA_LIST):
            # Defining DataSet            
            valid_dataset = PawpularityDataset(valid,transform[tta])   

            valid_loader = torch.utils.data.DataLoader(
                valid_dataset,
                batch_size=32,
                pin_memory=True,
                drop_last=False,
                num_workers=4
            )
            
            set_seed(seed+fold)
            
            model_path = f"../input/pawpular-models/12/swin_baseline_11_fold_{fold}.pth"
            model.load_state_dict(torch.load(model_path,map_location=device))
        
            # THE ENGINE LOOP
            valid_out = oof_out(valid_loader, model, device)
            tta_preds = tta_preds + valid_out
            

        tta_preds = tta_preds/len(TTA_LIST)

        ### Storing OOFS
        oof_id.append(valid['Id'].values)
        oof_tar.append(valid['Pawpularity'].values)
        oof_pred.append(tta_preds)
        oof_fold.append([fold]*len(valid['Id'].values))

    return oof_pred,oof_tar,oof_id,oof_fold

In [None]:
oof_pred,oof_tar,oof_id,oof_fold = make_oofs()

In [None]:
# COMPUTE OVERALL OOF AUC
oof = np.concatenate(oof_pred); true = np.concatenate(oof_tar);
id = np.concatenate(oof_id); folds = np.concatenate(oof_fold)
RMSE = mean_squared_error(true,oof,squared=False)
print('Overall OOF RMSE with TTA =',RMSE)

In [None]:
df_oof = pd.DataFrame(dict(
    id = id, target=true, pred = oof, fold=folds))
df_oof.to_csv('swin_009_oof_tta.csv',index=False)
df_oof.head()