In [None]:
tpu = True

In [None]:
if tpu:
    !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
    !python pytorch-xla-env-setup.py --version "nightly"

In [None]:
!pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

In [None]:
!pip install efficientnet_pytorch

In [None]:
!pip install timm

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import transforms
from torch.optim import lr_scheduler
from efficientnet_pytorch import EfficientNet
from torch.nn import functional as F

if tpu:
    import torch_xla
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.data_parallel as dp
    import torch_xla.distributed.parallel_loader as pl

    import torch_xla.utils.utils as xu
    import torch_xla.core.xla_model as xm
    import torch_xla.utils.serialization as xser
    import torch_xla.distributed.xla_multiprocessing as xmp
    import torch_xla.test.test_utils as test_utils

from warmup_scheduler import GradualWarmupScheduler

import numpy as np
import pandas as pd
import os
import gc
import time
from tqdm import tqdm
import timm

import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
import albumentations
from albumentations import Compose, Normalize, HorizontalFlip, VerticalFlip

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
from albumentations.pytorch import ToTensorV2

import warnings
warnings.filterwarnings("ignore")

In [None]:
print(torch.__version__)

In [None]:
os.environ['XLA_USE_32BIT_LONG'] = '1'
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

In [None]:
#config


data_dir = '../input/cassava-leaf-disease-classification'
df_train = pd.read_csv(os.path.join(data_dir, 'train.csv'))
image_folder = os.path.join(data_dir, 'train_images')


folds = 6
present_fold = 2
image_size = 384
batch_size = 4
val_batch_size = 4
num_workers = 4
out_dim = 5
init_lr = 1e-4

warmup_factor = 7
warmup_epo = 1

smoothing = 0.05

epoch_threshold = 5

debug = False

n_epochs = 2 if debug else 8

kernel_type = 'vit_large_patch16_384' 

net_type = 'vit_large_patch16_384'
print(image_folder)
t1 = 0.8
t2 = 1.4

In [None]:
#splitting data
skf = StratifiedKFold(folds,shuffle=True,random_state=42)
df_train['fold'] = -1
if debug:
    df_train = df_train[:500]
for i, (train_idx, valid_idx) in enumerate(skf.split(df_train, df_train['label'])):
    df_train.loc[valid_idx, 'fold'] = i
df_train.head()

In [None]:
class CASSAVA(Dataset):
    def __init__(self,df,transforms=None):
        self.df = df
        #print(self.df.shape[0])
        self.transforms = transforms
    def __len__(self):
        return self.df.shape[0]
    def __getitem__(self,idx):
        row = self.df.iloc[idx]
        img_dir = os.path.join(image_folder,row['image_id'])
        img = cv2.imread(img_dir)
        if self.transforms is not None:
            img = self.transforms(image=img)['image']
        label = row.label
        return img,torch.tensor(label)
#dummy
dataset = CASSAVA(df_train)
X,Y = dataset.__getitem__(3)

