# About this Notebook
* In 2018 ,Google open sourced the device which was the backbone of their state of the art results ,the TPU's (more on them in a bit) . They made TPU's available through their cloud services for anyone to use at $2/hour and after some time TPU's were made available for free on Google colab .

* As we all know TPU's are very fast and give state of the art results for neural networks , however initially it could only be used with TensorFlow and Keras , leaving the pytorch fans really frustated as they didn't want to shift to TF . This led to a chain of events for the development of way allowing TPU's to be used with Pytorch   

* Hence Pytorch-XLA module was developed which lets pytorch to run its graph on xla_devices like TPU's . 

In the **Jigsaw competition** TPU's have been used in various different ways , on single cores, multi cores, using different chechpoints etc,etc . However when I tried to understand the Publically shared kernels to explore how to use TPU cores with pytorch, I found it was really difficult to comphrehend. Also for Jigsaw the best performing model is XLM-Roberta which is a fairly large model and throws an error if someone is not careful with memory management.

The complexity of TPU's and usage of XLM-Roberta with TPU's left me frustated and angry . I decided to take this TPU thing slowly and started with small models building on that with a lot of experimentations upto XLM-Roberta .

**In this Notebook I share my experimentations with Pytorch-XLA and TPU's . I will start from basic TPU usage and build on that to show how to use TPUs on multiple cores and also with multithreading. I will also share some tips and tricks which would be useful when using TPU's with Pytorch XLA. If you want to learn to use TPU's the easy way, this might be a good place to start** 

After learning this You will be able to Decode and easily understand ALEX's and Abhishek's kernels

# Prerequisites:

* This notebook assumes that you are familiar with Pytorch and have used it before to build models. If you have not , here are some useful links to learn pytorch from zero:<br>
1) https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html<br>
2) https://pytorch.org/tutorials/beginner/nn_tutorial.html<br>
3) https://pytorch.org/tutorials/beginner/pytorch_with_examples.html <br>

* Using Pytorch is fairly easy , it's just like using python only in a more advanced kind of way. Here is the link to the repository which contains Pytorch implementations on different architectures : https://github.com/tanulsingh/Pytorch-for-Everyone

In [None]:
import numpy as np 
import pandas as pd

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# What are TPU's? How they work? How are they different from a GPU?

