### Credits and citations

## Full credits to the origin author [@shoheiazuma](https://kaggle.com/shoheiazuma) of the original notebook. But also big thanks to [@abhishek](https://www.kaggle.com/abhishek) for his example notebooks and videos on how to build and run models on TPUs and multiple TPUs.

I did some tidying and reorganisation of the code to learn more about how to switch code between CPU/GPU/TPU. There can be more improvements as we go along please join me in simplifying the process of writing and running code on CPUs, GPUs and TPUs. 

Please feel free to answer there as well as comment below.

#### Forked from https://www.kaggle.com/shoheiazuma/tweet-sentiment-roberta-pytorch

### This is a training version of the notebook, the [inference version can be found here](https://www.kaggle.com/neomatrix369/tse2020-roberta-pytorch-multi-tpu-10-skfd-2-2).

### TPU installation

Thanks to the tips on https://www.kaggle.com/c/tweet-sentiment-extraction/discussion/159221, I upated my TPU/Pytorch installation process.

In [None]:
%%bash

echo "TPU_DEPS_INSTALLED=${TPU_DEPS_INSTALLED:-}"
if [[ -z "${TPU_DEPS_INSTALLED:-}" ]]; then
    echo "Installing TPU dependencies."
    pip install --upgrade pip
# Nightly builds can be unstable, I have been discussing this with others, and now I take 
# this advise from @Kirderf and apply it myself 
# python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
    curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py \
         -o pytorch-xla-env-setup.py
    python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev # --version nightly
    export TPU_DEPS_INSTALLED=true
    echo "TPU dependencies installed."
else
   echo "TPU dependencies already exist. Skipping step."
fi

In [None]:
%%bash
export XLA_USE_BF16=1
export XRT_TPU_CONFIG="tpu_worker;0;10.240.1.2:8470"
ls -lash *.whl 
rm -fr *.whl || true

In [None]:
import numpy as np
import pandas as pd
import os
import warnings
import random
import torch 
from torch import nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold
import tokenizers
import transformers
from transformers import RobertaModel, RobertaConfig

warnings.filterwarnings('ignore')

In [None]:
from transformers import get_linear_schedule_with_warmup, AdamW

import torch_xla.core.xla_model as xm # using TPUs
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

from tqdm.autonotebook import tqdm

from joblib import Parallel, delayed
import warnings
warnings.filterwarnings("ignore")

In [None]:
def print_to_console(string_to_print, end='\n', flush=False):
    if accelerator_device == "tpu":
        xm.master_print(string_to_print) 
    else:
        print(string_to_print, end=end, flush=flush)

In [None]:
accelerator_device = "tpu"
# import tensorflow as tf
# print(f'Tensorflow version {tf.__version__}')

import os

TPU_WORKER = os.environ["TPU_NAME"]
# tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
# print(f"Running on TPU: {tpu.cluster_spec().as_dict()['worker']}")
print_to_console(f"TPU_WORKER: {TPU_WORKER}")
    
# tf.config.experimental_connect_to_cluster(tpu)
# strategy = tf.distribute.experimental.TPUStrategy(tpu)

# REPLICAS_OR_WORKERS = strategy.num_replicas_in_sync
REPLICAS_OR_WORKERS = xm.xrt_world_size()
print_to_console(f'REPLICAS: {REPLICAS_OR_WORKERS}')
TPU_CORES=8
print_to_console(f'TPU_CORES: {TPU_CORES}')

In [None]:
cpu_count = os.cpu_count()
MAX_LEN = 128
multiple_workers = REPLICAS_OR_WORKERS > 1
TRAIN_BATCH_SIZE = 32 # originall 16 * REPLICAS_OR_WORKERS
VALID_BATCH_SIZE = 16
ROBERTA_PATH = "../input/roberta-base"

# Seed

In [None]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)

seed = 42
seed_everything(seed)

# Data Loader

