In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py  > /dev/null
!python pytorch-xla-env-setup.py --version nightly  > /dev/null

In [None]:
!pip install timm  > /dev/null

In [None]:
import gc
import os
import time
import torch
import albumentations

import numpy as np
import pandas as pd

import cv2
from PIL import Image

import torch.nn as nn
from sklearn import metrics
from sklearn import model_selection
from torch.nn import functional as F
from torch.optim import Adam

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

import timm

import warnings
warnings.filterwarnings("ignore")

os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

In [None]:
FLAGS = {'fold':0,
         'model':'resnext50_32x4d',
         'pretrained': True,
         'batch_size':128,
         'num_workers':4,
         'lr':3e-4,
         'epochs':10
        }

In [None]:
class TimmModels(nn.Module):
    def __init__(self, model_name, pretrained=True, num_classes=5):
        super(TimmModels, self).__init__()
        self.m = timm.create_model(model_name, pretrained=pretrained)
        model_list = list(self.m.children())
        model_list[-1] = nn.Linear(in_features=model_list[-1].in_features,
                                   out_features=num_classes,
                                   bias=True
                                  )
        self.m = nn.Sequential(*model_list)
        
    def forward(self, image):
        out = self.m(image)
        return out

In [None]:
class ImageDataset:
    def __init__(self,
                 image_paths,
                 targets,
                 resize,
                 augmentations=None,
                 backend='pil',
                 channel_first=True
                ):
        """
        :param image_paths: list of paths to images
        :param targets: numpy array
        :param resize: tuple or None
        :param augmentations: albumentations augmentations
        """
        
        self.image_paths = image_paths
        self.targets = targets
        self.resize = resize
        self.augmentations = augmentations
        self.backend = backend
        self.channel_first = channel_first
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, item):
        targets = self.targets[item]
        
        if self.backend == 'pil':
            image = Image.open(self.image_paths[item])
            
            if self.resize is not None:
                image = image.resize((self.resize[1], self.resize[0]),
                                     resample=Image.BILINEAR
                                    )
                
            image = np.array(image)
            
            if self.augmentations is not None:
                augmented = self.augmentations(image=image)
                image = augmented['image']
                
        elif self.backend == 'cv2':
            image = cv2.imread(self.image_paths[item])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            if self.resize is not None:
                image = cv2.resize(image, 
                                   (self.resize[1], self.resize[0]),
                                   interpolation=cv2.INTER_CUBIC
                                  )
                
            if self.augmentations is not None:
                augmented = self.augmentations(image=image)
                
            image = augmented['image']
            
        else:
            raise Exception("Backend not implemented")
            
        if self.channel_first:
            image = np.transpose(image, (2,0,1)).astype(np.float32)
            
        return {"image":torch.tensor(image),
                "targets":torch.tensor(targets)
               }

In [None]:
df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
df['kfold'] = -1

df = df.sample(frac=1).reset_index(drop=True)
# shuffling

y = df.label.values
skf = model_selection.StratifiedKFold(n_splits=5)

for f, (t_, v_) in enumerate(skf.split(X=df, y=y)):
    df.loc[v_, 'kfold'] = f
    
df.to_csv('train_folds.csv', index=False)

In [None]:
MX = xmp.MpModelWrapper(TimmModels(
    FLAGS['model'],
    pretrained=FLAGS['pretrained'],
    num_classes=5))

In [None]:
def train_loop_fn(data_loader, 
                  loss_fn, 
                  model,
                  optimizer,
                  device,
                  scheduler=None):
    model.train()
    for bi, d in enumerate(data_loader):
        
        images = d['image'].to(device, dtype=torch.float32)
        targets = d['targets'].to(device, dtype=torch.int64)
        # why int64?
        
        optimizer.zero_grad()
        outputs = model(images)
        
        loss = loss_fn(outputs, targets)
        
        loss.backward()
        
        # Use PyTorch XLA optimizer stepping
        xm.optimizer_step(optimizer, barrier=True)
        
        if scheduler is not None: scheduler.step()
            
    loss_reduced = xm.mesh_reduce(
        'loss_reduce',
        loss,
        lambda x: sum(x) / len(x)
                                 )
    
    xm.master_print(f'bi={bi}, train loss={loss_reduced}')
    
    model.eval()
    