You might be thinking why knowing how tpus work is important , well it's not a must but to exploit something fully we must know how it works right?
TPUs are hardware accelerators specialized in deep learning tasks. For explanation of what  TPU's are and how they work please go through the following videos :
* [video1](https://www.youtube.com/watch?v=MXxN4fv01c8)
* [video2](https://www.youtube.com/watch?v=kBjYK3K3P6M)<br><br>
Its important to understand the underlying concepts of Pytorch XLA's . If you want to dig even deeper [here](https://codelabs.developers.google.com/codelabs/keras-flowers-data/#2) is a article by google explaining everything about TPU's


# Key Takeaways
Following are the key takeaways from the above videos and articles :-

* Each TPU v3 board has 8 TPU cores and 64 GB's of memory
* TPU's consist of two units, Matrix Multiply Unit (MXU) which runs matrix multiplications and a Vector Processing Unit (VPU) for all other tasks such as activations, softmax, etc.
* TPU's v2/v3 use a new type of dtype called bfloat16 which combines the range of a 32-bit floating point number with just the storage space of only a 16-bit floating point number and this allows to do fit more matrices in the memory and thus more matrix multiplications. This increased speed comes at the cost of precision as bfloat16 is able to represent fewer decimal places as compared to 16-bit floating point integer but its ohk because neural networks can work at a reduced precision while maintaining their high accuracy
* The ideal batch size for TPUs is 128 data items per TPU core but the hardware can already show good utilization from 8 data items per TPU core

**Now we move onto the final question does TPU's directly run the Python code? Or is there something else working under the hood without credits**

![](https://3s81si1s5ygj3mzby34dq6qf-wpengine.netdna-ssl.com/wp-content/uploads/2018/12/bfloat.jpg)


# Under the Hood
* We know that any deep learning framework first defines a computation graph which is then executed by any processing chip to train a neural network. Similarly, The TPU does not directly run Python code, it runs the computation graph defined by your program.However the computation graph is first converted into TPU machine code. Under the hood, a compiler called XLA (accelerated Linear Algebra compiler) transforms the graph of computation nodes into TPU machine code. This compiler also performs many advanced optimizations on your code and your memory layout. 
* In tensorflow the conversion from computation to TPU machine code automatically takes place as work is sent to the TPU, whereas there was no such support for Pytorch and thus XLA module was created to include XLA in our build chain explicitly.

![](https://lh5.googleusercontent.com/NjGqp60oF_3Bu4Q63dprSivZ77BgVnaPEp0Olk1moFm8okcmMfPXs7PIJBgL9LB5QCtqlmM4WTepYxPC5Mq_i_0949sWSpq8pKvfPAkHnFJWuHjrNVLPN2_a0eggOlteV7mZB_Z9)

Let's start with importing Pytorch-XLA and necessary Modules

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

PS : Above code only works when TPU is on ðŸ˜œ <br><br>

In [None]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn import metrics
from sklearn.model_selection import train_test_split
import transformers
from transformers import AdamW

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

import warnings
warnings.filterwarnings("ignore")

Now if you have used Pytorch with GPU's before , you know that running your code and models on GPU's is this simple :
* device = "cuda"
* tensors.to(device)
* model.to(device)<br><br>
Thus While building Pytorch-XLA developers had the same thing in mind and they wanted to create something similar, so that the end users have the same feel and structure while using TPU's as they had while using GPU's<br><br>
**Pytorch-XLA treats each TPU core as an individual XLA device and thus using a TPU core is as easy as**:
* device = xm.xla_device()
* tensors.to(device)
* model.to(device)

In [None]:
t0 = torch.randn(2, 2, device=xm.xla_device()) #creating a tensor and sending it to the TPU
t1 = torch.randn(2, 2, device=xm.xla_device()) #creating a tensor and sending it to the TPU
print(t0 + t1) # As both tensors are now on the same device  i.e same TPU core we can perform any calculation on them like addition
print(t0.mm(t1)) # Matrix Multiplication

Sending a model On TPU

In [None]:
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in) #NOTE THAT THE TENSOR AND MODEL BOTH BE SENT TO THE DEVICE AS WE DID WITH GPUS , THEN ONLY we CAN PERFORM ANY OPERATION
print(l_out)

# Training on a Single TPU Core
Now that we know the basics of Pytorch-XLA , let's see how to train models on TPU Single cores and what are the changes in code.
I will be using a subset of Jigsaw Competition , and using BERT-BASE for simplicity , In the end of this notebook I will provide information about the experimentations I have done with XLM-Roberta and TPU's and give insights to efficiently train XLM-Roberta on TPU's<br><br>
**First I will start with Functions that will remain the same irrespective of where and how we want to train the model**

# Remaining same always

In [None]:
class config:
    
    MAX_LEN = 224
    TRAIN_BATCH_SIZE = 32
    VALID_BATCH_SIZE = 8
    EPOCHS = 1
    MODEL_PATH = "model.bin"
    TRAINING_FILE = '/kaggle/input/jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv'
    TOKENIZER = transformers.BertTokenizer.from_pretrained('bert-base-uncased',do_lower_case =True)

In [None]:
class BERTDataset(torch.utils.data.Dataset):
    def __init__(self,text,target):
        self.text = text
        self.target = target
        self.tokenizer = config.TOKENIZER
        self.max_len = config.MAX_LEN 

    def __len__(self):
        return len(self.text)

    def __getitem__(self,idx):
        text  = str(self.text[idx])
        text = " ".join(text.split())

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens = True,
            max_length = self.max_len,
            pad_to_max_length = True
        )

        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]

        return {
            'ids': torch.tensor(ids,dtype=torch.long),
            'mask': torch.tensor(mask,dtype=torch.long),
            'targets': torch.tensor(self.target[idx],dtype=torch.long)
        }

In [None]:
class BERTBaseUncased(nn.Module):
    def __init__(self):
        super(BERTBaseUncased,self).__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
        self.bert_drop = nn.Dropout(0.3)
        self.fc1 = nn.Linear(768,1)

    def forward(self,ids,mask):
        _,o2 = self.bert(
            ids,
            mask
        )
        bo = self.bert_drop(o2)
        out = self.fc1(bo)

        return out

In [None]:
def loss_fn(outputs, targets):
        return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))

