- Original Abhisekh's code
- Data setup from https://www.kaggle.com/xhlulu/jigsaw-tpu-xlm-roberta
- Inspirations from @tanlikesmath https://www.kaggle.com/tanlikesmath/xlm-roberta-pytorch-xla-tpu
- Special Thanks To Pytorch-XLA Devs!!

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
import os
os.environ['XLA_USE_BF16'] = "1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'
import torch
import pandas as pd
from scipy import stats
import numpy as np

from tqdm import tqdm
from collections import OrderedDict, namedtuple
import torch.nn as nn
from torch.optim import lr_scheduler
import joblib

import logging
import transformers
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule, XLMRobertaTokenizer, XLMRobertaModel, XLMRobertaConfig
import sys
from sklearn import metrics, model_selection

In [None]:
import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings
warnings.filterwarnings("ignore")

In [None]:
class BERTDatasetTraining:
    def __init__(self, X=None):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, item):
        
        ids = self.X[item][0]
        targets = self.X[item][1]
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'targets': torch.tensor(targets, dtype=torch.float)
        }

In [None]:
class CustomRoberta(nn.Module):
    def __init__(self):
        super(CustomRoberta, self).__init__()
        self.num_labels = 1
        self.roberta = transformers.XLMRobertaModel.from_pretrained("xlm-roberta-large", output_hidden_states=False, num_labels=1)
        self.dropout = nn.Dropout(p=0.2)
        self.classifier = nn.Linear(1024, self.num_labels)

    def forward(self,
                input_ids=None,
                attention_mask=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None):

        _, o2 = self.roberta(input_ids,
                               attention_mask=attention_mask,
                               position_ids=position_ids,
                               head_mask=head_mask,
                               inputs_embeds=inputs_embeds)

        logits = self.classifier(o2)       
        outputs = logits
        return outputs

In [None]:
model = CustomRoberta();

In [None]:
%%time
# load pre-tokenized cached data
X_train = np.load("../input/sample-tpu-xlmr-pytorch-pad-on-fly/x_train_tokenized.npy", allow_pickle=True)
X_valid = np.load("../input/sample-tpu-xlmr-pytorch-pad-on-fly/x_valid_tokenized.npy", allow_pickle=True)

X_train.shape, X_valid.shape

In [None]:
'''
poor score for 240k; tune params maybe;
'''
X_train = X_train[:120000]
import gc; gc.collect()

In [None]:
gc.collect();
train_targets = X_train[:,1]
valid_targets = X_valid[:,1]

train_dataset = BERTDatasetTraining(X_train)
valid_dataset = BERTDatasetTraining(X_valid)
gc.collect()

In [None]:
def run():
    
    def loss_fn(outputs, targets):
        return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))

    def train_loop_fn(data_loader, model, optimizer, device, scheduler=None, epoch=None):
        
        model.train()
        
        for bi, d in enumerate(data_loader):
            
            ids = d["ids"]
            targets = d["targets"]

            ids = ids.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)

            optimizer.zero_grad()
            outputs = model(
                input_ids=ids,
                attention_mask = (ids>0).to(device),
            )
            
            loss = loss_fn(outputs, targets)
            
            if bi % 100 == 0:
                xm.master_print(f'bi={bi}, loss={loss}')

            loss.backward()
            xm.optimizer_step(optimizer)
            
            if scheduler is not None:
                scheduler.step()
        
        model.eval();
        # NB model is cached here because it somewhat works this way for 8 cores;
        # DON'T ASK WHY; ;)
        xm.save(model.state_dict(), f"xlm_roberta_model_{epoch}.bin")
        
    def eval_loop_fn(data_loader, model, device):
        
        model.eval()
        fin_targets = []
        fin_outputs = []
        for bi, d in enumerate(data_loader):
            
            if bi %50 == 0:
                xm.master_print(f'EVAL bi={bi}')
            
            ids = d["ids"]
            targets = d["targets"]

            ids = ids.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)

            outputs = model(
                input_ids=ids,
                attention_mask = (ids>0).to(device),
            )

            targets_np = targets.cpu().detach().numpy().tolist()
            outputs_np = outputs.cpu().detach().numpy().tolist()
            fin_targets.extend(targets_np)
            fin_outputs.extend(outputs_np)    

        return fin_outputs, fin_targets

    
    MAX_LEN = 128
    TRAIN_BATCH_SIZE = 8
    EPOCHS = 2 # change

    tokenizer = transformers.XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')

    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True,
    )

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=2,
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=8,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=1,
    )

    device = xm.xla_device();
    model.to(device);

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    lr = 1e-4 * xm.xrt_world_size()
    num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    xm.master_print(f'num_train_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')

    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

    for epoch in range(EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=None, epoch=epoch)
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        o, t = eval_loop_fn(para_loader.per_device_loader(device), model, device)
        auc = metrics.roc_auc_score(np.array(t) >= 0.5, o)
        del o,t
        gc.collect()
        xm.master_print(f'AUC = {auc}')

In [None]:
%%time

def _mp_fn(rank, flags):
    a = run()

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')