## Library imports

### install pytorch xla to use TPUs

In [1]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7
!pip install timm

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5116  100  5116    0     0  21677      0 --:--:-- --:--:-- --:--:-- 21677
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-1.7 ...
Found existing installation: torch 1.7.0
Uninstalling torch-1.7.0:
Done updating TPU runtime
  Successfully uninstalled torch-1.7.0
Found existing installation: torchvision 0.8.1
Uninstalling torchvision-0.8.1:
  Successfully uninstalled torchvision-0.8.1
Copying gs://tpu-pytorch/wheels/torch-1.7-cp37-cp37m-linux_x86_64.whl...

Operation completed over 1 objects/114.2 MiB.                                    
Copying gs://tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl...

Operation completed over 1 objects/127.4 MiB.                                    
Copying gs://tpu-pytorch/wheels/torchvision-1.7-cp37-cp37m-linux_x86_64.whl...

Operati

In [2]:
# append package pathss
import sys
append_paths = ['../input/pytorch-image-models/pytorch-image-models-master', '../input/image-fmix/FMix-master']
for package_path in append_paths:
    sys.path.append(package_path)

# basic imports
import os
import gc
import numpy as np
import pandas as pd
import random
import math
import time
import itertools
from tqdm.notebook import tqdm
from datetime import datetime



# augumentations library
from albumentations.pytorch import ToTensorV2
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightnessContrast,ShiftScaleRotate, Cutout, CoarseDropout, 
    IAAAdditiveGaussianNoise, Transpose, MotionBlur, MedianBlur, GaussianBlur, HueSaturationValue
    )
import albumentations as A
from fmix import sample_mask
import cv2

# DL library imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from  torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms

## pytorch-xla imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

# timm import
import timm

# metrics calculation
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import KFold, StratifiedKFold

# basic plotting library
import matplotlib.pyplot as plt
plt.style.use("ggplot")

# interactive plots
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import warnings  
warnings.filterwarnings('ignore')

In [3]:
## For parallelization in TPUs
os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

## Config params

In [4]:
class CFG:
    # pipeline parameters
    SEED        = 42
    NUM_CLASSES = 5
    TGT_LABEL   = 'label'
    TRAIN       = True
    LR_FIND     = False
    RETRAIN     = False
    TEST        = False
    DEBUG       = False
    N_FOLDS     = 5 
    N_EPOCHS    = 5 
    DF_FRAC     = 1  
    TEST_BATCH_SIZE  = 32
    TRAIN_BATCH_SIZE = 16
    SIZE             = [224, 224]
    NUM_WORKERS      = 8
    FOLD_TO_TRAIN    = [0] #, 1, 2, 3, 4

    # model parameters
    MODEL_ARCH  = 'vit_base_patch16_224'
    MODEL_NAME  = 'vit_v1'
    WGT_PATH    = ''
    WGT_MODEL   = ''

    # loss fn parameters
    LOSS_FN     = 'CrossEntropyLoss' # 'LabelSmoothingCrossEntropy'
    SMOOTHING   = 0.3
    MIX_PROB    = 0.25
    
    # scheduler variables
    SCHEDULER = 'OneCycleLR' # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts', 'CosineAnnealingWarmRestarts']
    T_0       = 10 # CosineAnnealingWarmRestarts
    MAX_LR    = 3e-4
    MIN_LR    = 1e-6
    T_MAX     = 5

    # optimizer variables
    OPTIMIZER     = 'Adam'
    WEIGHT_DECAY  = 1e-6
    GRD_ACC_STEPS = 1
    MAX_GRD_NORM  = 1000
    
    # Vit parameter
    GAMMA = 0.7


TRAIN_PATH = '../input/cassava-leaf-disease-classification/train_images'
TEST_PATH = '../input/cassava-leaf-disease-classification/test_images'
NPY_FOLDER = '../input/cassava-npy-train-images/train_npy_images'
DIR_INPUT = '../input/cassava-leaf-disease-classification'
MODEL_PATH = '../input/vit-base-models-pretrained-pytorch/jx_vit_base_p16_224-80ecf9dd.pth'