In [None]:
class TweetDataset(torch.utils.data.Dataset):
    def __init__(self, df, max_len=MAX_LEN): # original max_len=96
        self.df = df
        self.max_len = max_len
        self.labeled = 'selected_text' in df
        self.tokenizer = tokenizers.ByteLevelBPETokenizer(
            vocab_file='../input/roberta-base/vocab.json', 
            merges_file='../input/roberta-base/merges.txt', 
            lowercase=True,
            add_prefix_space=True)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        
        ids, masks, tweet, offsets = self.get_input_data(row)
                
        data = {
            'ids': torch.tensor(ids, dtype=torch.long),
            'masks': torch.tensor(masks, dtype=torch.long),
            'tweet': tweet,
            'offsets': torch.tensor(offsets, dtype=torch.long),
        }        
    
        if self.labeled:
            start_idx, end_idx = self.get_target_idx(row, tweet, offsets)
            data['start_idx'] = torch.tensor(start_idx, dtype=torch.long) 
            data['end_idx'] = torch.tensor(end_idx, dtype=torch.long)

        return data

    def __len__(self):
        return len(self.df)
    
    def get_input_data(self, row):
        tweet = " " + " ".join(row.text.lower().split())
        encoding = self.tokenizer.encode(tweet)
        sentiment_id = self.tokenizer.encode(row.sentiment).ids
        ids = [0] + sentiment_id + [2, 2] + encoding.ids + [2]
        offsets = [(0, 0)] * 4 + encoding.offsets + [(0, 0)]
                
        pad_len = self.max_len - len(ids)
        if pad_len > 0:
            ids += [1] * pad_len
            offsets += [(0, 0)] * pad_len
        
        ids = torch.tensor(ids)
        masks = torch.where(ids != 1, torch.tensor(1), torch.tensor(0))
        offsets = torch.tensor(offsets)
        
        return ids, masks, tweet, offsets
        
    def get_target_idx(self, row, tweet, offsets):
        selected_text = " " +  " ".join(row.selected_text.lower().split())

        len_st = len(selected_text) - 1
        idx0 = None
        idx1 = None

        for ind in (i for i, e in enumerate(tweet) if e == selected_text[1]):
            if " " + tweet[ind: ind+len_st] == selected_text:
                idx0 = ind
                idx1 = ind + len_st - 1
                break

        char_targets = [0] * len(tweet)
        if idx0 != None and idx1 != None:
            for ct in range(idx0, idx1 + 1):
                char_targets[ct] = 1

        target_idx = []
        for j, (offset1, offset2) in enumerate(offsets):
            if sum(char_targets[offset1: offset2]) > 0:
                target_idx.append(j)

        start_idx = target_idx[0]
        end_idx = target_idx[-1]
        
        return start_idx, end_idx
        
def get_train_val_loaders(df, train_idx, val_idx, batch_size=TRAIN_BATCH_SIZE):
    train_df = df.iloc[train_idx]
    val_df = df.iloc[val_idx]

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        TweetDataset(train_df),
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    train_loader = torch.utils.data.DataLoader(
        TweetDataset(train_df), 
        batch_size=TRAIN_BATCH_SIZE, 
        num_workers=cpu_count, # num_workers=2
        sampler=train_sampler,
        drop_last=True)

    val_sampler = torch.utils.data.distributed.DistributedSampler(
        TweetDataset(val_df),
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        TweetDataset(val_df), 
        batch_size=VALID_BATCH_SIZE,
        sampler=val_sampler,
        num_workers=2) # num_workers=2

    dataloaders_dict = {"Training": train_loader, "Validation": val_loader}

    return dataloaders_dict

