In [None]:
VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

In [None]:
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

Here are all of our imports. You will note that there are already an optimization applied in order for PyTorch XLA to train.

`XLA_USE_BF16` is an environment variable that tells PyTorch XLA to automatically use [bfloat16](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus).

In [None]:
import os
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

import numpy as np
import pandas as pd
from scipy import stats


from tqdm import tqdm
from collections import OrderedDict, namedtuple

# torch imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler

import joblib
import sys
import logging
import gc
import random
import time

# transformers imports
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import get_cosine_schedule_with_warmup
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from transformers import get_constant_schedule
from transformers import XLMRobertaConfig
from transformers import XLMRobertaTokenizer
from transformers import XLMRobertaModel


# torch_xla imorts
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import torch_xla.utils.serialization as xser

import warnings
warnings.filterwarnings("ignore")


from sklearn.metrics import roc_auc_score 
from sklearn.model_selection import StratifiedKFold

In [None]:
print(f"{xm.xrt_world_size()}")

In [None]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed = 999
seed_everything(seed)

In [None]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        

class ArrayDataset(Dataset):
    def __init__(self,*arrays):
        assert all(arrays[0].shape[0] == array.shape[0] for array in arrays)
        self.arrays = arrays
    
    def __getitem__(self, index):
        return tuple(torch.from_numpy(np.array(array[index])) for array in self.arrays)
    
    def __len__(self):
        return self.arrays[0].shape[0]
    
    
class JigsawDataset(object):
    def __init__(self, input_ids=None, token_type_ids=None, attention_mask=None, target=None):
        self.ids = input_ids
        self.mask = attention_mask
        self.token_ids = token_type_ids
        self.target = target
        
    def __len__(self):
        return self.ids.shape[0]
    
    def __getitem__(self, item):
        return {
            'ids': torch.from_numpy(np.array(self.ids[item])),
            'mask': torch.from_numpy(np.array(self.mask[item])),
            'target': torch.from_numpy(np.array(self.target[item]))}
    
 


# MODELS
    
class XLMRobertaLargeTC(nn.Module):
    def __init__(self):
        super(XLMRobertaLargeTC, self).__init__()
        config = XLMRobertaConfig.from_pretrained('xlm-roberta-large', output_hidden_states=True)
        self.xlm_roberta = XLMRobertaModel.from_pretrained('xlm-roberta-large', config=config)
        
        self.fc = nn.Linear(config.hidden_size, 1)
        self.dropout = nn.Dropout(p=0.2)
        
        # initialize weight
        nn.init.normal_(self.fc.weight, std=0.02)
        nn.init.normal_(self.fc.bias, 0)
        
        
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
        _, o2, _ = self.xlm_roberta(
            input_ids=input_ids, 
            attention_mask=attention_mask)
        
        o2 = self.dropout(o2)
        logits = self.fc(o2)
        
        return logits
    
    
class XLMRobertaBaseTC(nn.Module):
    def __init__(self):
        super(XLMRobertaBaseTC, self).__init__()
        config = XLMRobertaConfig.from_pretrained('xlm-roberta-base', output_hidden_states=True)
        self.xlm_roberta = XLMRobertaModel.from_pretrained('xlm-roberta-base', config=config)
        
        self.fc = nn.Linear(config.hidden_size, 1)
        self.dropout = nn.Dropout(p=0.2)
        
        # inititalize weights
        nn.init.normal_(self.fc.weight, std=0.02)
        nn.init.normal_(self.fc.bias, 0)
        
    def forward(self, input_ids=None, token_type_ids=None, attention_mask=None):
        _, o2, _ = self.xlm_roberta(
            input_ids=input_ids, 
            attention_mask=attention_mask)
        
        o2 = self.dropout(o2)
        logits = self.fc(o2)
        
        return logits
    
    
class BertMultilingualCaseTC(nn.Module):
    def __init__(self):
        super(BertMultilingualCaseTC, self).__init__()
        config = BertModel.from_pretrained('bert-base-multilingual-cased', output_hidden_states=True)
        self.bert_model = BertModel.from_pretrained('bert-base-multilingual-cased', config=config)
        self.fc = nn.Linear(config.hidden_size, 1)
        
        # initialize weights
        nn.init.normal_(self.fc.weight, std=0.02)
        nn.init.normal_(self.fc.bias, 0)
        
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, input_ids=None, token_type_ids=None, attention_mask=None):
        _, o2, _ = self.bert_model(
            input_ids=input_ids, 
            attention_mask=attention_mask)
        
        o2 = self.dropout(o2)
        logits = self.fc(o2)
        
        return logits
    
    
class BertMultilingualUncasedTC(nn.Module):
    def __init__(self):
        super(BertMultilingualUncasedTC, self).__init__()
        config = BertModel.from_pretrained('bert-base-multilingual-uncased', output_hidden_states=True)
        self.bert_model = BertModel.from_pretrained('bert-base-multilingual-uncased', config=config)
        self.fc = nn.Linear(config.hidden_size, 1)
        
        # initliaze weights
        nn.init.normal_(self.fc.weight, std=0.02)
        nn.init.normal_(self.fc.bias, 0)
        
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, input_ids=None, token_type_ids=None, attention_mask=None):
        _, o2, _ = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        o2 = self.dropout(o2)
        logits = self.fc(o2)
        
        return logits

# Training