index_label_map = {
                0: "Cassava Bacterial Blight (CBB)", 
                1: "Cassava Brown Streak Disease (CBSD)",
                2: "Cassava Green Mottle (CGM)", 
                3: "Cassava Mosaic Disease (CMD)", 
                4: "Healthy"
                }

class_names = [value for key,value in index_label_map.items()]

## Helper functions

In [5]:
def plot_training_results(train_results):
    fig = make_subplots(rows=2, cols=1)

    colors = [
        ('#d32f2f', '#ef5350'),
        ('#303f9f', '#5c6bc0'),
        ('#00796b', '#26a69a'),
        ('#fbc02d', '#ffeb3b'),
        ('#5d4037', '#8d6e63'),
    ]

    for i in range(CFG.N_FOLDS):
        data = train_results[train_results['fold'] == i]

        fig.add_trace(go.Scatter(x=data['epoch'].values,
                                 y=data['train_loss'].values,
                                 mode='lines',
                                 visible='legendonly' if i > 0 else True,
                                 line=dict(color=colors[i][0], width=2),
                                 name='Train loss - Fold #{}'.format(i)),
                     row=1, col=1)

        fig.add_trace(go.Scatter(x=data['epoch'],
                                 y=data['valid_loss'].values,
                                 mode='lines+markers',
                                 visible='legendonly' if i > 0 else True,
                                 line=dict(color=colors[i][1], width=2),
                                 name='Valid loss - Fold #{}'.format(i)),
                     row=1, col=1)

        fig.add_trace(go.Scatter(x=data['epoch'].values,
                                 y=data['valid_score'].values,
                                 mode='lines+markers',
                                 line=dict(color=colors[i][0], width=2),
                                 name='Valid score - Fold #{}'.format(i),
                                 showlegend=False),
                     row=2, col=1)

    fig.update_layout({
      "annotations": [
        {
          "x": 0.225, 
          "y": 1.0, 
          "font": {"size": 16}, 
          "text": "Train / valid losses", 
          "xref": "paper", 
          "yref": "paper", 
          "xanchor": "center", 
          "yanchor": "bottom", 
          "showarrow": False
        }, 
        {
          "x": 0.775, 
          "y": 1.0, 
          "font": {"size": 16}, 
          "text": "Validation scores", 
          "xref": "paper", 
          "yref": "paper", 
          "xanchor": "center", 
          "yanchor": "bottom", 
          "showarrow": False
        }, 
      ]
    })

    fig.show()

In [6]:
def find_no_of_trainable_params(model):
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_trainable_params

In [7]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
        
set_seed(CFG.SEED)

In [8]:
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [9]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Dataset 

In [10]:
train_df = pd.read_csv(f'{DIR_INPUT}/train.csv')
#train_df[['cls0', 'cls1', 'cls2', 'cls3', 'cls4']] = train_labels = pd.get_dummies(train_df.iloc[:, 1])
train_df['npy_image_id'] = train_df['image_id'].str.replace('jpg', 'npy')
if CFG.DF_FRAC < 1:
    train_df = train_df.sample(frac=CFG.DF_FRAC).reset_index(drop=True)
train_labels = train_df.iloc[:, 1].values
print(train_df.shape)
train_df.head()
folds = StratifiedKFold(n_splits=CFG.N_FOLDS, shuffle=True, random_state=CFG.SEED)

if CFG.DEBUG == True:
    pass
    #folds = train_df.copy()
    #for n, (train_index, val_index) in enumerate(Fold.split(folds, folds[CFG.TGT_LABEL])):
    #    folds.loc[val_index, 'fold'] = int(n)
    #folds['fold'] = folds['fold'].astype(int)
    #print(folds.groupby(['fold', CFG.TGT_LABEL]).size())

(21397, 3)


In [11]:
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['npy_image_id'].values
        self.labels = df[CFG.TGT_LABEL].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = np.load(f'{NPY_FOLDER}/{self.file_names[idx]}')
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).long()
        return image, label
    

