# Configuration

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
!pip install torchsummary


In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.distributed.xla_multiprocessing as xmp
import os,re,gc,pickle,random,sys,collections
import numpy as np 
import pandas as pd 
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from transformers import XLMRobertaModel
import torch.nn.functional as F

In [None]:
# devices = (
#     xm.get_xla_supported_devices(
#         max_devices=8))
# print("Devices: {}".format(devices))


In [None]:
batch_size=32
learning_rate=1e-5
MAX_LEN = 192
num_epochs=10
input1 = "/kaggle/input/jigsaw-multilingual-toxic-comment-classification/"
inpath = "../input/jwtc-xlmroberta-encoding-192-pickle/datain/"
form = "training/encode_{}.pkl"
langs = ["en","en2","es","fr","it","pt","ru","tr"]
used_data = ["en","en2"]

# os.environ['XLA_USE_BF16']="1"
# os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

# Load Data
Some code is borrowed from https://www.kaggle.com/mint101/basic-xlm-r-lb-9442-intro the tensorflow implementation of XLM-R fine tune

In [None]:
def pick_load_format(path):
    with open(inpath+path,"rb") as f:
        return pickle.load(f)

def load(path):
    return pick_load_format(form.format(path))

# def load_data():
#     train = []
#     for i in used_data:
#         df1 = load(i+"_l1")
#         df0 = load(i+"_l0")
#         train += [df1,df0]
#     train = pd.concat(train)

#     train = np.stack(train.comment_text.values, axis=0).astype("int32"),train.toxic.values

#     valid = pick_load_format("valid.pkl")
#     x_valid = np.stack(valid.comment_text.values, axis=0).astype("int32")
#     y_valid = valid.toxic.values

#     test = pick_load_format("test.pkl")
#     x_test = np.stack(test.content.values, axis=0).astype("int32")
    
#     return train,(x_valid,y_valid),x_test

def get_cong(n,verb=True):
    tot = round(1+(n*2)/10_000)*10_000
    if verb: print("Pos: {}, Sample neg: {}, Total: {}".format(n,tot-n,tot))
    return tot,tot-n

def load_data(seed=1214):
    train = []
    for i in used_data:
        df1 = load(i+"_l1")
        size, sample_size = get_cong(df1.shape[0])
        df0 = load(i+"_l0").sample(n=sample_size, random_state=seed)
        train += [df1,df0]
    train = pd.concat(train)

    train = np.stack(train.comment_text.values, axis=0).astype("int32"),train.toxic.values

    valid = pick_load_format("valid.pkl")
    x_valid = np.stack(valid.comment_text.values, axis=0).astype("int32")
    y_valid = valid.toxic.values

    test = pick_load_format("test.pkl")
    x_test = np.stack(test.content.values, axis=0).astype("int32")
    
    return train,(x_valid,y_valid),x_test

In [None]:
%%time
train,valid,test = load_data()
gc.collect()

valid_size = len(valid[1])
train_size = len(train[1])
print(train_size,valid_size)
# only for debug
train=(train[0],train[1])
valid=(valid[0],valid[1])

# Pytorch dataset

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn

# train_dataloader=DataLoader(TensorDataset(torch.Tensor(train[0]).long(),torch.Tensor(train[0]).long()),
#                                                 shuffle=True,batch_size=batch_size,num_workers=1)
# valid_dataloader=DataLoader(TensorDataset(torch.Tensor(valid[0]).long(),torch.Tensor(valid[0]).long()),
#                                                 shuffle=True,batch_size=batch_size,num_workers=1)
# gc.collect()
train_dataset=TensorDataset(torch.Tensor(train[0]).long(),torch.Tensor(train[1]).long())
valid_dataset=TensorDataset(torch.Tensor(valid[0]).long(),torch.Tensor(valid[1]).long())
test_dataset=TensorDataset(torch.Tensor(test).long())

# Pretrained model

In [None]:
from transformers import XLMRobertaModel
pretrained_XLM=XLMRobertaModel.from_pretrained('xlm-roberta-large')

