In [None]:
%%capture
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7
!pip install timm
!pip install nb_black
%load_ext nb_black

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.style.use("ggplot")

import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

import timm

import gc
import os
import time
import random
from datetime import datetime

from PIL import Image
from tqdm.notebook import tqdm
from sklearn import model_selection, metrics
from sklearn.model_selection import StratifiedKFold

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import cv2

In [None]:
FOLD = 0

In [None]:
# For parallelization in TPUs
os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

In [None]:
def seed_everything(seed):
    """
    Seeds basic parameters for reproductibility of results
    
    Arguments:
        seed {int} -- Number of the seed
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(28)

In [None]:
# general global variables
DATA_PATH = "../input/petfinder-pawpularity-score"
TRAIN_PATH = "../input/petfinder-pawpularity-score/train"
TEST_PATH = "../input/petfinder-pawpularity-score/test"
MODEL_PATH = (
    "../input/vit-base-models-pretrained-pytorch/jx_vit_base_p16_224-80ecf9dd.pth"
)
BASE_DIR = "../input/petfinder-pawpularity-score"
IMG_DIR = "train"

# model specific global variables
IMG_SIZE = 384
BATCH_SIZE = 16
LR = 4e-05
GAMMA = 0.7
N_EPOCHS = 4


class TrainConfig:
    batch_size = BATCH_SIZE
    num_workers = 4
    epochs = N_EPOCHS
    lr = LR
    img_size = IMG_SIZE

In [None]:
train_df = pd.read_csv("../input/petfinder-pawpularity-score/train.csv")
train_df.head()

In [None]:
N_FOLDS = 5
train_df["kfold"] = -1
skf = StratifiedKFold(n_splits=N_FOLDS)
train_df["groups"] = pd.cut(train_df["Pawpularity"], bins=10, labels=False)
target = train_df["groups"]

for fold, (train_idx, val_idx) in enumerate(skf.split(target, target)):
    train_df.loc[val_idx, 'kfold'] = fold
train_df = train_df.drop(["groups"], axis=1)
train_df.head()

In [None]:
val_df = train_df[train_df["kfold"] == FOLD]
train_df = train_df[train_df["kfold"] != FOLD]

train_df = train_df.drop(["kfold"], axis=1)
val_df = val_df.drop(["kfold"], axis=1)

print(f"Train Size: {train_df.shape}")
print(f"Validation Size: {val_df.shape}")

In [None]:
def get_train_augs():
    return A.Compose(
        [
            A.RandomResizedCrop(TrainConfig.img_size, TrainConfig.img_size),
            A.OneOf(
                [
                    A.HueSaturationValue(
                        hue_shift_limit=0.2,
                        sat_shift_limit=0.2,
                        val_shift_limit=0.2,
                        p=0.9,
                    ),
                    A.RandomBrightnessContrast(
                        brightness_limit=0.2, contrast_limit=0.2, p=0.9
                    ),
                ],
                p=0.9,
            ),
            A.ToGray(p=0.01),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.Cutout(num_holes=8, max_h_size=64, max_w_size=64, fill_value=0, p=0.5),
            ToTensorV2(p=1.0),
        ],
        p=1.0,
    )


def get_valid_augs():
    return A.Compose(
        [
            A.Resize(height=TrainConfig.img_size, width=TrainConfig.img_size, p=1.0),
            ToTensorV2(p=1.0),
        ],
        p=1.0,
    )

In [None]:
class PetFinderDataset(Dataset):
    def __init__(self, df, augs=None):
        self.df = df
        self.augs = augs

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        image = self._load_image(self.df["Id"].iloc[index])

        # Apply image augmentations if available
        if self.augs:
            image = self.augs(image=image)["image"]

        return image, self.df["Pawpularity"].iloc[index]

    def _load_image(self, image_id):
        image = cv2.imread(f"{BASE_DIR}/{IMG_DIR}/{image_id}.jpg", cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0
        return image

In [None]:
train_dataset = PetFinderDataset(df=train_df, augs=get_train_augs())
valid_dataset = PetFinderDataset(df=val_df, augs=get_valid_augs())

train_loader = DataLoader(
    train_dataset,
    batch_size=TrainConfig.batch_size,
    pin_memory=False,
    num_workers=TrainConfig.num_workers,
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=TrainConfig.batch_size,
    pin_memory=False,
    num_workers=TrainConfig.num_workers,
)

In [None]:
print("Available Vision Transformer Models: ")
timm.list_models("vit*")

In [None]:
class ViT(nn.Module):
    def __init__(self, pretrained=False):

        super(ViT, self).__init__()

        self.model = timm.create_model("vit_base_patch32_384", pretrained=pretrained)
        self.model.head = nn.Linear(self.model.head.in_features, 1)

    def forward(self, x):
        return self.model(x)

    def train_one_epoch(self, train_loader, criterion, optimizer, device):
        # keep track of training loss
        epoch_loss = 0.0

        ###################
        # train the model #
        ###################
        self.model.train()
        for i, (data, target) in enumerate(train_loader):
            # move tensors to GPU if CUDA is available
            if device.type == "cuda":
                data, target = data.cuda(), target.cuda()
            elif device.type == "xla":
                data = data.to(device, dtype=torch.float32)
                target = target.to(device, dtype=torch.int64)

            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = self.forward(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # update training loss
            epoch_loss += loss

            # perform a single optimization step (parameter update)
            if device.type == "xla":
                xm.optimizer_step(optimizer)

                if i % 100 == 0:
                    xm.master_print(f"\tBATCH {i+1}/{len(train_loader)} - LOSS: {loss}")

            else:
                optimizer.step()

        return epoch_loss / len(train_loader)

    def validate_one_epoch(self, valid_loader, criterion, device):
        # keep track of validation loss
        valid_loss = 0.0

        ######################
        # validate the model #
        ######################
        self.model.eval()
        for data, target in valid_loader:
            # move tensors to GPU if CUDA is available
            if device.type == "cuda":
                data, target = data.cuda(), target.cuda()
            elif device.type == "xla":
                data = data.to(device, dtype=torch.float32)
                target = target.to(device, dtype=torch.int64)

            with torch.no_grad():
                # forward pass: compute predicted outputs by passing inputs to the model
                output = self.model(data)
                # calculate the batch loss
                loss = criterion(output, target)
                # update average validation loss
                valid_loss += loss

        return valid_loss / len(valid_loader)

In [None]:
def fit_tpu(
    model, epochs, device, criterion, optimizer, train_loader, valid_loader=None
):

    valid_loss_min = np.Inf  # track change in validation loss

    # keeping track of losses as it happen
    train_losses = []
    valid_losses = []

    for epoch in range(1, epochs + 1):
        gc.collect()
        para_train_loader = pl.ParallelLoader(train_loader, [device])

        xm.master_print(f"{'='*50}")
        xm.master_print(f"EPOCH {epoch} - TRAINING...")
        train_loss = model.train_one_epoch(
            para_train_loader.per_device_loader(device), criterion, optimizer, device
        )
        xm.master_print(
            f"\n\t[TRAIN] EPOCH {epoch} - LOSS: {train_loss}\n"
        )
        train_losses.append(train_loss)
        gc.collect()

        if valid_loader is not None:
            gc.collect()
            para_valid_loader = pl.ParallelLoader(valid_loader, [device])
            xm.master_print(f"EPOCH {epoch} - VALIDATING...")
            valid_loss = model.validate_one_epoch(
                para_valid_loader.per_device_loader(device), criterion, device
            )
            xm.master_print(f"\t[VALID] LOSS: {valid_loss}\n")
            valid_losses.append(valid_loss)
            gc.collect()

            # log to wandb
#             wandb.log(
#                 {
#                     "Train Loss": train_loss,
#                     "Validation Loss": valid_loss,
#                 }
#             )

            # save model if validation loss has decreased
            if valid_loss <= valid_loss_min and epoch != 1:
                xm.master_print(
                    "Validation loss decreased ({:.4f} --> {:.4f}).  Saving model ...".format(
                        valid_loss_min, valid_loss
                    )
                )
            #                 xm.save(model.state_dict(), 'best_model.pth')

            valid_loss_min = valid_loss

    return {
        "train_loss": train_losses,
        "valid_losses": valid_losses,
    }

In [None]:
model = ViT(pretrained=True)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# wandb.watch(model)
print(f"The model has {count_parameters(model):,} trainable parameters")

In [None]:
class RMSELoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.mse = nn.MSELoss()
        self.eps = eps

    def forward(self, yhat, y):
        loss = torch.sqrt(self.mse(yhat, y.float()) + self.eps)
        return loss

In [None]:
def _run():
    train_dataset = PetFinderDataset(df=train_df, augs=get_train_augs())
    valid_dataset = PetFinderDataset(df=val_df, augs=get_valid_augs())

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True,
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False,
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=TrainConfig.batch_size,
        pin_memory=False,
        num_workers=TrainConfig.num_workers,
    )

    valid_loader = DataLoader(
        valid_dataset,
        batch_size=TrainConfig.batch_size,
        pin_memory=False,
        num_workers=TrainConfig.num_workers,
    )

#     criterion = nn.CrossEntropyLoss()
    criterion = RMSELoss()
    device = xm.xla_device()
    model.to(device)

    lr = LR * xm.xrt_world_size()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    xm.master_print(f"INITIALIZING TRAINING ON {xm.xrt_world_size()} TPU CORES")
    start_time = datetime.now()
    xm.master_print(f"Start Time: {start_time}")

    logs = fit_tpu(
        model=model,
        epochs=N_EPOCHS,
        device=device,
        criterion=criterion,
        optimizer=optimizer,
        train_loader=train_loader,
        valid_loader=valid_loader,
    )

    xm.master_print(f"Execution time: {datetime.now() - start_time}")

    xm.master_print("Saving Model")
    xm.save(
        model.state_dict(), f'model_5e_{datetime.now().strftime("%Y%m%d-%H%M")}.pth'
    )

In [None]:
_run()