In [None]:
# Installing Pytorch-XLA and other dependencies
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')

In [None]:
import os
os.environ['XLA_USE_BF16'] = "1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = "100000000"

import gc
import random
import timeit
import time
import datetime
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List
from tqdm.notebook import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

# pytorch imports
import torch
import torch.optim as optim
from torchvision import transforms
from torchvision.utils import make_grid
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

import timm

# for TPU
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp


plt.style.use('bmh')
plt.rcParams['figure.figsize'] = [20, 13]
SEED = 421

In [None]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(SEED)

In [None]:
DATA_DIR =  "../input/ranzcr-clip-catheter-line-classification/"
FOLDS_DIR = "../input/ranzcr-folds/"


MODEL_NAME = "efficientnet_b7"
# MODEL_NAME = "resnet200d"

EPOCHS = 30
INIT_LR = 1e-3
BATCH_SIZE = 16
IMAGE_SIZE = 600
PRINT_EVERY = 50 # how often to print the losses/metric scores
STOP_TRAINING_AFTER = 5 # number of no-improvement epochs to wait before stopping training. 

The folds are taken from Abhishek's [kernel](https://www.kaggle.com/abhishek/ranzcr-tez-training-efficientnet-5)

In [None]:
sub_df = pd.read_csv(DATA_DIR + 'sample_submission.csv')
folds_df = pd.read_csv(FOLDS_DIR + "train_folds.csv")

In [None]:
CLASSES = [col for col in folds_df.columns if col not in ['StudyInstanceUID', 'PatientID', 'kfold']]
CLASSES

In [None]:
folds_df.head()

In [None]:
# we will perform validation on fold 0 and train the model on the remaining 4 folds
valid_fold = 0

train_df = folds_df[folds_df.kfold != valid_fold].reset_index(drop=True)
valid_df = folds_df[folds_df.kfold == valid_fold].reset_index(drop=True)


print ("NUMBER OF SAMPLES IN:\n")
print (f"Training: {train_df.shape[0]}")
print (f"Validation: {valid_df.shape[0]}")
print (f"Testing: {sub_df.shape[0]}")

### Utility function

In [None]:
def plot_input_images(imgs: torch.Tensor, title_string: str, nrow: int = 4) -> None:
    image_grid = make_grid(imgs, nrow=nrow, padding=10, pad_value=1)
    
    # transform from CHW -> HWC
    plt.imshow(image_grid.permute(1, 2, 0), cmap=plt.cm.bone)
    plt.title(title_string)

### Dataset class

In [None]:
class RanczrDataset(Dataset):
    def __init__(self, df, data_dir, transform=None):
        super().__init__()
        self.df = df
        self.data_dir = data_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img_name = self.df['StudyInstanceUID'][index]
        targets = self.df.loc[index, CLASSES].values.astype(np.uint8)
        
        targets = torch.from_numpy(targets)
        
        img_path = os.path.join(self.data_dir, img_name+".jpg")
        
        image = np.array(Image.open(img_path))
        
        if self.transform:
            image = self.transform(image=image)
        
        return image, targets

### Transformations

In [None]:
train_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE, always_apply=True),
    A.CLAHE(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(mean=[0.485], std=[0.229], max_pixel_value=255.0, p=1.0),
    ToTensorV2(),
])

valid_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE, always_apply=True),
    A.Normalize(mean=[0.485], std=[0.229], max_pixel_value=255.0, p=1.0),
    ToTensorV2(),
])

### Create Datasets

In [None]:
# Creating Pytorch Datasets
train_dataset = RanczrDataset(df=train_df, data_dir=DATA_DIR+"train/", transform=train_transform)
validation_dataset = RanczrDataset(df=valid_df, data_dir=DATA_DIR+"train/", transform=valid_transform)

# Create Dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=16, drop_last=True, num_workers=0)

inputs, targets = next(iter(train_loader))

# plot input images
plot_input_images(inputs['image'], title_string='batch images', nrow=8)

del train_loader
gc.collect()

In [None]:
valid_loader = DataLoader(dataset=validation_dataset, batch_size=16, drop_last=True, num_workers=0)

inputs, targets = next(iter(valid_loader))

# plot input images
plot_input_images(inputs['image'], title_string='validation batch images', nrow=8)

del valid_loader
gc.collect()

### Model

In [None]:
class RanczrModel(nn.Module):
    def __init__(self):
        super(RanczrModel, self).__init__()
        
        self.effnet = timm.create_model(MODEL_NAME, pretrained=True, in_chans=1).as_sequential()[:-2]
        
        self.dropout = nn.Dropout(p=0.5)
        self.dense = nn.Linear(2560, len(CLASSES))
        
        
    def forward(self, images):
        pooled_features= self.effnet(images)
        
        outputs = self.dense(self.dropout(pooled_features))
        
        return outputs

In [None]:
effnet = RanczrModel()