def eval_loop_fn(data_loader, 
                 loss_fn, 
                 model, 
                 device):
    
    fin_targets = []
    fin_outputs = []
    
    for bi, d in enumerate(data_loader):
        
        images = d['image'].to(device)
        targets = d['targets'].to(device)
        
        with torch.no_grad(): outputs = model(images)
            
        targets_np = targets.cpu().detach().numpy().tolist()
        outputs_np = outputs.cpu().detach().numpy().tolist()
        
        fin_targets.extend(targets_np)
        fin_outputs.extend(outputs_np)
        
        del targets_np, outputs_np
        
        gc.collect()
        
    o, t = np.array(fin_outputs), np.array(fin_targets)
    
    loss = loss_fn(torch.tensor(o), torch.tensor(t))
    
    loss_reduced = xm.mesh_reduce('loss_reduce',
                                  loss,
                                  lambda x: sum(x) / len(x)
                                 )
    
    xm.master_print(f'val. loss = {loss_reduced}')
    
    acc = metrics.accuracy_score(t,o.argmax(axis=1))
    acc_reduced = xm.mesh_reduce('acc_reduce',
                                  acc,
                                  lambda x: sum(x) / len(x)
                                 )
    
    xm.master_print(f'val. accuracy = {acc_reduced}')

In [None]:
def run(rank, flags):
    global FLAGS
    
    torch.set_default_tensor_type('torch.FloatTensor')
    
    xm.master_print("let's start!")
    training_data_path = '../input/cassava-jpeg-256x256/kaggle/train_images_jpeg'
    df = pd.read_csv("/kaggle/working/train_folds.csv")
    
    device = xm.xla_device()
    
    epochs = FLAGS['epochs']
    fold = FLAGS['fold']
    
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    
    train_aug = albumentations.Compose(
        [
            albumentations.Normalize(
                mean,
                std,
                max_pixel_value=255.0,
                always_apply=True
            ),
            albumentations.Transpose(p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.HueSaturationValue(
                hue_shift_limit=0.2,
                sat_shift_limit=0.2,
                val_shift_limit=0.2,
                p=0.5
            ),
            albumentations.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1),
                contrast_limit=(-0.1,0.1),
                p=0.5
            ),
            albumentations.CoarseDropout(p=0.5),
            albumentations.Cutout(p=0.5)
        ]
    )
    
    valid_aug = albumentations.Compose(
        [
            albumentations.Normalize(
                mean,
                std,
                max_pixel_value=255.0,
                always_apply=True
            )
        ]
    )
    
    train_images = df_train.image_id.values.tolist()
    train_images = [
        os.path.join(training_data_path, i) for i in train_images
    ]
    
    train_targets = df_train.label.values
    
    valid_images = df_valid.image_id.values.tolist()
    valid_images = [
        os.path.join(training_data_path, i) for i in valid_images
    ]
    valid_targets = df_valid.label.values
    
    train_dataset = ImageDataset(
        image_paths=train_images,
        targets=train_targets,
        resize=None,
        augmentations=train_aug
    )
    
    valid_dataset = ImageDataset(
        image_paths=valid_images,
        targets=valid_targets,
        resize=None,
        augmentations=valid_aug
    )
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=FLAGS['batch_size'],
        sampler=train_sampler,
        num_workers=FLAGS['num_workers'],
        drop_last=True
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=FLAGS['batch_size'],
        sampler=valid_sampler,
        num_workers=FLAGS['num_workers'],
        drop_last=False
    )
    
    mp_device_loader = pl.MpDeviceLoader(
        train_loader,
        device,
        fixed_batch_size=True
    )
    
    model = MX.to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(),
                     lr=FLAGS['lr']*xm.xrt_world_size()
                    )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           len(train_loader)*FLAGS['epochs']
                                                          )
    
    xm.master_print(f'========== training fold {FLAGS["fold"]} for {FLAGS["epochs"]} epochs ===========')
    
    for i in range(FLAGS['epochs']):
        xm.master_print(f'EPOCH {i}:')
        
        train_loop_fn(train_loader, loss_fn, model, optimizer, device, scheduler)
        
        eval_loop_fn(valid_loader, loss_fn, model, device)
        
    xm.master_print('save model')
    
    xm.save(model.state_dict(), f'xla_trained_model_{FLAGS["epochs"]}_epochs_fold_{FLAGS["fold"]}.pth')

In [None]:
start_time = time.time()
xmp.spawn(run, args=(FLAGS,), nprocs=8, start_method='fork')
print('time taken: ', time.time()-start_time)