def get_test_loader(df, batch_size=VALID_BATCH_SIZE):
    test_sampler = torch.utils.data.distributed.DistributedSampler(
        TweetDataset(df),
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    loader = torch.utils.data.DataLoader(
        TweetDataset(df), 
        batch_size=VALID_BATCH_SIZE,
        sampler=test_sampler,
        num_workers=2) # num_workers=2
    return loader

# Model

In [None]:
class TweetModel(nn.Module):
    def __init__(self):
        super(TweetModel, self).__init__()
        
        config = RobertaConfig.from_pretrained(
            f'{ROBERTA_PATH}/config.json', output_hidden_states=True)    
        config.output_hidden_states = True
        self.roberta = RobertaModel.from_pretrained(
            f'{ROBERTA_PATH}/pytorch_model.bin', config=config)

        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(config.hidden_size, 2)
        nn.init.normal_(self.fc.weight, std=0.02)
        nn.init.normal_(self.fc.bias, 0)

    def forward(self, input_ids, attention_mask):
        _, _, hs = self.roberta(input_ids, attention_mask)
         
        x = torch.stack([hs[-1], hs[-2], hs[-3]])
        x = torch.mean(x, 0)
        x = self.dropout(x)
        x = self.fc(x)
        start_logits, end_logits = x.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
                
        return start_logits, end_logits

# Loss Function

In [None]:
def loss_fn(start_logits, end_logits, start_positions, end_positions):
    ce_loss = nn.CrossEntropyLoss()
    start_loss = ce_loss(start_logits, start_positions)
    end_loss = ce_loss(end_logits, end_positions)    
    total_loss = start_loss + end_loss
    return total_loss

# Evaluation Function

In [None]:
def get_selected_text(text, start_idx, end_idx, offsets):
    selected_text = ""
    for ix in range(start_idx, end_idx + 1):
        selected_text += text[offsets[ix][0]: offsets[ix][1]]
        if (ix + 1) < len(offsets) and offsets[ix][1] < offsets[ix + 1][0]:
            selected_text += " "
    return selected_text

def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

def compute_jaccard_score(text, start_idx, end_idx, start_logits, end_logits, offsets):
    start_pred = np.argmax(start_logits)
    end_pred = np.argmax(end_logits)
    if start_pred > end_pred:
        pred = text
    else:
        pred = get_selected_text(text, start_pred, end_pred, offsets)
        
    true = get_selected_text(text, start_idx, end_idx, offsets)
    
    return jaccard(true, pred)

# Training Function

In [None]:
def set_to_device(data_, field_name, device, data_type=torch.long):
    field = data_[field_name]
    return field.to(device, dtype=data_type)

In [None]:
def train_model(device, model, dataloaders_dict, criterion, optimizer, num_epochs, filename, scheduler):
    for epoch in range(num_epochs):
        for phase in ['Training', 'Validation']:
            if phase == 'Training':
                print_to_console(f'Started training Epoch: {epoch + 1}/{num_epochs}')
                model.train()
            else:
                print_to_console(f'Started validation Epoch: {epoch + 1}/{num_epochs}')       
                model.eval()

            epoch_loss = 0.0
            epoch_jaccard = 0.0
            
            data_para_loader = pl.ParallelLoader(dataloaders_dict[phase], [device])
            data_para_loader = data_para_loader.per_device_loader(device)
            tk0 = tqdm(data_para_loader, total=len(dataloaders_dict[phase]), desc=f'{phase}: {epoch + 1}/{num_epochs}, {filename}')
            
            for index, data in enumerate(tk0):
                if index % 500 == 0:
                    print_to_console(f'Started Training: index={index}/{len(tk0)}')
                ids = set_to_device(data, 'ids', device)
                masks = set_to_device(data, 'masks', device)
                tweet = data['tweet']
                offsets = data['offsets'].cpu().detach().numpy()
                start_idx = set_to_device(data, 'start_idx', device)
                end_idx = set_to_device(data, 'end_idx', device)

                model.zero_grad()

                with torch.set_grad_enabled(phase == 'Training'):
                    
                    optimizer.zero_grad()
                    start_logits, end_logits = model(ids, masks)
                    
                    loss = criterion(start_logits, end_logits, start_idx, end_idx)
                    
                    if phase == 'Training':
                        loss.backward()
                        xm.optimizer_step(optimizer)
                        scheduler.step()

                    epoch_loss += loss.item() * len(ids)
                    
                    start_logits = torch.softmax(start_logits, dim=1).cpu().detach().numpy()
                    end_logits = torch.softmax(end_logits, dim=1).cpu().detach().numpy()
                    start_idx = start_idx.cpu().detach().numpy()
                    end_idx = end_idx.cpu().detach().numpy()

                    for i in range(len(ids)):                        
                        jaccard_score = compute_jaccard_score(
                            tweet[i],
                            start_idx[i],
                            end_idx[i],
                            start_logits[i], 
                            end_logits[i], 
                            offsets[i])
                        epoch_jaccard += jaccard_score
                        if index % 500 == 0:
                            print(f'{i}/{len(ids)-1}: epoch_loss: {epoch_loss}, jaccard_score: {jaccard_score}', end="\r", flush=True)

            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_jaccard = epoch_jaccard / len(dataloaders_dict[phase].dataset)
            
            print_to_console('')
            print_to_console('Epoch {}/{} | {:^5} | Loss: {:.4f} | Jaccard: {:.4f}'.format(
                epoch + 1, num_epochs, phase, epoch_loss, epoch_jaccard))

    print_to_console(f'Saving model to {filename}')
    xm.save(model.state_dict(), filename)

# Training

In [None]:
num_epochs = 3
folds=10
LEARNING_RATE = 4e-5
skf = StratifiedKFold(n_splits=folds, shuffle=True, random_state=seed)

In [None]:
def get_optimizer_params(model):    
    param_optimizer = list(model.named_parameters())
    no_decay = [
        "bias",
        "LayerNorm.bias",
        "LayerNorm.weight"
    ]

    optimizer_parameters = [
        {
            'params': [
                p for n, p in param_optimizer if not any(
                    nd in n for nd in no_decay
                )
            ], 
         'weight_decay': 0.001
        },
        {
            'params': [
                p for n, p in param_optimizer if any(
                    nd in n for nd in no_decay
                )
            ], 
            'weight_decay': 0.0
        },
    ]
    
    return optimizer_parameters

In [None]:
def run(fold, train_idx, val_idx):
    saved_model_filename = f'roberta_fold{fold}.pth'
    print_to_console(f'Looking for saved model file {saved_model_filename}...')
    if os.path.exists(saved_model_filename):
        print_to_console(f'Model file {saved_model_filename} found, skipping process (delete model file in order to re-run the process...')
        return
    else:
        print_to_console(f'Model file {saved_model_filename}, not found proceeding with model building process...')

    print_to_console("")
    print_to_console(f'Fold: {fold}')
    MX = TweetModel()
    
    if fold >= TPU_CORES:
        fold = fold % TPU_CORES

    device = xm.xla_device(fold + 1)
    model = MX.to(device)
    
    num_train_steps = int(len(train_df) / TRAIN_BATCH_SIZE / REPLICAS_OR_WORKERS * num_epochs)
    
    num_train_steps = int(
        len(train_df) / TRAIN_BATCH_SIZE * num_epochs
    )
    optimizer = AdamW(
        get_optimizer_params(model), 
        lr=LEARNING_RATE * REPLICAS_OR_WORKERS, # xm.xrt_world_size()
        betas=(0.9, 0.999)
    )

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )
        
    num_batches = int(len(train_df) / TRAIN_BATCH_SIZE)

    criterion = loss_fn    
    dataloaders_dict = get_train_val_loaders(train_df, train_idx, val_idx, TRAIN_BATCH_SIZE)

    train_model(
        device,
        model, 
        dataloaders_dict,
        criterion, 
        optimizer, 
        num_epochs,
        saved_model_filename,
        scheduler
    )

