# Install PyTorch XLA

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

# Import PyTorch XLA

In [None]:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

In [None]:
import pandas as pd
from PIL import Image
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.metrics import roc_auc_score
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import os
import torch
import torchvision
import cv2
import numpy as np
import time
import warnings
import datetime

In [None]:
df = pd.read_csv('/kaggle/input/melanoma-merged-external-data-512x512-jpeg/folds_13062020.csv')

In [None]:
class MelanomaDataset(torch.utils.data.Dataset):
    def __init__(self, path_files, path_csv=None, pd_loaded=None, extension=".jpg", transforms=None):
        self.path_files = path_files
        self.ext = extension
        self.transforms = transforms
        
        if pd_loaded is not None:
            self.df = pd_loaded
        elif path_csv is None and pd_loaded is None:
            raise Exception("Both parameters path_csv and pd_loaded can't be none")
        else:
            self.df = pd.read_csv(path_csv)
        
    def __getitem__(self, index):
        img_path = os.path.join(self.path_files, self.df.iloc[index]['image_id'] + self.ext)
        img_arr = cv2.imread(img_path)
        img_arr_rgb = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
        
        if self.df.iloc[index]['target'] == 0:
            img_label = [1,0]
        else:
            img_label = [0,1]
    
        if self.transforms:
            sample = {'image':img_arr_rgb, 'label':img_label}
            sample = self.transforms(**sample)
            img_tens = sample['image']
            img_label = sample['label']
        else:
            img_tens = torchvision.transforms.ToTensor()(img_arr_rgb_no_hair)
            
        img_label = torch.tensor(img_label)
        
        return img_tens, img_label
    
    def __len__(self):
        return len(self.df)

In [None]:
# use albumentations to transform the images
train_transform = A.Compose(
    [A.HorizontalFlip(p=0.5),
     A.VerticalFlip(p=0.5),
     A.RandomRotate90(p=0.5),
     A.RandomBrightnessContrast(32./255, 0.5),
     A.CoarseDropout(max_holes=15, max_height=15, max_width=15),
     A.Normalize(),
     ToTensorV2()])

valid_transform = A.Compose(
    [A.Normalize(),
     ToTensorV2()])

In [None]:
# Thank you @rwightman (Ross Wightman) for that awesome libary
pip install timm

In [None]:
import timm

In [None]:
def create_model():
    model = timm.create_model("tf_efficientnet_b4", pretrained=True)
    # two classes only
    num_classes = 2
    model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
    return model

# Training function

The function that will be executed by each core is _mp_fn:
* the arguments are passed through a dict (flags in my case)
* make sure to not execute xm.xla_device() anywhere else than in _mp_fn or it will cause an error
* make sure to not execute xm.xrt_world_size() anywhere else than in _mp_fn or it will cause an error

