Credit to https://www.kaggle.com/debarshichanda/tpu-training.
I modified some codes and made it work.

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev


In [None]:
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version 1.7
!pip install timm
!pip install pretrainedmodels

In [None]:
import os
import gc
import cv2
import copy
import time
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#import torch_optimizer as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import models
from torch.utils.data import DataLoader, Dataset
from torch.cuda import amp

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.utils import class_weight

from tqdm.notebook import tqdm
from collections import defaultdict
from datetime import datetime
import albumentations as A
from albumentations.pytorch import ToTensorV2

import timm
import pretrainedmodels

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

In [None]:
# For parallelization in TPUs
os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

In [None]:
class CFG:
    model_name = 'resnet18d'#'tf_efficientnet_b4_ns'
    img_size = 380
    scheduler = 'CosineAnnealingWarmRestarts'
    T_max = 10
    T_0 = 10
    lr = 1e-4
    min_lr = 1e-6
    batch_size = 64#16*4
    weight_decay = 1e-6
    seed = 42
    num_classes = 5
    num_epochs = 10#10
    n_fold = 5
    smoothing = 0.2

In [None]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CFG.seed)

In [None]:
ROOT_DIR = "../input/cassava-leaf-disease-classification"
TRAIN_DIR = "../input/cassava-leaf-disease-classification/train_images"
TEST_DIR = "../input/cassava-leaf-disease-classification/test_images"

In [None]:
df = pd.read_csv(f"{ROOT_DIR}/train.csv")
df

In [None]:
skf = StratifiedKFold(n_splits=CFG.n_fold)
for fold, ( _, val_) in enumerate(skf.split(X=df, y=df.label)):
    df.loc[val_ , "kfold"] = int(fold)
    
df['kfold'] = df['kfold'].astype(int)

In [None]:
class CassavaLeafDataset(nn.Module):
    def __init__(self, root_dir, df, transforms=None):
        self.root_dir = root_dir
        self.df = df
        self.labels=df['label'].values
        self.image_ids=df['image_id'].values
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.image_ids[index])
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        label = self.labels[index]
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return img, label

In [None]:
data_transforms = {
    "train": A.Compose([
        A.RandomResizedCrop(CFG.img_size, CFG.img_size),
#         A.Transpose(p=0.5),
        A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.5),
#         A.ShiftScaleRotate(p=0.5),
#         A.HueSaturationValue(
#                 hue_shift_limit=0.2, 
#                 sat_shift_limit=0.2, 
#                 val_shift_limit=0.2, 
#                 p=0.5
#             ),
#         A.RandomBrightnessContrast(
#                 brightness_limit=(-0.1,0.1), 
#                 contrast_limit=(-0.1, 0.1), 
#                 p=0.5
#             ),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
#         A.CoarseDropout(p=0.5),
#         A.Cutout(p=0.5),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.CenterCrop(CFG.img_size, CFG.img_size, p=1.),
        A.Resize(CFG.img_size, CFG.img_size),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.)
}

In [None]:
import tqdm
from tqdm.notebook import tqdm as tqdm

In [None]:
# class EffNet(nn.Module):
#     def __init__(self, n_classes, pretrained=True):
#         super(EffNet, self).__init__()
#         self.model = timm.create_model(CFG.model_name, pretrained=True)
#         num_features = self.model.classifier.in_features
#         self.model.classifier = nn.Linear(num_features, CFG.num_classes)

#     def forward(self, x):
#         x = self.model(x)
#         return x

class EffNet(nn.Module):
    def __init__(self, n_classes, pretrained=True):
        super(EffNet, self).__init__()
        self.model = timm.create_model(CFG.model_name, pretrained=True)
        
        self.logit = nn.Linear(512, CFG.num_classes)

    def forward(self, x):
        batch_size, C, H, W = x.shape
        logit = self.model.forward_features(x)
        logit = F.adaptive_avg_pool2d(logit,1).reshape(batch_size,-1)
        
        logit=self.logit(logit)
        return logit  

In [None]:
def train_one_epoch(model,train_loader, criterion, optimizer, device):
    # keep track of training loss
    epoch_loss = 0.0
    epoch_accuracy = 0.0

    ###################
    # train the model #
    ###################
    model.train()
    for i, (data, target) in enumerate(tqdm(train_loader)):
        
#         print(i)
        data = data.to(device)
        target = target.to(device)

        # clear the gradients of all optimized variables
        optimizer.zero_grad()

        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)

        # calculate the batch loss
        loss = criterion(output, target)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()

        # Calculate Accuracy
        accuracy = (output.argmax(dim=1) == target).float().mean()
        # update training loss and accuracy
        epoch_loss += loss
        epoch_accuracy += accuracy

        # perform a single optimization step (parameter update)

        xm.optimizer_step(optimizer)



    return epoch_loss / len(train_loader), epoch_accuracy / len(train_loader)