In [None]:
def eval_fn(data_loader, model, device):
    model.eval()
    fin_targets = []
    fin_outputs = []
    with torch.no_grad():
        for bi, d in enumerate(data_loader):
            ids = d["ids"]
            mask = d["mask"]
            targets = d["targets"]

            ids = ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)

            outputs = model(
                ids=ids,
                mask=mask
            )
            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
    return fin_outputs, fin_targets

In [None]:
train = pd.read_csv(config.TRAINING_FILE).fillna("none").sample(n=4000)
train_df, valid_df, train_tar, valid_tar = train_test_split(train.comment_text, train.toxic, 
                                                  stratify=train.toxic.values, 
                                                  random_state=42, 
                                                  test_size=0.2, shuffle=True)

# Functions that will change

In [None]:
def train_fn(data_loader, model, optimizer, device, scheduler,epoch,num_steps):
    model.train()

    for bi, d in enumerate(data_loader):
        ids = d["ids"]
        mask = d["mask"]
        targets = d["targets"]

        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)
        
        optimizer.zero_grad()
        outputs = model(
            ids=ids,
            mask=mask
        )

        loss = loss_fn(outputs, targets)
        loss.backward()
        #--------------------------------#------------------------#----------------------------#--------------------------#
        ####################################### CHANGE HAPPENS HERE #######################################################
        xm.optimizer_step(optimizer,barrier=True)
        ###################################################################################################################
        #-------------------------------#------------------------#----------------------------#---------------------------#
        if scheduler is not None:
                scheduler.step()
    
        
        if (bi+1) % 10 == 0:
            print('Epoch [{}/{}], bi[{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, 1, bi+1,num_steps, loss.item()))

In [None]:
def run():
    train_dataset = BERTDataset(
        text=train_df.values,
        target=train_tar.values
    )
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=4
    )

    valid_dataset = BERTDataset(
        text=valid_df.values,
        target=valid_tar.values
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=64,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=4
    )
    
    #-----------------------------#---------------------#-----------------------------------#-----------------------------------#
    ##################################### Change occurs Here ####################################################################

    device = xm.xla_device()
    model = BERTBaseUncased()
    model.to(device)
    
    #############################################################################################################################
    #----------------------------#----------------------#------------------------------------#-----------------------------------#
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    ]
    
    lr = 3e-5 * xm.xrt_world_size()    #You can or cannot make this change , it will work if not multiplied with xm.xrt_world_size()

    num_train_steps = int(len(train_dataset) / config.TRAIN_BATCH_SIZE / xm.xrt_world_size() * config.EPOCHS)
    optimizer = AdamW(optimizer_parameters, lr=lr)

    best_accuracy = 0
    for epoch in range(config.EPOCHS):
        train_fn(train_data_loader, model, optimizer, device, scheduler=None,epoch=epoch,num_steps=num_train_steps)
        
        outputs, targets = eval_fn(valid_data_loader, model, device)
        
        outputs = np.array(outputs) >= 0.5
        accuracy = metrics.roc_auc_score(targets, outputs)
        print(f"AUC_SCORE = {accuracy}")
        if accuracy > best_accuracy:
            xm.save(model.state_dict(), config.MODEL_PATH)
            best_accuracy = accuracy

In [None]:
run()

What changes did we make ?
* We declared device as xm.xla_device()
* We used xm.optimizer_step(optimizer,barrier=True)<br><br>
And these were all that was needed for running on a single TPU core.

### Curious About What Barrier=True Does?

TO understand We must know how XLA tensors work internally and how XLA creates graphs and runs operations. XLA tensors are Lazy and their internals differ from CPU and CUDA tensors. CPU and CUDA tensors launch operations immediately or eagerly. XLA tensors, on the other hand, are lazy. They record operations in a graph until the results are needed. Deferring execution like this lets XLA optimize it. Thus Calling xm.optimizer_step(optimizer, barrier=True) at the end of each training iteration causes XLA to execute its current graph and update the modelâ€™s parameters.