In [None]:
def _mp_fn(index, flags):
    # set seed make to make sure all same trained
    torch.manual_seed(flags['seed'])
    # to use TPU
    device = xm.xla_device()
    
    best_score = None
    patience_counter = 0
    
    if flags['es_criterion'] == 'roc_score':
        mode = 'max'
    else: 
        mode = 'min'
        
    # samplers so each core gets its own set of data    
    train_sampler = torch.utils.data.distributed.DistributedSampler(flags['train_ds'],num_replicas=xm.xrt_world_size(),rank=xm.get_ordinal(),shuffle=True)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(flags['valid_ds'],num_replicas=xm.xrt_world_size(),rank=xm.get_ordinal(),shuffle=False)
    
            
    train_dl = torch.utils.data.DataLoader(dataset=flags['train_ds'], batch_size=flags['batch_size'], sampler=train_sampler, num_workers=flags['num_workers'])
    valid_dl = torch.utils.data.DataLoader(dataset=flags['valid_ds'], batch_size=flags['batch_size'], sampler=valid_sampler, num_workers=flags['num_workers'])
    
    model = create_model().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=flags['lr'])
    loss_criterion = torch.nn.BCEWithLogitsLoss()

    
    if flags['use_lr_scheduler']:
        if xm.is_master_ordinal():
            lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, mode=mode, patience=1, verbose=True, factor=0.2)
        else:
            lr_scheduler = ReduceLROnPlateau(optimizer=optimizer, mode=mode, patience=1, verbose=False, factor=0.2)

    
    for epoch in range(flags['num_epochs']):
        train_size = 0
        train_loss = 0.0
        train_acc = 0.0

        valid_size = 0
        valid_loss = 0.0
        valid_acc = 0.0
        
        # Training
        start_time = time.time()
        model.train()
        # only master prints
        xm.master_print('=' * 20, 'Training - begin for epoch', epoch+1, '=' * 20)
        para_loader_train = pl.ParallelLoader(train_dl, [device]).per_device_loader(device)
        for images, labels in para_loader_train:
            
            outputs = model(images)
            labels = torch.as_tensor(data=labels, dtype=torch.float32)
            loss = loss_criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            
            # stepping with TPU 
            # note for multiple cores barrier=True is not needed!
            xm.optimizer_step(optimizer)
            
            # Compute the total loss for the batch and add it to train_loss
            batch_size = images.size(0)
            train_size += batch_size
            train_loss += loss.item() * batch_size

            # Compute the accuracy
            ret, predictions = torch.max(outputs.data, 1)
            ret_, predictions_ = torch.max(labels.data, 1)
            correct_counts = predictions.eq(predictions_.data.view_as(predictions))

            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))

            # Compute total accuracy in the whole batch and add to train_acc
            train_acc += acc.item() * batch_size

            # Find average training loss and training accuracy
            avg_train_loss = train_loss / train_size
            avg_train_acc = train_acc / train_size
            
        print(f'Core: {xm.get_ordinal()} Train: Epoch {epoch+1} - loss: {avg_train_loss:.4f} - acc: {avg_train_acc:.4f} - Time: {str(datetime.timedelta(seconds=time.time() - start_time))[:7]}')
        
        if flags['validation_on']:
            # Validation
            start_time = time.time()
            model.eval()
            para_loader_valid = pl.ParallelLoader(valid_dl, [device]).per_device_loader(device)
            sigmoid = torch.nn.Sigmoid()
            val_preds = []
            val_targets = []

            for images, labels in para_loader_valid:
                labels = torch.as_tensor(data=labels, dtype=torch.float32)
                with torch.no_grad():
                    outputs = model(images)
                    outputs_sig = sigmoid(outputs)
                
                # BCEWithLogitsLoss will convert outputs to Sigmoid
                loss = loss_criterion(outputs, labels)

                # Compute the total loss for the batch and add it to train_loss
                batch_size = images.size(0)
                valid_size += batch_size
                valid_loss += loss.item() * batch_size

                # Compute the accuracy
                ret, predictions = torch.max(outputs_sig.data, 1)
                ret_, predictions_ = torch.max(labels.data, 1)
                correct_counts = predictions.eq(predictions_.data.view_as(predictions))

                # Convert correct_counts to float and then compute the mean
                acc = torch.mean(correct_counts.type(torch.FloatTensor))

                # Compute total accuracy in the whole batch and add to train_acc
                valid_acc += acc.item() * batch_size

                # Find average training loss and training accuracy
                avg_valid_loss = valid_loss / valid_size
                avg_valid_acc = valid_acc / valid_size

                val_preds.append(outputs_sig.detach().cpu().numpy())
                val_targets.append(labels.detach().cpu().numpy())

            val_preds = np.concatenate(val_preds)
            val_targets = np.concatenate(val_targets)
            # calculate roc-score
            roc_score =  roc_auc_score(val_targets, val_preds)
            # each process sends its roc-score to all other processes and build a mean with it
            # this causes a rendevouz where all processes are waiting for each other till they reach that point together
            avg_roc_score = xm.mesh_reduce('mean-roc-score', roc_score, np.mean)
            xm.master_print(f'Valid: Epoch {epoch+1} - master-loss: {avg_valid_loss:.4f} - master-acc: {avg_valid_acc:.4f} - roc: {avg_roc_score:.4f} - Time: {str(datetime.timedelta(seconds=time.time() - start_time))[:7]}')
            
            if flags['es_criterion'] == 'val_loss':
                crit_var = avg_valid_loss
            elif flags['es_criterion'] == 'roc_score':
                crit_var = avg_roc_score

            if flags['use_lr_scheduler']:
                lr_scheduler.step(crit_var)

            # early stopping
            if best_score is None:
                best_score = crit_var
                # saving the model with the master-process only
                # this causes a rendevouz where all processes are waiting for each other till they reach that point together
                xm.save(model.state_dict(), f"trained_melanoma_weights_{flags['id']+1}")
            elif best_score > crit_var and flags['es_criterion'] == 'val_loss':
                patience_counter = 0
                xm.master_print(f'Loss reduced from {best_score} ----> {crit_var}')
                best_score = crit_var
                xm.save(model.state_dict(), f"trained_melanoma_weights_{flags['id']+1}")
                xm.master_print("Saving model...")
            elif best_score < crit_var and flags['es_criterion'] == 'roc_score':
                patience_counter = 0
                xm.master_print(f'ROC Score increased from {best_score} ----> {crit_var}')
                best_score = crit_var
                xm.save(model.state_dict(), f"trained_melanoma_weights_{flags['id']+1}")
                xm.master_print("Saving model...")
            else:
                if flags['patience'] is not None and patience_counter == flags['patience']:
                    xm.master_print(f'Early Stopping at Epoch {epoch+1}')
                    break
                patience_counter += 1

# Call training with all 8 cores

Training Foldwise, for each model I run the notebook all over again so I can train each model for 3 hours instead of pushing all models to train within 3 hours. id_ has to be 0 <= id_ < num_folds

In [None]:
flags={}
flags['num_epochs'] = 8
flags['es_criterion'] = 'roc_score'
flags['validation_on'] = True
flags['use_lr_scheduler'] = True
flags['patience'] = None
flags['seed'] = 1234
flags['batch_size'] = 5
flags['num_workers'] = 8
flags['num_folds'] = 5
flags['lr'] = 0.0001

# GroupKFold
gkf = GroupKFold(n_splits=flags['num_folds'])

# Foldwise Model Training, increment of id_ and rerun will create new model
# Model 1
id_ = 0
train_index, test_index = list(gkf.split(np.zeros(len(df)), df['target'], list(df['patient_id'])))[id_]
print('-' * 20, 'Training Model', id_+1, '-' * 20)
train_ds = MelanomaDataset("/kaggle/input/melanoma-external-data-jpeg-384x384/melanoma/", pd_loaded=df.iloc[train_index], transforms=train_transform)
valid_ds = MelanomaDataset("/kaggle/input/melanoma-external-data-jpeg-384x384/melanoma/", pd_loaded=df.iloc[test_index],  transforms=valid_transform)
flags['id'] = id_
flags['train_ds'] = train_ds
flags['valid_ds'] = valid_ds
# using 8 cores running through _mp_fn()
if __name__ == '__main__':
    xmp.spawn(_mp_fn, args=(flags,), nprocs=8, start_method='fork')