class Model(nn.Module):
    def __init__(self,pretrained,out_dim=768):
        super().__init__()
        self.pretrained_model=pretrained
        self.fc=nn.Linear(out_dim,1)
        torch.nn.init.kaiming_uniform_(self.fc.weight)
        self.fc.bias.data.fill_(0.)
        self.out_dim=out_dim
    
    def forward(self,word_id):
        # pos id is optional
        hidden_states=self.pretrained_model(word_id)[0]
        avg=F.adaptive_avg_pool2d(hidden_states,(1,self.out_dim)).squeeze(1)
        prob=F.sigmoid(self.fc(avg).squeeze(1))
        return prob

model=Model(pretrained_XLM)
print(model)
gc.collect()

In [None]:
from sklearn.metrics import roc_auc_score

# context=dict()
# context['device']=xm.xla_device()
# model.to(context['device'])
# train_dataloader_para=pl.ParallelLoader(train_data_loader,context['device'])
# valid_dataloader_para=pl.ParallelLoader(valid_data_loader,context['device'])
# gc.collect()

In [None]:
def reduce_fn(vals):
    # take average
    return sum(vals) / len(vals)

def _train_fn(index):
    # in parrallel TPU programs, all device dependent code should be written in xmp spawned function
    train_sampler=torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    valid_sampler=torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    train_dataloader=DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=0
    )
    valid_dataloader =DataLoader(
        valid_dataset,
        batch_size=batch_size,
        sampler=valid_sampler,
        num_workers=0
    )
    device=xm.xla_device()
    print(device)
    model_=model.to(device)
    nepoch=num_epochs
    optimizer=torch.optim.Adam(model_.parameters())
    scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max',factor=0.1,patience=3,threshold=0.0001,threshold_mode='rel',cooldown=0)
    earlyStoppingPatience=8
    loss_fn=nn.BCELoss()
    eval_metric=roc_auc_score
    
    
    best_score=0
    no_improve=0
    for i in range(1,nepoch+1):
        xm.master_print('{}/{} epochs'.format(i,nepoch))
        # train
        model.train()
        para_train_loader=pl.ParallelLoader(train_dataloader,[device])
        para_train_loader=para_train_loader.per_device_loader(device)
        xm.master_print('parallel loader created... training now')
        pbar=tqdm(enumerate(para_train_loader),total=len(para_train_loader))
        avg_loss=0
        for i,(data, target) in pbar:
#         for i,(data,target) in enumerate(para_train_loader):
            data, target=data.to(device),target.double().to(device)
            optimizer.zero_grad()
            output = model_(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            loss_reduced = xm.mesh_reduce('loss_reduce',loss,reduce_fn) 
#             master_print will only print once (not from all 8 cores)
            pbar.set_postfix({'loss':loss_reduced})
        # eval
        model.eval()
        para_valid_loader=pl.ParallelLoader(valid_dataloader,[device])
        para_valid_loader=para_valid_loader.per_device_loader(device)
        xm.master_print('parallel loader created... evaluating now')
        pbar=tqdm(enumerate(para_valid_loader),total=len(para_valid_loader))
        outputs=[]
        targets=[]
        for i,(data,target) in pbar:
#         for i,(data,target) in enumerate(para_valid_loader):
            data, target=data.to(device),target.to(device)
            outputs.extend(model_(data).detach().numpy().tolist())
            targets.extend(target.detach().numpy().tolist())
        score=eval_metric(targets,outputs)
        xm.master_print('score={}'.format(score))
        if score>best_score:
            xm.save(mode.state_dict(),'robertaCKPT')
            best_score=score
            no_improve=0
        else:
            no_improve+=1
            if no_improve>earlyStoppingPatience:
                xm.master_print('EarlyStoppinpg while best score is {}'.format(best_score))
            break
        scheduler.step(score)

def _mp_fn(index):
    _ = _train_fn()


xmp.spawn(_train_fn,args=(),start_method='fork',nprocs=1)