def validate_one_epoch(model,valid_loader, criterion, device):
    # keep track of validation loss
    valid_loss = 0.0
    valid_accuracy = 0.0

    ######################
    # validate the model #
    ######################
    model.eval()
    for data, target in valid_loader:
        # move tensors to GPU if CUDA is available

        data = data.to(device, dtype=torch.float32)
        target = target.to(device, dtype=torch.int64)

        with torch.no_grad():
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # Calculate Accuracy
            accuracy = (output.argmax(dim=1) == target).float().mean()
            # update average validation loss and accuracy
            valid_loss += loss
            valid_accuracy += accuracy

    return valid_loss / len(valid_loader), valid_accuracy / len(valid_loader)

In [None]:
def fit_tpu(
    model, epochs, device, criterion, optimizer, train_loader, valid_loader=None
):

    best_acc = 0.0  # track change in validation loss

    # keeping track of losses as it happen
    history = defaultdict(list)

    for epoch in range(1, epochs + 1):
        gc.collect()
        para_train_loader = pl.ParallelLoader(train_loader, [device])

        xm.master_print(f"{'='*50}")
        xm.master_print(f"EPOCH {epoch} - TRAINING...")
        train_loss, train_acc = train_one_epoch(model,
            para_train_loader.per_device_loader(device), criterion, optimizer, device
        )
        xm.master_print(
            f"\n\t[TRAIN] EPOCH {epoch} - LOSS: {train_loss}, ACCURACY: {train_acc}\n"
        )
        history['train loss'].append(train_loss)
        history['train acc'].append(train_acc)
        gc.collect()

        if valid_loader is not None:
            gc.collect()
            para_valid_loader = pl.ParallelLoader(valid_loader, [device])
            xm.master_print(f"EPOCH {epoch} - VALIDATING...")
            valid_loss, valid_acc = validate_one_epoch(model,
                para_valid_loader.per_device_loader(device), criterion, device
            )
            xm.master_print(f"\t[VALID] LOSS: {valid_loss}, ACCURACY: {valid_acc}\n")
            history['valid loss'].append(valid_loss)
            history['valid acc'].append(valid_acc)
            gc.collect()

            # save model if validation loss has decreased
            if valid_acc <= best_acc:
#                 xm.master_print(
#                     "Validation loss decreased ({:.4f} --> {:.4f}).  Saving model ...".format(
#                         valid_loss_min, valid_loss
#                     )
#                 )
                PATH = f"Fold-{fold}_{best_acc}_epoch-{epoch}.pth"
                xm.save(model.state_dict(), PATH)

            best_acc = valid_acc

    return history

In [None]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1): 
        super(LabelSmoothingLoss, self).__init__() 
        self.confidence = 1.0 - smoothing 
        self.smoothing = smoothing 
        self.cls = classes 
        self.dim = dim
        
    def forward(self, pred, target): 
        pred = pred.log_softmax(dim=self.dim) 
        with torch.no_grad(): 
            true_dist = torch.zeros_like(pred) 
            true_dist.fill_(self.smoothing / (self.cls - 1)) 
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [None]:
# model = EffNet(n_classes=CFG.num_classes)

In [None]:
def _run(fold):
    valid_df = df[df.kfold == fold]
    train_df = df[df.kfold != fold]
    
    train_dataset = CassavaLeafDataset(TRAIN_DIR, train_df, transforms=data_transforms["train"])
    valid_dataset = CassavaLeafDataset(TRAIN_DIR, valid_df, transforms=data_transforms["valid"])

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True,
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False,
    )

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=CFG.batch_size,
        sampler=train_sampler,
        drop_last=True,
        num_workers=4,
    )

    valid_loader = torch.utils.data.DataLoader(
        dataset=valid_dataset,
        batch_size=CFG.batch_size*4,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=4,
    )

#     criterion = LabelSmoothingLoss(smoothing=CFG.smoothing, classes=CFG.num_classes)
    criterion=torch.nn.CrossEntropyLoss()
    
    device = xm.xla_device()
    model = EffNet(n_classes=CFG.num_classes)
    model.to(device)

    lr = CFG.lr * xm.xrt_world_size()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr,weight_decay=CFG.weight_decay)

    xm.master_print(f"INITIALIZING TRAINING ON {xm.xrt_world_size()} TPU CORES")
    start_time = datetime.now()
    xm.master_print(f"Start Time: {start_time}")

    logs = fit_tpu(
        model=model,
        epochs=CFG.num_epochs,
        device=device,
        criterion=criterion,
        optimizer=optimizer,
        train_loader=train_loader,
        valid_loader=valid_loader,
    )

    xm.master_print(f"Execution time: {datetime.now() - start_time}")

    xm.master_print("Saving Model")
    xm.save(
        model.state_dict(), f'model_5e_{datetime.now().strftime("%Y%m%d-%H%M")}.pth'
    )

In [None]:
# Start training processes
def _mp_fn(rank, flags):
    torch.set_default_tensor_type("torch.FloatTensor")
    a = _run(fold=0)


# _run()
FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method="fork")

In [None]:
!rm *.py
!rm *.whl