# we want to train the models from scratch hence unfreezing all the weights
for param in effnet.parameters():
    param.requires_grad = True

In [None]:
effnet

In [None]:
# objective function
def loss_func(predictions, targets):
    return nn.BCEWithLogitsLoss()(predictions, targets)


# reduction function for xla
def reduce_func(vals):
    # averaging the loss over all TPU cores
    return sum(vals) / len(vals)


# evaluation metric: multilabel AUC
def multilabel_auc(targets: torch.Tensor, predictions: torch.Tensor) -> float:
#     xm.master_print (f'Shapes=> {targets.shape}\t{predictions.shape}')
    auc = 0
    for j in range(targets.shape[1]):
#         xm.master_print(f'Target class counts=> 1: {(targets[:, j]==1).sum()} 0: {(targets[:, j]==0).sum()}')
        try:
            auc += roc_auc_score(targets[:, j], predictions[:, j])
        except ValueError:
            # code will reach at this point when there is one one target class i.e all are 1s or 0s
            assert torch.unique(targets[:, j]).shape[0] == 1, "There should be only one unique value present in the target tensor"
            target_val = torch.unique(targets[:, j])
            
            # we will add one element to the target tensor with a value not present in the target tensor
            target_tensor = torch.cat([targets[:, j], 1-target_val])
            
            # for the prediction column we can take the mean of the predictions and use that as the prediction value for the new target we just added
            mean_prediction = torch.mean(predictions[:, j], dim=0, keepdim=True)
            
            predictions_tensor = torch.cat([predictions[:, j], mean_prediction])
            
            auc += roc_auc_score(target_tensor, predictions_tensor)
    # calculate average over all the classes
    return auc / len(CLASSES)

### Train and validation methods

In [None]:
def train_loop_fn(data_loader, model, optimizer, device):
    # setting the model on train mode
    model.train()
    
    train_loss_history = [] # this will contain the train loss for each step in the current epoch. These values will later be averaged over the entire epoch to calculate mean_train_loss


    for batch_ix, (inputs, targets) in enumerate(data_loader):
        images, targets = inputs['image'].to(device), targets.float().to(device)
        
        assert images is not None, "input images are None"
        assert targets is not None, "targets are None"
        # zeroing the gradients
        optimizer.zero_grad()
        
        # forward pass
        predictions = model(images)
        
        # calculate loss
        train_loss = loss_func(predictions, targets)
            
        # perform mean of training loss over all the 8 TPU cores and return the value in train_loss_reduce variable; this value will be same for all the cores
        train_loss_reduce = xm.mesh_reduce('train_loss_reduce', train_loss.item(), reduce_func)
        
        train_loss_history.append(train_loss_reduce)
        
        # print loss averaged over all TPU cores after `print_every` steps
        if batch_ix % PRINT_EVERY == 0:    
            # master_print will only print once (not from all 8 cores)
            xm.master_print(f"batch_ix={batch_ix}, train_loss={train_loss_reduce:.4f}")
        
        # backward pass
        train_loss.backward()
        
        # gradient clipping
#         grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        
        # DEBUGGING        
        if model.effnet[0].weight.isnan().any():
            sys.exit("NaN encountered!!!")
        # Use Pytorch XLA optimizer for applying the weight updates
        xm.optimizer_step(optimizer)
        

    return train_loss_history

def valid_loop_fn(data_loader, model, device):
    # setting the model in eval() model for validation step
    model.eval()
    
    val_loss_history = [] # this will contain the validation loss for each step in the current epoch. These values will later be averaged over the entire epoch to calculate mean_val_loss 
    
    all_targets = None
    all_predictions = None
    for batch_ix, (inputs, targets) in enumerate(data_loader):
        images, targets = inputs['image'].to(device), targets.float().to(device)
        
        # make prediction on a hold-out validation set
        predictions = model(images)
        
        val_loss = loss_func(predictions.detach(), targets)
        val_loss_reduce = xm.mesh_reduce('val_loss_reduce', val_loss.item(), reduce_func)
        val_loss_history.append(val_loss_reduce)
        
        if all_predictions is None:
            all_predictions = predictions.detach().cpu()
            all_targets = targets.cpu()
        else:
            all_predictions = torch.cat((all_predictions, predictions.detach().cpu()), dim=0)
            all_targets = torch.cat((all_targets, targets.cpu()), dim=0)
        
    val_auc = multilabel_auc(all_targets, all_predictions)
    val_auc_reduce = xm.mesh_reduce('val_auc_reduce', val_auc, reduce_func)
    
    return val_loss_history, val_auc_reduce