In [None]:
%env JOBLIB_TEMP_FOLDER=/tmp
%env JOBLIB_START_METHOD="forkserver"  ### commented out helped, usually its set to stop Parallel from hanging or going idle
%env TMPDIR=/tmp

In [None]:
%%time
train_df = pd.read_csv('../input/tweet-sentiment-extraction/train.csv')
train_df['text'] = train_df['text'].astype(str)
train_df['selected_text'] = train_df['selected_text'].astype(str)

is_incorrect_fn = lambda row: (" " + row.text + " ").find(" " + row.selected_text + " ") < 0
filtered_bad_data = train_df.apply(is_incorrect_fn, axis=1)
print(f"train before correcting {train_df.shape}")
train_df = train_df[~filtered_bad_data].copy().reset_index(drop=True)
print(f"train after correcting {train_df.shape}")

USABLE_TPU_CORES=int(TPU_CORES*(3/4))
print_to_console(f'USABLE_TPU_CORES: {USABLE_TPU_CORES}')

### Note there are issues using joblib's Parallel/delay with multiple TPUs
### The process can hang on a training or validation session of a fold.
### Workaround: wait till all other folds finish, click on Cancel Run and then re-run just this cell
### As we have saved the other models, it will only start fromt the fold that has never been created (sort of pipeline architecture)
### Please also ensure there is enough diskspace, the 10 models take up 470MB and needs extra leg room to build and compress these models.
Parallel(n_jobs=USABLE_TPU_CORES, backend="threading")(
    delayed(run)(fold, train_idx, val_idx) for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df.sentiment))
)
### Alternatively you can use xmp.spawn() see solution in the next cell, but it's still being worked on for the moment.


In [None]:
!ls -lash *.pth

### Alternative way to spawn / run tasks in parallel using Pytorch's MultiProc for XLA, instead of Joblib's parallels
(please convert the following cells into code cells to execute them, may need some work still to get it to work instead of Parallel from joblib)

train_idx_list = {}
val_idx_list = {}
for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df.sentiment))
    train_idx_list[fold] = train_idx
    val_idx_list[fold] = val_idx

def _mp_fn(rank, flags):
    xm.master_print(f'TPU worker {rank}: triggering task for this TPU worker')
    torch.set_default_tensor_type('torch.FloatTensor')
    a = run(rank, train_idx_list[rank], val_idx_list[rank])

%%time
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=TPU_CORES, start_method='fork')