![](https://cloud.google.com/tpu/docs/images/tpu--sys-arch4.png)

# <div align = 'center'><u> Simple PyTorch Pipeline  </u></div>
<div align = 'center'>Using all 8 cores of <b>TPU</b></div>

# Table of contents <a id='0.1'></a>

1. [Introduction](#1)
2. [Import Packages](#2)
3. [Loading Data](#3)
4. [Model](#4)
5. [Training on 8-cores of TPU](#5)
7. [References](#7)

**Version Notes**:-
* Version 12  : Not using bfloat16, since it is giving loss in a whole number format. Also, not using internet in this version.
* Version 16  : Code for training the model on 1-core of the TPU and saving it for the inference.
* Version 19  : Code for training the model on all 8-cores of the TPU and saving it for the inference.
* Version 23  : Solved a bug (in Version 22, we were normalizing the mask ðŸ˜› and that was a mistake) and also trained the model with 25 epochs. In this version, you will see that loss drops at faster rate as compared to previous version. ðŸ™‚ 
* Version 24  : Added augs, using resnet34 encoder, increased learning rate. 
* Version 25  : Using seresnext50 encoder and Cosine Annealing LR scheduler. Let's train for 150 epochs.
* Version 31  : Added Group K-Fold validation, dice metric and curated comments/documentation.

**Note:- Inference notebook for this TPU Pytorch pipeline will be out soon, with a decent LB score.**

*Update:
Experimented with the private copy of this NB with 4-fold resnext50 Unet model, 256 img size, Cosine Annealing LR scheduler, 60 epochs per fold and got **0.840** public LB score.* 

# 1. <a id='1'>Introduction</a>

This notebook will show you how to use TPU with PyTorch in this competition.

Traning with 1-core of the TPU is demonstrated in the [16th version of this NB](https://www.kaggle.com/joshi98kishan/hubmap-pytorch-tpu-segmentation?scriptVersionId=48001107). In this NB, **we will be training the model on all the 8-cores of the TPU.**

The TPU which is provided with the kaggle kernel is **TPU v3-8** (3 for 3rd generation and 8 for 8 cores).

So, there are total 8 cores in this TPU. Each of the cores can run computations independently. 

The heavylisting of doing multicore training is done by PyTorch XLA's modules :
`DistributedSampler` and `ParallelLoader`. 

### Installing PyTorch XLA

In order to use TPU with PyTorch, we have to install [PyTorch XLA](https://github.com/pytorch/xla/) library. 

We can install this library online by running the below commands:
* `!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py`
* `!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev`

Note:- Turn on the TPU before going ahead, otherwise you will get the error, '*Missing XLA configuration*'.

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

# Removing setup-script and wheel files, as we don't need them now.
!rm pytorch-xla-env-setup.py
!rm torch-nightly-cp37-cp37m-linux_x86_64.whl 
!rm torch_xla-nightly-cp37-cp37m-linux_x86_64.whl 
!rm torchvision-nightly-cp37-cp37m-linux_x86_64.whl

We will be training the model with [U-Net](https://arxiv.org/pdf/1505.04597.pdf) architecture having the pretrained encoder (resnet, serexnet etc.). This model is implemented in PyTorch ([here](https://github.com/qubvel/segmentation_models.pytorch)) in the form of library called **segmentation_models_pytorch**. 

Just to show how to install it offline, we will be installing it from source, becasue we will be needing this library in inference kernel and there we won't be able to use internet. The below commands and the associated [dataset](https://www.kaggle.com/vineeth1999/segmentationmodelspytorch) is borrowed from this [notebook](https://www.kaggle.com/vineeth1999/hubmap-pytorch-efficientunet-offline).

In [None]:
!mkdir -p /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/efficientnet_pytorch-0.6.3.xyz /tmp/pip/cache/efficientnet_pytorch-0.6.3.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/pretrainedmodels-0.7.4.xyz /tmp/pip/cache/pretrainedmodels-0.7.4.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/segmentation-models-pytorch-0.1.2.xyz /tmp/pip/cache/segmentation_models_pytorch-0.1.2.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.1.20-py3-none-any.whl /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.2.1-py3-none-any.whl /tmp/pip/cache/
!pip install --no-index --find-links /tmp/pip/cache/ efficientnet-pytorch
!pip install --no-index --find-links /tmp/pip/cache/ segmentation-models-pytorch

# 2. <a id='2'>Import Packages</a>
[Table of contents](#0.1)

In [None]:
import os, gc
import numpy as np 
import pandas as pd 
from tqdm.notebook import tqdm
import tifffile as tiff
import matplotlib.pyplot as plt
import cv2
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from albumentations import *

from sklearn.model_selection import GroupKFold

from segmentation_models_pytorch.unet import Unet
from segmentation_models_pytorch.encoders import get_preprocessing_fn

import torch_xla
import torch_xla.core.xla_model as xm
from torch.utils.data.distributed import DistributedSampler
import torch_xla.distributed.parallel_loader as pl                             
import torch_xla.distributed.xla_multiprocessing as xmp

import warnings
warnings.filterwarnings("ignore")

Before going further, let's have a quick look on what is **bfloat16 floating-point format**.

> This format is a truncated (16-bit) version of the 32-bit IEEE 754 single-precision floating-point format (binary32) with the intent of accelerating machine learning and near-sensor computing.

> Bfloat16 is used to reduce the storage requirements and increase the calculation speed of machine learning algorithms. 

[wiki](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format)

In [None]:
'''Note:- Check the version notes above.'''

# *** We set this environment variable, in order to make TPU use bfloat16. ***
# os.environ['XLA_USE_BF16']="1"

In [None]:
def set_all_seeds(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
set_all_seeds(2020)

# 3. <a id='3'>Loading Data</a>
[Table of contents](#0.1)

In [None]:
DATA_PATH = '/kaggle/input/hubmap-256x256'
os.listdir(DATA_PATH)

We are using iafoss's [256x 256 data](https://www.kaggle.com/iafoss/hubmap-256x256).

In [None]:
PATH_TRAIN = os.path.join(DATA_PATH, 'train')
PATH_MASKS = os.path.join(DATA_PATH, 'masks')

print(f'No. of training images: {len(os.listdir(PATH_TRAIN))}')
print(f'No. of masks: {len(os.listdir(PATH_MASKS))}')

In [None]:
class HuBMAPDataset(Dataset):
    def __init__(self, 
                 data_path, 
                 fnames, 
                 preprocess_input = None,
                 transforms = None):
        self.data_path = data_path
        self.fnames = fnames
        self.preprocess_input = preprocess_input
        self.transforms = transforms
    
    def __len__(self):
        return len(self.fnames)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.data_path, 'train', self.fnames[idx])
        mask_path = os.path.join(self.data_path, 'masks', self.fnames[idx])
        
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        
        if self.transforms:
            # Applying augmentations if any. 
            sample = self.transforms(image = img, 
                                     mask = mask)
            
            img, mask = sample['image'], sample['mask']
            
        if self.preprocess_input:
            # Normalizing the image with the given mean and
            # std corresponding to each channel.
            img = self.preprocess_input(image = img)['image']
            
        # PyTorch assumes images in channels-first format. 
        # Hence, bringing the channel at the first place.
        img = img.transpose((2, 0, 1))
        
        img = torch.from_numpy(img)
        mask = torch.from_numpy(mask)
            
        return img, mask

In [None]:
# Images and its corresponding masks are saved with the same filename.
fnames = np.array(os.listdir(PATH_TRAIN))

groups = [fname[:9] for fname in fnames]
    
group_kfold = GroupKFold(n_splits = 4)

**Let's also have look on preprocessing.**

We are using pretrained model trained on imagenet images, normalized with these *mean* and *standard deviation*. 

mean : [0.485, 0.456, 0.406]

std : [0.229, 0.224, 0.225]

First the whole image is divided by 255. Then from each channel, the corresponding mean is subtracted.
And lastly, each channel is divided by its corresponding standard deviation (std).


**img = ( (img/255) - mean)/std**

In [None]:
# You can select and experiment with other encoder from this list:
# https://github.com/qubvel/segmentation_models.pytorch#encoders

ENCODER_NAME = 'se_resnext50_32x4d'



'''
So, we have to do the same kind of preprocessing while doing transfer learning. 

The 'get_preprocessing_fn()' function will return the preprocessing_fn corresponding to the 
mentioned model and then we are using it as Albumentation like transform by wrapping it with 
a Lambda function.
'''
preprocessing_fn = Lambda(image = get_preprocessing_fn(encoder_name = ENCODER_NAME,
                                                       pretrained = 'imagenet'))



# https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter
transforms = Compose([
                HorizontalFlip(),
                VerticalFlip(),
                RandomRotate90(),
                ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, 
                                 border_mode=cv2.BORDER_REFLECT),
                OneOf([
                    OpticalDistortion(p=0.3),
                    GridDistortion(p=.1),
                    IAAPiecewiseAffine(p=0.3),
                ], p=0.3),
                OneOf([
                    HueSaturationValue(10,15,10),
                    CLAHE(clip_limit=2),
                    RandomBrightnessContrast(),            
                ], p=0.3),
            ], p=1.0)

# 4. <a id='4'>Model</a>
[Table of contents](#0.1)

### U-Net model

U-Net architecutre

![](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

In [None]:
class HuBMAPModel(nn.Module):
    def __init__(self):
        super(HuBMAPModel, self).__init__()
        self.model = Unet(encoder_name = ENCODER_NAME, 
                          encoder_weights = 'imagenet',
                          classes = 1,
                          activation = None)
        
        
    def forward(self, images):
        img_masks = self.model(images)
        return img_masks

In this competition, we have dice coefficient as the evaluation metric.

**More the dice coefficient, better the predictions.**

Dice coefficient is pretty much straight forward.
![](https://i.stack.imgur.com/OsH4y.png)

For the training, we have to define the loss function. We can define the dice loss function like this:

**dice loss = (1 - dice_coefficient)**

**Lesser the dice loss, bigger the dice coefficient.**

In [None]:
#https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice
    
    
loss_fn = DiceLoss()

In [None]:
def get_dice_coeff(pred, targs):
    '''
    Calculates the dice coeff of a single or batch of predicted mask and true masks.
    
    Args:
        pred : Batch of Predicted masks (b, w, h) or single predicted mask (w, h)
        targs : Batch of true masks (b, w, h) or single true mask (w, h)
  
    Returns: Dice coeff over a batch or over a single pair.
    '''
    
    
    pred = (pred>0).float()
    return 2.0 * (pred*targs).sum() / ((pred+targs).sum() + 1.0)

    
def reduce(values):
    '''    
    Returns the average of the values.
    Args:
        values : list of any value which is calulated on each core 
    '''
    return sum(values) / len(values)

In [None]:
# Encoder pretrained weights takes times to download, therefore taking the weigths from kaggle dataset

!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp ../input/pretrained-model-weights-pytorch/se_resnext50_32x4d-a260b3a4.pth /root/.cache/torch/hub/checkpoints/

# 5. <a id='5'>Training on 8-cores of TPU</a>
[Table of contents](#0.1)

Let's see how Cosine Annealing LR scheduler changes learning rate with the epochs.

In [None]:
# Did some experiments and come up with this set of hyperparams.

demo_model = HuBMAPModel()

# LR which we set here is the highest LR value.
demo_optim = optim.SGD(demo_model.parameters(),
                      lr = 0.3,
                      momentum = 0.9)
DEMO_EPOCHS = 60

# Change the T_0 argument to get more number of cycles.
demo_sched = CosineAnnealingWarmRestarts(demo_optim, 
                                         T_0 = DEMO_EPOCHS//3, 
                                         T_mult=1, 
                                         eta_min=0, 
                                         last_epoch=-1, 
                                         verbose=False)
lrs = []

for i in range(DEMO_EPOCHS):
    demo_optim.step()
    lrs.append(demo_optim.param_groups[0]["lr"])
    demo_sched.step()

plt.plot(lrs)
plt.show()

In [None]:
# Updated
def train_one_epoch(epoch_no, data_loader, model, optimizer, device, scheduler = None):
    '''
    Run one epoch on the 'model'.
    Args:
        epoch_no: Serial number of the given epoch
        data_loader: Data iterator like DataLoader
        model: Model which needs to be train for one epoch
        optmizer : Pytorch's optimizer
        device: Device(A particular core) on which to run this epoch
        Scheduler : Pytorch's lr scheduler
        
        
    Returns: Nothing
    '''
    model.train()
    losses = []
    dice_coeffs = []
    
    for i, batch in enumerate(data_loader):
        img_btch, mask_btch = batch
        
        # *** Putting the images and masks to the TPU device. ***
        img_btch = img_btch.to(device)
        mask_btch = mask_btch.to(device)
        
        optimizer.zero_grad()
        
        pred_mask_btch = model(img_btch.float())
        
        loss = loss_fn(pred_mask_btch, mask_btch.float())
        
        loss.backward()
        
        
        '''
        xm.optimizer_step():
        Consolidates the gradients between cores and issues the XLA device step computation.
        The `step()` function now not only propagates gradients, but uses the TPU context 
        to synchronize gradient updates across each processes' copy of the network. 
        This ensures that each processes' network copy stays "in sync" (they are all identical).
        This means that each process's network has the same weights after this is called.
        [Source:PyTorch XLA doc]
        '''
        # Note: barrier=True not needed when using ParallelLoader
        xm.optimizer_step(optimizer)
        if scheduler is not None:
            scheduler.step()

        
        #'mesh_reduce()' reduce the loss calculated on 8 cores.
        # The way it needs to be reduced is defined in 'reduce()' function
        loss_reduced = xm.mesh_reduce('train_loss_reduce', 
                                      loss, 
                                      reduce)
        losses.append(loss_reduced.item())
        
        dice_coeff = get_dice_coeff(torch.squeeze(pred_mask_btch), 
                                    mask_btch.float())
        dice_coeffs.append(xm.mesh_reduce('train_dice_reduce', 
                                          dice_coeff, 
                                          reduce).item())  
        
        
        del img_btch, pred_mask_btch, mask_btch
        gc.collect()
    
    xm.master_print(f'{epoch_no+1} - Loss : {reduce(losses): .4f}, Dice Coeff : {reduce(dice_coeffs): .4f}')
    

# New stuff
def eval_fn(data_loader, model, device):
    '''
    Calculates metrics on the validation data.
    
    Returns: returns calculated metrics
    '''
    model.eval()
    
    dice_coeffs = []
    losses = []
    
    for i, batch in enumerate(data_loader):
        img_btch, mask_btch = batch
        
        img_btch = img_btch.to(device)
        mask_btch = mask_btch.to(device)
        
        pred_mask_btch = model(img_btch.float())
        
        loss = loss_fn(pred_mask_btch, 
                       mask_btch.float())
        losses.append(xm.mesh_reduce('val_loss_reduce', 
                                     loss, 
                                     reduce).item())
        
        dice_coeff = get_dice_coeff(torch.squeeze(pred_mask_btch), 
                                    mask_btch.float())
        dice_coeffs.append(xm.mesh_reduce('val_dice_reduce', 
                                          dice_coeff, 
                                          reduce).item())
        
    total_dice_coeff = reduce(dice_coeffs)
    total_loss = reduce(losses)
    
    return total_loss, total_dice_coeff

Note : _mp_fn() function will use xm.master_print() method to print any message, in place of print(). 
       Because this function will run on each core, so print() would be called 8 times and print 
       the message for each 8 core. xm.master_print() will print only for one time with respect to the
       master device.

In [None]:
# Updated
def _mp_fn(rank, flags):
    
    ''' 
    This function will be called for each device which takes part of the replication. 
    
    Args:
        rank: Index of the process within the replication.
        flags: Contains custom arguments which you need to pass to each process.
   '''
    
    
    # Acquires the (unique) TPU core corresponding to this process's index
    device = xm.xla_device()
    
    # Creates the (distributed) train sampler, which let this process only access
    # its portion of the training dataset.
    train_sampler = DistributedSampler(dataset = flags['TRAIN_DS'],
                                      num_replicas = xm.xrt_world_size(),
                                      rank = xm.get_ordinal(),
                                      shuffle = True)
    train_dl = DataLoader(dataset = flags['TRAIN_DS'],
                          batch_size = flags['BATCH_SIZE'],
                          sampler = train_sampler,
                          num_workers = 0)
    
    
    val_sampler = DistributedSampler(dataset = flags['VAL_DS'],
                                     num_replicas = xm.xrt_world_size(),
                                     rank = xm.get_ordinal(),
                                     shuffle = False)
    val_dl = DataLoader(dataset = flags['VAL_DS'],
                                  batch_size = flags['BATCH_SIZE'],
                                  sampler = val_sampler,
                                  num_workers = 0)
    
    del train_sampler, val_sampler
    gc.collect()
    
#     Uncomment this line, if you want to realize the training on all the 8-cores.
#     print("Process is using", xm.xla_real_devices([str(device)])[0])
    
    fold_model = flags['FOLD_MODEL']
    fold_model.to(device)
    
    lr = flags['LR']

    optimizer = optim.SGD(model.parameters(), 
                          lr = lr,
                          momentum = 0.9)
    
    scheduler = CosineAnnealingWarmRestarts(optimizer, 
                                            T_0 = flags['EPOCHS']//3, 
                                            T_mult=1, 
                                            eta_min=0, 
                                            last_epoch=-1, 
                                            verbose=False)
    
    xm.master_print('Training now...')
    for e_no, epoch in enumerate(range(flags['EPOCHS'])):
        
        # Here comes our data loader for 8 cores.
        # It takes famous 'DataLoader()' object and list of 
        # devices where data has to be sent.
        # Calling 'per_device_loader()' on it will
        # return the data loader for the particular device.
        train_para_loader = pl.ParallelLoader(train_dl, 
                                              [device]).per_device_loader(device)
        
        train_one_epoch(e_no, 
                        train_para_loader,
                        fold_model, 
                        optimizer, 
                        device,
                        scheduler)
        
        del train_para_loader
        gc.collect()
         
    xm.master_print('\nValidating now...')
    val_para_loader = pl.ParallelLoader(val_dl, 
                                       [device]).per_device_loader(device)
    
    loss, dice_coeff = eval_fn(val_para_loader,
                               fold_model,
                               device)
    del val_para_loader
    gc.collect()    
    
    xm.master_print(f'Val Loss : {loss: .4f}, Val Dice : {dice_coeff: .4f}')
        
    #Saving the model, so that we can import it in the inference kernel.
    xm.save(fold_model.state_dict(), f"8core_fold_model_{flags['FOLD_NO']}.pth")

âœ¨âœ¨âœ¨

**Now, we will be calling the most important function of this NB, which is `xmp.spawn()` for each fold.**

This function will create eight processes, one for each TPU core, and call above `_map_fn()` (map function) on each process. The inputs to `_map_fn()` are an index (zero through seven) and the placeholder `flags.` When the proccesses acquire their device they actually acquire their corresponding TPU core automatically. Each core is itself a device.

Aruguments of xmp.spawn():
* fn = The function to be called for each device (TPU core).
* args = The arguments which we will be passed to the `_mp_fn` function.
* nprocs : The number of processes/devices for the replication.
* start_method : The Python *multiprocessing* process creation method. For more info, visit [here](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).

In [None]:
for fold, (t_idx, v_idx) in enumerate(group_kfold.split(fnames, 
                                                        groups = groups)):
    
    print(f'Fold: {fold+1}')
    print('-' * 40)
    
    t_fnames = fnames[t_idx]
    v_fnames = fnames[v_idx]
    
    train_ds = HuBMAPDataset(data_path = DATA_PATH,
                             fnames = t_fnames,
                             preprocess_input = preprocessing_fn,
                             transforms = transforms)
    
    val_ds = HuBMAPDataset(data_path = DATA_PATH,
                           fnames = v_fnames,
                           preprocess_input = preprocessing_fn,
                           transforms = None)
    
    model = HuBMAPModel()
    model.float()
    
    lr = 0.3
    
    FLAGS = {'FOLD_NO': fold,
             'TRAIN_DS': train_ds,
             'VAL_DS': val_ds,
             'FOLD_MODEL': model,
             'LR': lr,
             'BATCH_SIZE' : 32,
             'EPOCHS' : 30}

    xmp.spawn(fn = _mp_fn, 
              args = (FLAGS,), 
              nprocs = 8,
              start_method = 'fork')
    
    
    print('\n')
    
    del train_ds, val_ds, model

In [None]:
# We have trained our warriors and they are resting in the current
# working directory.
# Now, we will move them to the battle ground (Inference kernel) for the war.
!ls

**So, this is how we can train our model on all the 8 cores of the TPU with PyTorch.**

Do consider the upvote, if you liked it. :)

# 7. <a id='7'>References</a>
[Table of contents](#0.1)
* https://www.kaggle.com/joshi98kishan/pytorch-xla-setup-script
* https://www.kaggle.com/tanlikesmath/the-ultimate-pytorch-tpu-tutorial-jigsaw-xlm-r
* https://pytorch.org/xla
* https://www.kaggle.com/abhishek/super-duper-fast-pytorch-tpu-kernel
* https://www.kaggle.com/abhishek/bert-multi-lingual-tpu-training-8-cores