### Forked from https://www.kaggle.com/alibaba19/fast-ai-training-pipeline

In [None]:
!pip install timm

In [None]:
import math
import timm
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pathlib import Path
from PIL import Image
from fastai.vision.all import *
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
ROOT = Path('../input/petfinder-pawpularity-score')
TRAIN_IMG_PATH = ROOT / 'train'

ID_COL = 'Id'
TARGET_COL = 'Pawpularity'

SEED = 42
IMG_SIZE = 224
BATCH_SIZE = 128
N_EPOCHS = 3
N_SPLITS = 5
# Tensorflow EfficientNet B0 Noisy-Student
MODEL_NAME = 'tf_efficientnet_b0_ns'

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(SEED)

In [None]:
train_df = pd.read_csv(ROOT / 'train.csv')
target = train_df[TARGET_COL]
train_df.head()

In [None]:
train_transform = A.Compose([
    A.RandomResizedCrop(IMG_SIZE, IMG_SIZE, scale=(0.85, 1.1)),
    A.RandomRotate90(),
    A.Flip(),
    A.Transpose(),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2(),
])

valid_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2(),
])

In [None]:
class PetDataset(Dataset):
    def __init__(self, df, data_dir, transform=None, mode='train'):
        self.df = df
        self.data_dir = data_dir
        self.transform = transform
        self.mode = mode
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self.data_dir / f'{row[ID_COL]}.jpg'
        img = np.array(Image.open(img_path).convert('RGB'))
        if self.transform is not None:
            img = self.transform(image=img)['image']
        tgt = row[TARGET_COL] if self.mode == 'train' else 0
        return img.float().to(device), torch.tensor(tgt).float().to(device)

In [None]:
class CustomModel(nn.Module):
    def __init__(self, model_name=MODEL_NAME, pretrained=True):
        super(CustomModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model.global_pool = nn.Identity()
        self.model.classifier = nn.Identity()
        self.head = create_head(self.model.num_features, 1)
        self.act = nn.ReLU()
        
    def forward(self, x):
        x = self.model(x)
        x = self.head(x)
        x = self.act(x)
        return x

In [None]:
kfold = KFold(n_splits=N_SPLITS, random_state=SEED, shuffle=True)
oof_pred = torch.zeros(len(train_df))
criterion = MSELossFlat()
for fold, (train_idx, valid_idx) in enumerate(kfold.split(train_df)):
    print('='*5, f'Start Fold: {fold}', '='*5)
    train_x, valid_x = train_df.loc[train_idx], train_df.loc[valid_idx]
    
    train_ds, valid_ds = PetDataset(train_x, TRAIN_IMG_PATH, train_transform), PetDataset(valid_x, TRAIN_IMG_PATH, valid_transform)
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)
    dls = DataLoaders(train_dl, valid_dl)
    
    model = CustomModel(MODEL_NAME).to(device)
    learner = Learner(dls, model, loss_func=criterion, metrics=rmse)
    learner.fine_tune(N_EPOCHS)
    
    pred, tgt = learner.get_preds(dl=valid_dl)
    oof_pred[valid_idx] = pred.detach().cpu().view(-1)
    
    print(f'Fold: {fold}, RMSE: {mean_squared_error(tgt, pred, squared=False)}')
    
    learner.save(f'learner_fold_{fold}')
    torch.save(learner.model.state_dict(), f'./fold_{fold}.pth')
    
    torch.cuda.empty_cache()

In [None]:
print(f'CV Score = {mean_squared_error(target, oof_pred, squared=False)}')

In [None]:
plt.figure(figsize=(10, 6))
plt.hist(target, bins=30)
plt.hist(oof_pred.numpy(), bins=30)