class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [12]:
if CFG.DEBUG == True:
    train_dataset = TrainDataset(train_df, transform=None)
    for i in range(1):
        image, label = train_dataset[i]
        plt.imshow(image)
        plt.title(f'label: {label}')
        plt.show() 

## Transforms for Augumentations

In [13]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix(data, target, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.clip(np.random.beta(alpha, alpha),0.3,0.4)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    new_data = data.clone()
    new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
    targets = (target, shuffled_target, lam)
    return new_data, targets

def fmix(device, data, targets, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
    lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
    #mask =torch.tensor(mask, device=device).float()
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]
    x1 = torch.from_numpy(mask).to(device)*data
    x2 = torch.from_numpy(1-mask).to(device)*shuffled_data
    targets=(targets, shuffled_targets, lam)
    return (x1+x2), targets

In [14]:
def generate_transforms():
    train_transforms = Compose([
            Resize(height=CFG.SIZE[0], width=CFG.SIZE[1]), #RandomResizedCrop(CFG.size, CFG.size),
            Transpose(p=0.3), VerticalFlip(p=0.3), HorizontalFlip(p=0.3), ShiftScaleRotate(p=0.4),
            RandomBrightnessContrast(p=0.4), 
            IAAAdditiveGaussianNoise(p=0.3),  # sharpen, affine transform
            OneOf([MotionBlur(blur_limit=3), MedianBlur(blur_limit=3), GaussianBlur(blur_limit=3)], p=0.3),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.3),
            CoarseDropout(p=0.4), Cutout(p=0.4),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0)])
            # RandomCrop, IAAAdditiveGaussianNoise, RandomResizedCrop(sz,sz),   
            # CLAHE, ImageCompression, MaskDropout, elastictransform
            # IAAAffine

    val_transforms = Compose([
            Resize(height=CFG.SIZE[0], width=CFG.SIZE[1]),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0), ToTensorV2(p=1.0)])

    test_transforms = Compose([
            Resize(height=CFG.SIZE[0], width=CFG.SIZE[1]),
            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0), ToTensorV2(p=1.0)])

    return {'train_transforms':train_transforms, 'val_transforms':val_transforms, 'test_transform':test_transforms}

In [15]:
if CFG.DEBUG == True:
    train_dataset = TrainDataset(train_df, transform=generate_transforms()['train_transforms'])
    for i in range(1):
        image, label = train_dataset[i]
        plt.imshow(image[0])
        plt.title(f'label: {label}')
        plt.show() 

## Model class

In [16]:
print("Available Vision Transformer Models: ")
timm.list_models("vit*")

Available Vision Transformer Models: 


['vit_base_patch16_224',
 'vit_base_patch16_384',
 'vit_base_patch32_384',
 'vit_base_resnet26d_224',
 'vit_base_resnet50d_224',
 'vit_huge_patch16_224',
 'vit_huge_patch32_384',
 'vit_large_patch16_224',
 'vit_large_patch16_384',
 'vit_large_patch32_384',
 'vit_small_patch16_224',
 'vit_small_resnet26d_224',
 'vit_small_resnet50d_s3_224']