In [None]:
import torch_xla.version as xv
print('PYTORCH:', xv.__torch_gitrev__)
print('XLA:', xv.__xla_gitrev__)

In [None]:
!free -h

In [None]:
def loss_fn(outputs, targets):
    return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))


def reduce_fn(vals):
    return sum(vals) / len(vals)


def train_loop_fn(data_loader, model, optimizer, device, scheduler=None):
    model.train()
    
    train_loss = []
    
    for bi, data in enumerate(data_loader):

        ids = data['ids']
        mask = data['mask']
        targets = data['target']

        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        optimizer.zero_grad()
        
        outputs = model(input_ids=ids, attention_mask=mask)
        loss = loss_fn(outputs, targets)
        
        train_loss.append(loss.item())
        
        
        if bi % 500 == 0:
            loss_reduced = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
            xm.master_print(f'bi={bi}, loss={loss_reduced:.5f}')
        
        loss.backward()
        
        xm.optimizer_step(optimizer)
        
        if scheduler is not None:
            scheduler.step()
            
    return train_loss
    
    
def eval_loop_fn(data_loader, model, device):
    model.eval()
    
    fin_targets = []
    fin_outputs = []
    
    for bi, data in enumerate(data_loader):
        ids = data['ids']
        mask = data['mask']
        targets = data['target']

        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        outputs = model(input_ids=ids, attention_mask=mask)

        targets_np = targets.cpu().detach().numpy().tolist()
        outputs_np = outputs.cpu().detach().numpy().tolist()
        
        fin_targets.extend(targets_np)
        fin_outputs.extend(outputs_np)    
        
        del targets_np, outputs_np
        
        gc.collect()
        
    return fin_outputs, fin_targets

In [None]:
def _run():
    MAX_LEN = 192

    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=0)
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=4,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=0)

    device = xm.xla_device()
    
    model = mx.to(device)
    
    # print only once
    if fold == 0:
        xm.master_print('done loading model')

    param_optimizer = list(model.named_parameters())
    
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    lr = 0.5e-5 * xm.xrt_world_size()
    
    num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps)
    
    xm.master_print(f'num_train_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')

    
    
    for epoch in range(EPOCHS):
        gc.collect()
        
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        
        # print only once
        if epoch == 0: 
            xm.master_print('parallel loader created... training now')
            
        
        # train mode/function
        train_loss = train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=scheduler)
        
        del para_loader
        
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        
        # eval mode/function
        o, t = eval_loop_fn(para_loader.per_device_loader(device), model, device)
        
        del para_loader
        
        auc = roc_auc_score(np.array(t) >= 0.5, o)
        auc_reduced = xm.mesh_reduce('auc_reduce',auc,reduce_fn)
        xm.master_print(f'Epoch: {epoch+1}/{EPOCHS} | train loss: {np.mean(train_loss):.4f} | val auc: {auc_reduced:.4f}')
        
        gc.collect()
        
    xser.save(model.state_dict(), f"f{fold+1}_xlm_roberta_large.pth", master_only=True)
    

In [None]:
# load pre-encoded binary files

root_dir = '../input/jtc2020-xlm-roberta-large-tokenizer/'

INPUT_IDS = np.load(root_dir + 'xlm_roberta_large_train_input_ids.npy', mmap_mode='r')
ATTENTION_MASK = np.load(root_dir + 'xlm_roberta_large_train_attention_mask.npy', mmap_mode='r')
TARGETS = np.load(root_dir + 'train_toxic.npy', mmap_mode='r')


# didn't shuffle it earlier
shuffler = np.random.permutation(len(INPUT_IDS))
INPUT_IDS = INPUT_IDS[shuffler]
ATTENTION_MASK = ATTENTION_MASK[shuffler]
TARGETS = TARGETS[shuffler]


#we're going to use the validation set provided in the competition as testset
test_input_ids = np.load(root_dir + 'xlm_roberta_large_valid_input_ids.npy', mmap_mode='r')
test_attention_mask = np.load(root_dir + 'xlm_roberta_large_valid_attention_mask.npy', mmap_mode='r')
test_targets = np.load(root_dir + 'valid_toxic.npy', mmap_mode='r')

test_input_ids = test_input_ids
test_attention_mask = test_attention_mask
test_targets = test_targets

# for testing code over smaller sample
input_ids = INPUT_IDS[:307200]
attention_mask = ATTENTION_MASK[:307200]
targets = TARGETS[:307200]

del INPUT_IDS, ATTENTION_MASK, TARGETS
gc.collect()


EPOCHS = 1
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 32


skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=seed)

for fold, (train_idx, valid_idx) in enumerate(skf.split(input_ids, targets)):
    train_ids = input_ids[train_idx]
    train_mask = attention_mask[train_idx]
    train_trg = targets[train_idx]

    train_dataset = JigsawDataset(input_ids=train_ids, attention_mask=train_mask, target=train_trg)
    valid_dataset = JigsawDataset(input_ids=test_input_ids, attention_mask=test_attention_mask, target=test_targets)

    mx = xmp.MpModelWrapper(XLMRobertaLargeTC())
    
    del train_ids, train_mask, train_trg
    gc.collect()
    
    xm.master_print(f"Fold: {fold+1}")
    xm.master_print("")
    # Start training processes
    def _mp_fn(rank, flags):
        a = _run()

    FLAGS={}
    start_time = time.time()
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')
    
print('Time taken: ',time.time()-start_time)