In [None]:
def _run():
    # this is the main function that calls the above functions. This function will be spawned by Pytorch XLA multiprocessing.
    # this function will be run on each of the 8 cores
    
    # We need to define a data sampler to appropriately distribute the data across 8 cores
    train_sampler = DistributedSampler(
                    train_dataset,
                    num_replicas=xm.xrt_world_size(), # tell Pytorch how many devices (TPU cores) we are using for training
                    rank=xm.get_ordinal(), # tell Pytorch which device (core) we are on currently,
                    shuffle=True, # sampler will shuffle the indices
                )

    train_data_loader = DataLoader(
                        train_dataset,
                        batch_size=BATCH_SIZE,
                        sampler=train_sampler,
                        drop_last=True,
                        num_workers=0, # We will use only the main process to load the data hence saving up a lot of VM's memory 
    )
    
    valid_sampler = DistributedSampler(
                    validation_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal(),
                    shuffle=False,
                )
    
    valid_data_loader = DataLoader(
                        validation_dataset,
                        batch_size=4,
                        sampler=valid_sampler,
                        drop_last=False,
                        num_workers=0,
                )

    device = xm.xla_device() # our device (single TPU core)
    model = effnet.to(device) # put model on single TPU core
    
    xm.master_print("Model loading on TPU completed")
    # Whatever needs to printed only once for all 8 TPU cores, should be printed using `master_print` function
    xm.master_print("Training initiated")
    
#     scaled_lr = LR * xm.xrt_world_size()# scale the learning rate as per the number of devices
    # calculate the total number of training steps
    num_train_steps = int(len(train_dataset) / BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    
    # define the optimizer; CHECK THE HYPERPARAMETERS OF THIS OPTIMIZER AS WELL
    optimizer = optim.Adam(model.parameters(), lr=INIT_LR)
    
    # learning rate scheduler; CHECK THE HYPERPARAMETERS OF THIS SCHEDULER AS WELL
#     scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS, eta_min=1e-7)
    
#     scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=warmup_epo, after_scheduler=scheduler_cosine)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=2, mode='max', verbose=True)
    
    xm.master_print(f"num_training_steps = {num_train_steps}, world_size={xm.xrt_world_size()}")
    
    max_mean_auc_score = -np.Inf

    epoch_since_last_improvement = 0
    stop_training = False

    # initiate training
    for epoch in range(1, EPOCHS+1):
        xm.master_print(f"EPOCH: {epoch}/{EPOCHS}")
        epoch_start_tick = timeit.default_timer() # record time at the start of the epoch
        gc.collect()
        # we will encapsulate our dataloader with Pytorch XLA's ParallelLoader for TPU-core-specific dataloading
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        xm.master_print("Training Parallel loader created...\nTraining now")
        gc.collect()
        
        # call training function
        train_loss_hist = train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device)
        del para_loader
        gc.collect()        
        
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        xm.master_print("Validation Parallel loader created...\nValidating now")
        gc.collect()
        
        # call evaluation function
        val_loss_hist, epoch_mean_auc_score = valid_loop_fn(para_loader.per_device_loader(device), model, device)
        del para_loader
        
        gc.collect()
        
        # print train loss, validation loss and validation AUC at the end of each epoch
        xm.master_print(f"EPOCH: {epoch}==> mean_train_loss: {np.mean(train_loss_hist):.4f}, mean_valid_loss: {np.mean(val_loss_hist):.4f}, mean_valid_auc: {epoch_mean_auc_score:.4f}")
        gc.collect()
        
        elapsed = timeit.default_timer() - epoch_start_tick
        elapsed = str(datetime.timedelta(seconds=elapsed)).split('.')[0]
        
        xm.master_print(f"Time taken for epoch: {elapsed}\n")
        
#         scheduler_warmup.step(epoch-1)
        
        # decrease the learning rate if the epoch mean auc score does not improve as per the schedule specified in the scheduler
        if scheduler:
            scheduler.step(epoch_mean_auc_score)
        
        # model saving code
        if epoch_mean_auc_score >= max_mean_auc_score:
            xm.master_print(f"Validation AUC score increased ({max_mean_auc_score:.6f} -> {epoch_mean_auc_score:.6f}) at EPOCH {epoch}\n  Saving model ...\n")
            xm.save(model.state_dict(), 'timm-effnet-b7-res600-final.pt')
            max_mean_auc_score = epoch_mean_auc_score
            epoch_since_last_improvement = 0

        # check if validation auc didn't improve
        if epoch_mean_auc_score < max_mean_auc_score:
            epoch_since_last_improvement+=1
            xm.master_print(f'{epoch_since_last_improvement} epochs have finished since the last improvement in val AUC')

            if epoch_since_last_improvement > STOP_TRAINING_AFTER:
                xm.master_print('Stopping training prematurely due to no improvement in val AUC')
                stop_training = True
        if stop_training:
            break
    xm.master_print(f'Best Validation AUC is {max_mean_auc_score:.6f}')
    gc.collect()


In [None]:
# start training process; we need to spawn the training processes on each of the TPU cores.
def _map_fn(rank, flags):
    a = _run()

FLAGS = {}
start_time = timeit.default_timer()
xmp.spawn(_map_fn, args=(FLAGS,), nprocs=8, start_method='fork')