# Training on Multiple TPU cores
Let's now move at the most interesting part of running Pytorch on MULTIPLE TPU CORES simultaneously. 

Working with multiple Cloud TPU cores is different than training on a single Cloud TPU core. With a single Cloud TPU core we simply acquired the device and ran the operations using it directly. To use multiple Cloud TPU cores we must use other processes, one per Cloud TPU core. This indirection and multiplicity makes multicore training a little more complex than training on a single core, but it's necessary to maximize performance.

To understand we must understand four important things offered by XLA- Module
* xla_multiprocessing
* spawn() Function 
* ParallelLoader
* XLA_USE_BF16

We Will first use these in our model and see how they work and then I will explain what these do

In [None]:
!export XLA_USE_BF16=1

In [None]:
def train_fn(data_loader, model, optimizer, device, scheduler,epoch,num_steps):
    model.train()

    for bi, d in enumerate(data_loader):
        ids = d["ids"]
        mask = d["mask"]
        targets = d["targets"]

        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)
        
        optimizer.zero_grad()
        outputs = model(
            ids=ids,
            mask=mask
        )

        loss = loss_fn(outputs, targets)
        loss.backward()
        #--------------------------------#------------------------#----------------------------#--------------------------#
        ####################################### CHANGE HAPPENS HERE #######################################################
        xm.optimizer_step(optimizer)
        ###################################################################################################################
        #-------------------------------#------------------------#----------------------------#---------------------------#
        if scheduler is not None:
                scheduler.step()
    
        
        if (bi+1) % 10 == 0:
            print('Epoch [{}/{}], bi[{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, 1, bi+1,num_steps, loss.item()))

In [None]:
model = BERTBaseUncased()

In [None]:
def _run():
    
    train_dataset = BERTDataset(
        text=train_df.values,
        target=train_tar.values
    )
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=4
    )

    valid_dataset = BERTDataset(
        text=valid_df.values,
        target=valid_tar.values
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=64,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=4
    )

    device = xm.xla_device()
    model.to(device)
    
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    ]
    
    lr = 0.4 * 1e-5 * xm.xrt_world_size()

    num_train_steps = int(len(train_dataset) / config.TRAIN_BATCH_SIZE / xm.xrt_world_size() * config.EPOCHS)
    optimizer = AdamW(optimizer_parameters, lr=lr)

    best_accuracy = 0
    #---------------------------------------#--------------------------------#----------------------------#-------------------------------#
    ########################################## Change occur In this Loop #################################################################
    for epoch in range(config.EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=None,epoch=epoch,num_steps=num_train_steps)
        
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        outputs, targets = eval_fn(para_loader.per_device_loader(device), model, device)
        
    ########################################################################################################################################
    #---------------------------------------#---------------------------------#------------------------------#------------------------------#
        
        outputs = np.array(outputs) >= 0.5
        accuracy = metrics.roc_auc_score(targets, outputs)
        print(f"AUC_SCORE = {accuracy}")
        if accuracy > best_accuracy:
            xm.save(model.state_dict(), config.MODEL_PATH)
            best_accuracy = accuracy

In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run()

FLAGS={}
#xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

### NOTE :- I didn't know about this and got to know while I was preparing this notebook . It's always useful to experiment ðŸ˜‰ 
To Run on mutliple TPU cores xlm.xla_device() should not be called before spawning i.e before the spawn function . In simple terms xm.xla_device should only be called through spawn function . To understand this point just uncomment this last line and you will get a Runtime error. But when you restart the kernel and don't run the blocks where xlm.xla_device() was called , the code works fine. Try doing this exercise for fun. Why this happens? More on that in a bit

# Things To notice:
* xm.optimizer_step() does not take a barrier argument this time
* Model was declared outside the run function and was sent to Xla Device in the run fucntion whereas when using single TPU's we did it simultaneously in one place
* Something called Paraloader is wrapped around dataloader
* USE of XLA_USE_BF16 Environment variable
* And off course we now run the spawn function to execute the  model training and eval <br><br>
Let's now talk about each of these one by one . Let's start with:

