In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
import os
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

import torch
import pandas as pd
from scipy import stats
import numpy as np

from tqdm import tqdm
from collections import OrderedDict, namedtuple
import torch.nn as nn
from torch.optim import lr_scheduler
import joblib

import logging
import transformers
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule
from transformers import AutoTokenizer, AutoModel, AutoConfig
import sys
from sklearn import metrics, model_selection
from fastai.text import *

In [None]:
import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

import warnings
warnings.filterwarnings("ignore")

### Dataset

In [None]:
class JigsawArrayDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids:np.array, attention_mask:np.array, toxic:np.array=None, text_id=None):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.toxic = toxic
        self.text_id = text_id
    
    def __getitem__(self, idx):
        xb = (tensor(self.input_ids[idx]), tensor(self.attention_mask[idx]))
        yb = tensor(0.) if self.toxic is None else tensor(self.toxic[idx])
        yb = yb if self.text_id is None else (yb, self.text_id[idx])
        return xb,yb    
        
    def __len__(self):
        return len(self.input_ids)

In [None]:
XLM_PROCESSED_PATH = Path("/kaggle/input/xlmrobertabase/xlm_roberta_processed/")

In [None]:
XLM_PROCESSED_PATH.ls()

In [None]:
test_df = pd.read_csv("/kaggle/input/jigsaw-multilingual-toxic-comment-classification/test.csv")

### Model

In [None]:
def get_xlm_roberta_base():
    conf = AutoConfig.from_pretrained("xlm-roberta-base")
    conf.output_hidden_states = True
    model = AutoModel.from_config(config=conf)
    return model

In [None]:
class Head(Module):
    "Concat Pool over sequence"
    def __init__(self, p=0.5):
        self.d0 = nn.Dropout(p)
        self.l0 = nn.Linear(1536*2, 2)
        
    def forward(self, x):
        x = self.d0(x)
        x = torch.cat([x.permute(0,-1,-2).mean(-1), 
                       x.permute(0,-1,-2).max(-1).values], -1)
        x = self.l0(x) 
        return x

class JigsawModel(Module):
    def __init__(self, model, head):
        self.sequence_model = model
        self.head = head

    def forward(self, *xargs):
        inp = {}
        inp["input_ids"] = xargs[0]
        inp["attention_mask"] = xargs[1]
        _, _, hidden_states = self.sequence_model(**inp)
        # feed last 2 hidden states
        x = torch.cat(hidden_states[-2:], -1)
        return self.head(x)

In [None]:
model = get_xlm_roberta_base()
head = Head()
jigsaw_model = JigsawModel(model, head)

### load model

In [None]:
state_dict = torch.load("/kaggle/input/xlmrobertatoxicengmodel/model_finetuned-translated-data.bin")

In [None]:
jigsaw_model.load_state_dict(state_dict)

### Prediction

In [None]:
from sklearn.metrics import roc_auc_score

def predict_fn(data_loader, model, device, num_batches):
    model.eval()
    preds, text_ids = [], []
   
    with torch.no_grad():
#         tk0 = tqdm(data_loader, total=num_batches, desc="Predicting", disable=not xm.is_master_ordinal())
        tk0 = tqdm(data_loader, total=num_batches, desc="Predicting", disable=not xm.is_master_ordinal())
        for bi, (xb,yb) in enumerate(tk0):

            input_ids, attention_mask = xb
            input_ids = input_ids.to(device, dtype=torch.long)
            attention_mask = attention_mask.to(device, dtype=torch.long)
            out = model(input_ids, attention_mask)
            
            preds.append(to_cpu(out.softmax(-1)[:,1]))
            text_ids.append(to_cpu(yb[1]))

    return preds, text_ids

In [None]:
from torch.utils.data import DataLoader

In [None]:
def run(test_ds):
    device = xm.xla_device()
    model = jigsaw_model.to(device)
        

    test_dl = torch.utils.data.DataLoader(
        test_ds,
        batch_size=128,
        shuffle=False,
        num_workers=4
    )
    
    preds, text_ids = predict_fn(test_dl, model, device, len(test_dl))
    return preds, text_ids

In [None]:
test_ds = JigsawArrayDataset(
    input_ids = np.load(XLM_PROCESSED_PATH/'test_inputs/input_ids.npy'),
    attention_mask = np.load(XLM_PROCESSED_PATH/'test_inputs/attention_mask.npy'),
    text_id = test_df['id'].values
)

In [None]:
preds1, text_ids = run(test_ds)

In [None]:
# train_toxic = np.load(tokenized_path+'df_train_toxic.npy',mmap_mode='r')
test_ds = JigsawArrayDataset(
    input_ids = np.load(XLM_PROCESSED_PATH/'translated_test_inputs/input_ids.npy'),
    attention_mask = np.load(XLM_PROCESSED_PATH/'translated_test_inputs/attention_mask.npy'),
    text_id = test_df['id'].values
)

In [None]:
preds2, text_ids = run(test_ds)

In [None]:
preds1 = to_np(torch.cat(preds1).view(-1))
preds2= to_np(torch.cat(preds2).view(-1))

In [None]:
np.corrcoef(preds1, preds2)

In [None]:
preds = (preds1 + preds2) / 2
text_ids = to_np(torch.cat(text_ids).view(-1))

In [None]:
plt.hist(preds)

### Submit

In [None]:
subdf = pd.read_csv("/kaggle/input/jigsaw-multilingual-toxic-comment-classification/sample_submission.csv")
subdf['toxic'] = preds
subdf.to_csv("submission.csv", index=False)

### fin