In [17]:
class ViTBase16(nn.Module):
    def __init__(self, model_name=CFG.MODEL_ARCH, pretrained=False):
        super(ViTBase16, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        #if pretrained == False:
        #    model.load_state_dict(torch.load(MODEL_PATH))
        self.model.head = nn.Linear(self.model.head.in_features, CFG.NUM_CLASSES)
        
    def forward(self, x):
        x = self.model(x)
        return x

In [18]:
if CFG.DEBUG == True:
    model = ViTBase16(model_name=CFG.MODEL_ARCH, pretrained=False)
    train_dataset = TrainDataset(train_df, transform=generate_transforms()['train_transforms'])
    train_loader = DataLoader(train_dataset, batch_size= 4, shuffle=True,
                              num_workers=CFG.NUM_WORKERS, pin_memory=True, drop_last=True)
    for image, label in train_loader:
        output = model(image)
        print(output)
        break

## Loss function

In [19]:
class LabelSmoothingCrossEntropy(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x, target):
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()
    

In [20]:
def get_loss_fn():
    if CFG.LOSS_FN == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = LabelSmoothingCrossEntropy(smoothing=CFG.SMOOTHING)
    return criterion

criterion = get_loss_fn()

## Device as cpu or tpu or gpu

In [21]:
#device = xm.xla_device()
#print(f"Device found is {device}")

## Lr_find

In [22]:
def plot_lr_finder_results(lr_finder): 
    # Create subplot grid
    fig = make_subplots(rows=1, cols=2)
    # layout ={'title': 'Lr_finder_result'}
    
    # Create a line (trace) for the lr vs loss, gradient of loss
    trace0 = go.Scatter(x=lr_finder['log_lr'], y=lr_finder['smooth_loss'],name='log_lr vs smooth_loss')
    trace1 = go.Scatter(x=lr_finder['log_lr'], y=lr_finder['grad_loss'],name='log_lr vs loss gradient')

    # Add subplot trace & assign to each grid
    fig.add_trace(trace0, row=1, col=1);
    fig.add_trace(trace1, row=1, col=2);
    #iplot(fig, show_link=False)
    fig.write_html(CFG.MODEL_NAME + '_lr_find.html');

In [23]:
def find_lr(model, optimizer, data_loader, init_value = 1e-8, final_value=100.0, beta = 0.98, num_batches = 200):
    assert(num_batches > 0)
    mult = (final_value / init_value) ** (1/num_batches)
    lr = init_value
    optimizer.param_groups[0]['lr'] = lr
    batch_num = 0
    avg_loss = 0.0
    best_loss = 0.0
    smooth_losses = []
    raw_losses = []
    log_lrs = []
    dataloader_it = iter(data_loader)
    progress_bar = tqdm(range(num_batches))                
        
    for idx in progress_bar:
        batch_num += 1
        try:
            images, labels = next(dataloader_it)
            #print(images.shape)
        except:
            dataloader_it = iter(data_loader)
            images, labels = next(dataloader_it)

        # Move input and label tensors to the default device
        images = images.to(device)
        labels = labels.to(device)

        # handle exception in criterion
        try:
            # Forward pass
            y_preds = model(images.float())
            loss = criterion(y_preds, labels)
        except:
            if len(smooth_losses) > 1:
                grad_loss = np.gradient(smooth_losses)
            else:
                grad_loss = 0.0
            lr_finder_results = {'log_lr':log_lrs, 'raw_loss':raw_losses, 
                                 'smooth_loss':smooth_losses, 'grad_loss': grad_loss}
            return lr_finder_results 
                    
        #Compute the smoothed loss
        avg_loss = beta * avg_loss + (1-beta) *loss.item()
        smoothed_loss = avg_loss / (1 - beta**batch_num)
        
        #Stop if the loss is exploding
        if batch_num > 1 and smoothed_loss > 50 * best_loss:
            if len(smooth_losses) > 1:
                grad_loss = np.gradient(smooth_losses)
            else:
                grad_loss = 0.0
            lr_finder_results = {'log_lr':log_lrs, 'raw_loss':raw_losses, 
                                 'smooth_loss':smooth_losses, 'grad_loss': grad_loss}
            return lr_finder_results
        
        #Record the best loss
        if smoothed_loss < best_loss or batch_num==1:
            best_loss = smoothed_loss
        
        #Store the values
        raw_losses.append(loss.item())
        smooth_losses.append(smoothed_loss)
        log_lrs.append(math.log10(lr))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # print info
        progress_bar.set_description(f"loss: {loss.item()},smoothed_loss: {smoothed_loss},lr : {lr}")

        #Update the lr for the next step
        lr *= mult
        optimizer.param_groups[0]['lr'] = lr
    
    grad_loss = np.gradient(smooth_losses)
    lr_finder_results = {'log_lr':log_lrs, 'raw_loss':raw_losses, 
                         'smooth_loss':smooth_losses, 'grad_loss': grad_loss}
    return lr_finder_results

In [24]:
if CFG.LR_FIND == True:
    # create Dataset
    temp_train_dataset = TrainDataset(train_df, transform=generate_transforms()['train_transforms'])
    temp_train_dataloader = DataLoader(temp_train_dataset, batch_size= CFG.TRAIN_BATCH_SIZE, shuffle=True,
                          num_workers=CFG.NUM_WORKERS, pin_memory=False, drop_last=False)

    # create model instance
    # load pretrained weight file, if present
    if CFG.RETRAIN == True:
        i_fold = 0
        checkpoint = torch.load(f'{CFG.WGT_PATH}/{CFG.WGT_MODEL}_fold{i_fold}.pth')
        model = CustomResNext(model_name=CFG.MODEL_ARCH, pretrained=False)
        model.to(device)
        model.load_state_dict(checkpoint['model'])
        print(f'Model loaded for {CFG.WGT_MODEL}_fold{i_fold}')
            
    else:
        model = CustomResNext(model_name=CFG.MODEL_ARCH, pretrained=True)
        model.to(device)

    optimizer = optim.Adam(model.parameters(), weight_decay=CFG.WEIGHT_DECAY, lr=CFG.MAX_LR)
    lr_finder_results = find_lr(model, optimizer, temp_train_dataloader)
    plot_lr_finder_results(lr_finder_results)

## One fold train and validation function

In [25]:
def train_one_fold(i_fold, model, optimizer, scheduler, device, dataloader_train, dataloader_valid):
    train_fold_results = []
    lr_list = []
    best_val_acc = 0.0
    best_epoch = 0

    para_train_loader = pl.ParallelLoader(dataloader_train, [device]).per_device_loader(device)
    para_valid_loader = pl.ParallelLoader(dataloader_valid, [device]).per_device_loader(device)

    for epoch in range(CFG.N_EPOCHS):
        xm.master_print('  Epoch {}/{}'.format(epoch + 1, CFG.N_EPOCHS))
        #gc.collect()
        
        model.train()
        tr_loss = 0.0    
        # training iterator
        tr_iterator = iter(para_train_loader)
        #train_progress_bar = tqdm(range(len(para_train_loader)))
    
        #for idx in train_progress_bar:
        for idx in range(len(para_train_loader)):
            try:
                images, labels = next(tr_iterator)
                #print(images.shape)
            except StopIteration:
                tr_iterator = iter(dataloader_train)
                images, labels = next(tr_iterator)

            images = images.to(device,dtype=torch.float32)
            labels = labels.to(device, dtype=torch.int64)  
            
            # adding fmix
            mix_decision = np.random.rand()
            if mix_decision < CFG.MIX_PROB:
                images, labels = fmix(device, images, labels, alpha=1., decay_power=5., shape=(CFG.SIZE[0],CFG.SIZE[1]))
            
            # clear the gradients of all optimized variables
            optimizer.zero_grad() 

            # Forward pass
            y_preds = model(images)            
            if mix_decision < CFG.MIX_PROB:
                loss = criterion(y_preds, labels[0]) * labels[2] + criterion(y_preds, labels[1]) * (1.0 - labels[2])
            else:
                loss = criterion(y_preds, labels)
                    
            # Backward pass
            loss.backward()
            tr_loss += loss.item()
            xm.optimizer_step(optimizer)
                    
            # lr scheduler
            scheduler.step()
            lr_list.append(optimizer.state_dict()["param_groups"][0]['lr'])
            
            #if idx % 20 == 0:
            #    xm.master_print(f"Batch {idx} : Train_loss: {tr_loss} loss(avg): {tr_loss/(idx+1)}")
            #train_progress_bar.set_description(f"Train_loss: {tr_loss} loss(avg): {tr_loss/(idx+1)}")
        #gc.collect()
        xm.master_print(f"Epoch {epoch} : Train_loss: {tr_loss} loss(avg): {tr_loss/(idx+1)}")
        
        # Validate
        model.eval()
        val_loss = 0.0
        val_preds = None
        val_labels = None
        valid_iterator = iter(para_valid_loader)
        #valid_progress_bar = tqdm(range(len(para_valid_loader)))
        #pbar = tqdm(enumerate(train_loader), total=len(train_loader))


        for idx in range(len(para_valid_loader)):
        #for idx in valid_progress_bar:
            try:
                images, labels = next(valid_iterator)
            except StopIteration:
                tr_iterator = iter(dataloader_valid)
                images, labels = next(valid_iterator)
            
            images = images.to(device,dtype=torch.float32)
            labels = labels.to(device, dtype=torch.int64)  

            if val_labels is None:
                val_labels = labels.clone()
            else:
                val_labels = torch.cat((val_labels, labels), dim=0)
            
            with torch.no_grad():
                y_preds = model(images)
            
            # computing validation loss and metric
            loss = criterion(y_preds, labels)
            val_loss += loss.item()
            preds = torch.softmax(y_preds, dim=1)
            
            # store predictions            
            if val_preds is None:
                val_preds = preds
            else:
                val_preds = torch.cat((val_preds, preds), dim=0)
                
            # print to console
            #valid_progress_bar.set_description(f"val_loss: {val_loss} loss(avg): {val_loss/(idx+1)}")
        
        # save predictions
        val_preds  = np.argmax(val_preds.cpu().data.numpy(), axis=1)
        val_labels = val_labels.cpu().data.numpy()

        # compute accuracy
        val_score = accuracy_score(val_labels, val_preds)
        cm = confusion_matrix(val_labels, val_preds)
        class_wise_acc = []
        for i, val in enumerate(cm):
            class_wise_acc.append(val[i]/sum(val)*100)
        xm.master_print(f"Fold:{i_fold}, Epoch:{epoch}, val acc:{val_score * 100.0}, Classwise_acc:{class_wise_acc}")
        #gc.collect()
        
        # store results
        train_fold_results.append({ 'fold': i_fold, 'epoch': epoch, 'train_loss': tr_loss / len(dataloader_train), 
                                    'valid_loss': val_loss / len(dataloader_valid), 'valid_score': val_score,
                                    'class_wise_acc': class_wise_acc})
            
        # save best models        
        if val_score > best_val_acc:
            # reset variables
            best_val_acc = val_score
            best_epoch = epoch
                        
            # save model weights
            xm.save({'model': model.state_dict(), 'val_preds':val_preds, 'val_labels':val_labels}, 
                        f"{CFG.MODEL_NAME}_fold_{i_fold}_epoch{epoch}_{val_score}.pth")
    
    xm.master_print(f"For Fold {i_fold}, Best validation accuracy of {best_val_acc} was got at epoch {best_epoch}")                
    lr_list = np.array(lr_list)
    np.save(f"{CFG.MODEL_NAME}_fold{i_fold}_LRlist.npy", lr_list)
    return train_fold_results

## Training and validation function calls

In [26]:
def get_TPU_Dataloaders(train_data, valid_data):
    dataset_train = TrainDataset(train_data, transform=generate_transforms()['train_transforms'])
    dataset_valid = TrainDataset(valid_data, transform=generate_transforms()['val_transforms'])
            
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train,
                                    num_replicas=xm.xrt_world_size(),rank=xm.get_ordinal(),shuffle=True)

    valid_sampler = torch.utils.data.distributed.DistributedSampler(dataset_valid,
                                    num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(),shuffle=False)

    dataloader_train = DataLoader(dataset=dataset_train, batch_size=CFG.TRAIN_BATCH_SIZE, sampler=train_sampler,
                              drop_last=False, num_workers=CFG.NUM_WORKERS) #pin_memory=False

    dataloader_valid = DataLoader(dataset=dataset_valid, batch_size=CFG.TRAIN_BATCH_SIZE, sampler=valid_sampler,
                              drop_last=False, num_workers=CFG.NUM_WORKERS) #pin_memory=False
    return dataloader_train, dataloader_valid