In [None]:
def get_train_transforms():
    return Compose([
            RandomResizedCrop(image_size, image_size),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.)
  
        
def get_valid_transforms():
    return Compose([
            #HorizontalFlip(p=0.5),
            #VerticalFlip(p=0.5),
            CenterCrop(image_size, image_size, p=1.),
            Resize(image_size, image_size),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)
transforms_train = get_train_transforms()
transforms_valid = get_valid_transforms()

In [None]:
dataset_show = CASSAVA(df_train, transforms=transforms_train)

for i in range(2):
    f, axarr = plt.subplots(1,5,figsize=(20,10))
    for p in range(5):
        idx = np.random.randint(0, len(dataset_show))
        img, label = dataset_show[idx]
        print(type(img),img.size(),img.max())
        axarr[p].imshow(img.transpose(0, 1).transpose(1,2).squeeze())
        axarr[p].set_title(str(torch.argmax(label,axis=-1)))

In [None]:
class Classifier(nn.Module):
    def __init__(self,):
        super().__init__()
        self.effn = EfficientNet.from_pretrained('efficientnet-b4')
       
        self.dense = nn.Linear(1792,out_dim)
        self.pooling = nn.MaxPool2d((12,12))
        self.flatten = nn.Flatten()
        for param in self.effn.parameters():
            param.requires_grad = True
        
    def forward(self,x):
        x = self.effn.extract_features(x)
        x = self.pooling(x)
        x = self.flatten(x)
        x = self.dense(x)
       
        return x
    
class cassavamodel(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        #print(self.model)
        n_features = self.model.head.in_features
        self.model.head = nn.Linear(n_features, n_class)

    def forward(self, x):
        x = self.model(x)
        return x
    
#dummy
'''dummy_img = torch.rand((1,3,image_size,image_size))
dummy = Classifier()

#print(dummy)
o_p = dummy(dummy_img)
print(o_p)
model = Classifier()
dummy_img = torch.rand((1,3,image_size,image_size))
o_p = model(dummy_img)
print(o_p)'''

In [None]:
class TaylorSoftmax(nn.Module):

    def __init__(self, dim=1, n=2):
        super(TaylorSoftmax, self).__init__()
        assert n % 2 == 0
        self.dim = dim
        self.n = n

    def forward(self, x):
        
        fn = torch.ones_like(x)
       
        denor = 1.
        for i in range(1, self.n+1):
            denor *= i
            fn = fn + x.pow(i) /(denor+1e-6)

        out = fn / fn.sum(dim=self.dim, keepdims=True)
        return out
    
class LabelSmoothingLoss(nn.Module):

    def __init__(self, classes, smoothing=0.0, dim=-1): 
        super(LabelSmoothingLoss, self).__init__() 
        self.confidence = 1.0 - smoothing 
        self.smoothing = smoothing 
        self.cls = classes 
        self.dim = dim 
    def forward(self, pred, target): 
        """Taylor Softmax and log are already applied on the logits"""
        with torch.no_grad(): 
            true_dist = torch.zeros_like(pred) 
            if self.cls-1==0:
                raise Exception('self.cls = 1')
            true_dist.fill_(self.smoothing / (self.cls - 1)) 
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
    

class TaylorCrossEntropyLoss(nn.Module):

    def __init__(self, n=2, ignore_index=-1, reduction='mean', smoothing=0.2):
        super(TaylorCrossEntropyLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=-1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.lab_smooth = LabelSmoothingLoss(out_dim, smoothing=smoothing)

    def forward(self, logits, labels):
 
        log_probs = self.taylor_softmax(logits).log()

        loss = self.lab_smooth(log_probs, labels)
        return loss

In [None]:
def log_t(u, t):
    """Compute log_t for `u'."""
    if t==1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)

def exp_t(u, t):
    """Compute exp_t for `u'."""
    if t==1:
        return u.exp()
    else:
        return (1.0 + (1.0-t)*u).relu().pow(1.0 / (1.0 - t))

def compute_normalization_fixed_point(activations, t, num_iters):

    """Returns the normalization value for each example (t > 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same shape as activation with the last dimension being 1.
    """
    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations_step_0 = activations - mu

    normalized_activations = normalized_activations_step_0

    for _ in range(num_iters):
        logt_partition = torch.sum(
                exp_t(normalized_activations, t), -1, keepdim=True)
        normalized_activations = normalized_activations_step_0 * \
                logt_partition.pow(1.0-t)

    logt_partition = torch.sum(
            exp_t(normalized_activations, t), -1, keepdim=True)
    normalization_constants = - log_t(1.0 / logt_partition, t) + mu

    return normalization_constants

def compute_normalization_binary_search(activations, t, num_iters):

    """Returns the normalization value for each example (t < 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (< 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations = activations - mu

    effective_dim = \
        torch.sum(
                (normalized_activations > -1.0 / (1.0-t)).to(torch.int32),
            dim=-1, keepdim=True).to(activations.dtype)

    shape_partition = activations.shape[:-1] + (1,)
    lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
    upper = -log_t(1.0/effective_dim, t) * torch.ones_like(lower)

    for _ in range(num_iters):
        logt_partition = (upper + lower)/2.0
        sum_probs = torch.sum(
                exp_t(normalized_activations - logt_partition, t),
                dim=-1, keepdim=True)
        update = (sum_probs < 1.0).to(activations.dtype)
        lower = torch.reshape(
                lower * update + (1.0-update) * logt_partition,
                shape_partition)
        upper = torch.reshape(
                upper * (1.0 - update) + update * logt_partition,
                shape_partition)

    logt_partition = (upper + lower)/2.0
    return logt_partition + mu

class ComputeNormalization(torch.autograd.Function):
    """
    Class implementing custom backward pass for compute_normalization. See compute_normalization.
    """
    @staticmethod
    def forward(ctx, activations, t, num_iters):
        if t < 1.0:
            normalization_constants = compute_normalization_binary_search(activations, t, num_iters)
        else:
            normalization_constants = compute_normalization_fixed_point(activations, t, num_iters)

        ctx.save_for_backward(activations, normalization_constants)
        ctx.t=t
        return normalization_constants

    @staticmethod
    def backward(ctx, grad_output):
        activations, normalization_constants = ctx.saved_tensors
        t = ctx.t
        normalized_activations = activations - normalization_constants 
        probabilities = exp_t(normalized_activations, t)
        escorts = probabilities.pow(t)
        escorts = escorts / escorts.sum(dim=-1, keepdim=True)
        grad_input = escorts * grad_output
        
        return grad_input, None, None

def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example. 
    Backward pass is implemented.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return ComputeNormalization.apply(activations, t, num_iters)

def tempered_sigmoid(activations, t, num_iters = 5):
    """Tempered sigmoid function.
    Args:
      activations: Activations for the positive class for binary classification.
      t: Temperature tensor > 0.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
    return internal_probabilities[..., 0]


def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)

def bi_tempered_binary_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing = 0.0,
        num_iters=5,
        reduction='mean'):

    """Bi-Tempered binary logistic loss.
    Args:
      activations: A tensor containing activations for class 1.
      labels: A tensor with shape as activations, containing probabilities for class 1
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing
      num_iters: Number of iterations to run the method.
    Returns:
      A loss tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_labels = torch.stack([labels.to(activations.dtype),
        1.0 - labels.to(activations.dtype)],
        dim=-1)
    return bi_tempered_logistic_loss(internal_activations, 
            internal_labels,
            t1,
            t2,
            label_smoothing = label_smoothing,
            num_iters = num_iters,
            reduction = reduction)

def bi_tempered_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()

In [None]:
def reduce_fn(vals):
    # take average
    return sum(vals) / len(vals)
def train_one_epoch(model,device,loader,optimizer):
    model.train()
    train_loss = []
    bar = tqdm(loader)
    for image,label in bar:
        image = image.type(torch.FloatTensor)
        image = image.to(device)
        label = label.to(device)
    
        optimizer.zero_grad()
        y_pred = model(image)
        loss = bi_tempered_logistic_loss(y_pred, label, t1=t1, t2=t2, label_smoothing=smoothing)

        loss.backward()
        if tpu:
            xm.optimizer_step(optimizer)
        else:
            optimizer.step()
        loss_np = loss.detach().cpu().numpy()
        #xm.master_print(loss_np)
        train_loss.append(loss_np)
    return train_loss

def validate_one_epoch(model,device,loader):
    model.eval()
    val_loss = []
    sample_num = 0
    y_preds = []
    y_trues = []
    
    bar = tqdm(loader)
    for image,label in bar:
        image = image.type(torch.FloatTensor)
        image = image.to(device)
        label = label.to(device)
        pred = model(image)
        loss = bi_tempered_logistic_loss(pred,label, t1=t1, t2=t2, label_smoothing=smoothing)
        
        pred_ = torch.argmax(pred,dim=-1)
        y_preds+=list(pred_.detach().cpu().numpy())
        
        y_trues+=list(label.cpu().numpy())
        
        
        val_loss.append(loss.detach().cpu().numpy())
        
        sample_num += image.shape[0]
    val_loss = np.mean(val_loss)
    y_preds = np.asarray(y_preds)
    y_trues = np.asarray(y_trues)
    acc = np.asarray(y_trues==y_preds,dtype=np.float32)
    if tpu:
        accuracy = xm.mesh_reduce('test_accuracy', acc, np.mean)
        xm.master_print("Validation Accuracy = ",accuracy)
    else:
        accuracy = np.mean(acc)
        print('Validation Accuracy = ',accuracy)
    return val_loss,accuracy

In [None]:
model = cassavamodel(net_type, n_class=out_dim)
#model = Classifier()

In [None]:
class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

if tpu:
    init_lr = init_lr * xm.xrt_world_size()

optimizer = optim.Adam(model.parameters(), lr=init_lr)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs-warmup_epo)
scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=warmup_epo, after_scheduler=scheduler_cosine)



lrs = []
for epoch in range(1, n_epochs+1):
    scheduler_warmup.step(epoch-1)
    lrs.append(optimizer.param_groups[0]["lr"])
plt.figure(figsize=(20,3))
plt.plot(lrs)

In [None]:
def train_model(model=model):
    for fold in range(folds):
        if fold!=present_fold:
            continue
        if tpu:
            device = xm.xla_device()
        else:
            device = 'cuda'
        model = model.to(device)
        #model.load_state_dict(torch.load('../input/effnetb4/vit_large_patch16_384_best_fold2.pth'))
        optimizer =  optim.Adam(model.parameters(), lr=init_lr)

        train_idx = np.where((df_train['fold'] != fold))[0]
        valid_idx = np.where((df_train['fold'] == fold))[0]
        
        df_curr = df_train.loc[train_idx]
        df_val = df_train.loc[valid_idx]
        #print(df_curr.shape)
        
        scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs-warmup_epo)
        scheduler = GradualWarmupScheduler(optimizer, multiplier=warmup_factor, total_epoch=warmup_epo, after_scheduler=scheduler_cosine)


        dataset_train = CASSAVA(df_curr , transforms=transforms_train)
        dataset_valid = CASSAVA(df_val, transforms=transforms_valid)
        
        if tpu:

            train_sampler = torch.utils.data.distributed.DistributedSampler(
              dataset_train,
              num_replicas=xm.xrt_world_size(), 
              rank=xm.get_ordinal(),
              shuffle=True)

        train_loader = torch.utils.data.DataLoader(
            dataset=dataset_train,
            batch_size=batch_size,
            sampler=train_sampler if tpu else None,
            drop_last=True,
            num_workers=num_workers
        )
        if tpu:    
            valid_sampler = torch.utils.data.distributed.DistributedSampler(
              dataset_valid,
              num_replicas=xm.xrt_world_size(),
              rank=xm.get_ordinal(),
              shuffle=False)


        valid_loader = torch.utils.data.DataLoader(
            dataset=dataset_valid, 
            batch_size=val_batch_size,
            sampler=valid_sampler if tpu else None,
            drop_last=True,
            num_workers=num_workers,
        )

        acc_max = 0.
        
        for epoch in range(1, n_epochs+1):
            
             
            if tpu:
                train_loader = pl.MpDeviceLoader(train_loader, device)
                #train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
            train_loss = train_one_epoch(model,device,train_loader, optimizer)
    
            gc.collect()
            
            if tpu:
                valid_loader = pl.MpDeviceLoader(valid_loader, device)
                #valid_loader = pl.ParallelLoader(valid_loader, [device]).per_device_loader(device)

            val_loss, acc = validate_one_epoch(model,device,valid_loader)
            
            gc.collect()
            
            content = time.ctime() + ' ' + f'FOLD -> {fold} --> Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {np.mean(train_loss):.5f}, val loss: {np.mean(val_loss):.5f}, acc: {(acc):.5f}'

            with open(f'log_{kernel_type}.txt', 'a') as appender:
                appender.write(content + '\n')
            

        
            best_file = f'{kernel_type}_best_fold{fold}.pth'


            if acc > acc_max:
                if tpu:
                    
                    xm.rendezvous('save_model')
    
                    xm.master_print('save model')
                    xm.save(model.state_dict(), os.path.join(best_file))
                else:
                    torch.save(model.state_dict(),os.path.join(best_file))
                #xser.save(model.state_dict(), os.path.join(best_file))
                acc_max = acc
            

In [None]:
%%time

def _mp_fn(rank, flags):
    global acc_list
    torch.set_default_tensor_type('torch.FloatTensor')
    res = train_model()
if tpu:
    FLAGS={}
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')
else:
    _mp_fn(None,None)

In [None]:
f = open(f'./log_{kernel_type}.txt', "r")

print(f.read())