# XLA_USE_BF16 Environment variable

PyTorch/XLA can use the bfloat16 datatype when running on TPUs. In fact, PyTorch/XLA handles float types (torch.float and torch.double) differently on TPUs. This behavior is controlled by the XLA_USE_BF16 environment variable:

* By default both torch.float and torch.double are torch.float on TPUs.
* If XLA_USE_BF16 is set, then torch.float and torch.double are both bfloat16 on TPUs.
* If a PyTorch tensor has torch.bfloat16 data type, this will be directly mapped to the TPU bfloat16 (XLA BF16 primitive type).

# Paraloader
* ParallelLoader loads the training data onto each device i.e onto each TPU core
* Wraps an existing PyTorch DataLoader with background data upload.

# Barrier No longer needed
* xm.optimizer_step(optimizer) no longer needs a barrier. ParallelLoader automatically creates an XLA barrier that evalutes the graph

# Spawn Function
This is the most important of all to know how to effectively use multi-processing and Multiple TPU cores.  
* What spawn function does is it creates multiple copies of the computation graphs to be fed to different cores or xla_devices . It also makes copies of the data on which the model is trained upon.
* spawn() takes a function (the "map function"), a tuple of arguments (the placeholder flags dict), the number of processes to create, and whether to create these new processes by "forking" or "spawning." 
* In the above code here, spawn() will create eight processes, one for each Cloud TPU core, and call _map_fn() -- the map function -- on each process. The inputs to _map_fn() are an index (zero through seven) and the placeholder flags. When the proccesses acquire their device they actually acquire their corresponding Cloud TPU core automatically.

### Map_function
Let's now talk about the map function . So it is the function which is called on the replicated n number of processes. Pytorch XLA makes nprocs copies as soon as the spawn function is called , one for each device , then the map function is called the first thing on each of these devices. Map function takes two arguments , one is process index (zero to n) and the placeholder flags which is a dictionary and can contain configuration of your model like max_len, epochs, num_workers,etc

### Now back to why we cannot call xm.xla_device() before spawing?
* This is because if we do this we aquire single TPU core as our XLA-Device and thus Pytorch-XLA will think that there is a single device and will not be able to use the other TPU cores .
* While when spawing occrus the TPU cores are called as different devices running the same processes via the index input to the map_function .In pytorch XLA we can grab a specific Core using xm.xla_device(n=) and passing the index. That's how spawn is able to grab different TPU cores 

### How did each process in the above cell know to acquire its own Cloud TPU core?
The answer is context. Accelerators, like Cloud TPUs, manage their operations using an implicit stateful context. In the cell above, the `spawn()` function creates a multiprocessing context and gives it to each new, forked process, allowing them to coordinate.

# Experimentations with TPUs for Jigsaw
* Now since we know that spawn makes copies of data and processes , it is the reason for a common error of over memory usage when working with large models like XLM-Roberta and Roberta itself .
* I have tried and tested the following:<br>
> 
BERT- BASE UNCASED ------> Maxlength = 224 , bs= 128  epochs=3 works with multi-core without getting out of memory <br>
BERT-BASE-MULTILINGUAL ------> Maxlength = 224 , bs=64 epcohs =3 ,works with multi-core without getting out of memory . bs of 128 does not work<br>
XLM-Roberta ----> Gets out of memory even with bs =16 and max_len = 96 . <br>
Ideal hyperparams for XLM-Roberta for this comp would be maxlebgth = 224 and bs =64  and it works fine on google colab because it gives 25gb of VMS <br>

# End Notes

I have tried to explain everything in great detail and I hope this resolves some of the uncertainities pertaining to TPU's . If you find something new or there is some doubt , Please mention it in the comment and we can discuss on it.

<font color='orange'>I hope my efforts helped you and you have a better clarity regarding Pytorch-XLA and TPU's in general. Upvoting is gesture of appreciation which tells me that my notebook did help you , please be kind enough to upvote</font>

# Further Reading
* http://pytorch.org/xla/release/1.5/index.html
* https://github.com/pytorch/xla