In [27]:
def fit_TPU():
    train_results = []
    device = xm.xla_device()
    xm.master_print(f"INITIALIZING TRAINING ON {xm.xrt_world_size()} TPU CORES")
    
    for i_fold, (train_idx, valid_idx) in enumerate(folds.split(train_df, train_labels)):
        if i_fold in CFG.FOLD_TO_TRAIN:
            xm.master_print("Fold {}/{}".format(i_fold + 1, CFG.N_FOLDS))

            # create fold data
            train_data = train_df.iloc[train_idx].reset_index()    
            valid_data = train_df.iloc[valid_idx].reset_index()
            xm.master_print(train_data.shape, valid_data.shape)
            dataloader_train, dataloader_valid = get_TPU_Dataloaders(train_data, valid_data)

            # load pretrained weight file
            if CFG.RETRAIN == True:
                checkpoint = torch.load(f'{CFG.WGT_PATH}/{CFG.WGT_MODEL}_fold{i_fold}.pth')
                model = ViTBase16(model_name=CFG.MODEL_ARCH, pretrained=False)
                model.to(device)
                model.load_state_dict(checkpoint['model'])
                xm.master_print(f'Model loaded for {CFG.WGT_MODEL}_fold{i_fold}')

            else:
                model = ViTBase16(model_name=CFG.MODEL_ARCH, pretrained=True)
                model.to(device)

            ## optimizer function
            if CFG.OPTIMIZER == 'Adam':
                optimizer = optim.Adam(model.parameters(), weight_decay=CFG.WEIGHT_DECAY, lr=CFG.MAX_LR)
            else:
                optimizer = optim.SGD(model.parameters(), weight_decay=CFG.WEIGHT_DECAY, lr=CFG.MAX_LR, momentum=0.9)


            # lr scheduler
            if CFG.SCHEDULER == 'OneCycleLR':
                scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr= CFG.MAX_LR, epochs = CFG.N_EPOCHS, 
                                                          steps_per_epoch = len(dataloader_train), pct_start=0.4, 
                                                          div_factor=10, anneal_strategy='cos')
            elif CFG.SCHEDULER == 'CosineAnnealingWarmRestarts':
                scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, 
                                                        eta_min=CFG.MIN_LR, last_epoch=-1)
            else:
                scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_MAX,eta_min=CFG.MIN_LR, last_epoch=-1)

            #xm.master_print(f"scheduler:{scheduler},optimizer:{optimizer},loss_fn:{criterion}")
            train_fold_results = train_one_fold(i_fold, model, optimizer, scheduler, device,
                                                dataloader_train, dataloader_valid)
            train_results = train_results + train_fold_results
    return train_results

In [28]:
# Start training processes
def _mp_fn(rank, flags):
    torch.set_default_tensor_type("torch.FloatTensor")
    train_results = fit_TPU()    
    train_results = pd.DataFrame(train_results)
    print(train_results)
    train_results.to_csv('train_results.csv', index=False)
    #best_folds = np.array([train_results[train_results['fold']==x]['valid_score'].max() for x in CFG.FOLD_TO_TRAIN])
    #print(f'Overall CV accuracy : {best_folds.mean()}, std: {best_folds.std()}')
    #plot_training_results(train_results)

<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [29]:
%%time
FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method="fork")

INITIALIZING TRAINING ON 8 TPU CORES
Fold 1/5
(17117, 4) (4280, 4)


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth" to /root/.cache/torch/hub/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth" to /root/.cache/torch/hub/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth
Exception in device=TPU:3: [Errno 104] Connection reset by peer
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "<ipython-input-28-2eef5c61b590>", line 4, in _mp_fn
    train_results = fit_TPU()
  File "<ipython-input-27-0311e0fff9dd>", line 25, in fit_TPU
    model = ViTBase16(model_name=CFG.MODEL_ARCH, pretra

Exception: process 3 terminated with exit code 17