### Check TPU is available

In [1]:
import tensorflow as tf
try:
   tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  
   print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
   tpu = None
if tpu:
   tf.config.experimental_connect_to_cluster(tpu)
   tf.tpu.experimental.initialize_tpu_system(tpu)
   strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
   strategy = tf.distribute.get_strategy()

Running on TPU  ['10.0.0.2:8470']


### Setup Dependencies

In [2]:
!pip install nlp
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly  --apt-packages libomp5 libopenblas-dev

Collecting nlp
  Downloading nlp-0.4.0-py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 1.3 MB/s 
Collecting xxhash
  Downloading xxhash-2.0.0-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)
[K     |████████████████████████████████| 243 kB 4.2 MB/s 
Installing collected packages: xxhash, nlp
Successfully installed nlp-0.4.0 xxhash-2.0.0
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5115  100  5115    0     0  34560      0 --:--:-- --:--:-- --:--:-- 34560
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Found existing installation: torch 1.5.0
Uninstalling torch-1.5.0:
  Successfully uninstalled torch-1.5.0
Found existing installation: torchvision 0.6.0a0+35d732a
Uninstalling torchvisio

In [3]:
%%time
%autosave 60

import os
os.environ['XLA_USE_BF16'] = "1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

import gc
gc.enable()
import time

import numpy as np
import pandas as pd
from tqdm import tqdm 

import nlp
import transformers
from transformers import (AdamW, 
                          XLMRobertaTokenizer, 
                          XLMRobertaModel, 
                          get_cosine_schedule_with_warmup)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.serialization as xser
import torch_xla.version as xv

import warnings
warnings.filterwarnings("ignore")

print('PYTORCH:', xv.__torch_gitrev__)
print('XLA:', xv.__xla_gitrev__)

Autosaving every 60 seconds




PYTORCH: ecb9e790ed6ceafa738ad52a500b9e50bc0fc241
XLA: 210bfe312e98d58a4f7148e921017775b0623e6e
CPU times: user 1.3 s, sys: 237 ms, total: 1.54 s
Wall time: 2.19 s


### Data Files

In [4]:
train = pd.read_csv('../input/contradictory-my-dear-watson/train.csv')
test = pd.read_csv('../input/contradictory-my-dear-watson/test.csv')
sample_submission = pd.read_csv('../input/contradictory-my-dear-watson/sample_submission.csv')

### CONFIG

In [5]:
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 16
EPOCHS = 4
MAX_LEN = 80
# Scale learning rate to 8 TPU's
LR = 2e-5 * xm.xrt_world_size() 
METRICS_DEBUG = True
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5069051.0, style=ProgressStyle(descript…




### External Datasets

HuggingFace [`nlp`](https://huggingface.co/nlp/index.html) library contains datasets and evaluation metrics for natural language processing.
It is compatible with NumPy, Pandas, PyTorch and TensorFlow. 

We will be using three external datasets:

1. [XNLI](https://huggingface.co/nlp/viewer/?dataset=xnli) is a subset of a few thousand examples from MNLI which has been translated into a 14 different languages (some low-ish resource). 
2. [Glue](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli) the General Language Understanding Evaluation benchmark (https://gluebenchmark.com/) is a collection of resources for training, evaluating, and analyzing natural language understanding systems.
3. [SNLI](https://huggingface.co/nlp/viewer/?dataset=snli) is a collection of 570k human-written English sentence pairs manually labeled for balanced classification with the labels entailment, contradiction, and neutral, supporting the task of natural language inference (NLI), also known as recognizing textual entailment (RTE).

*Note: Three datasets combined reach more than 100k examples and hence will be difficult to train within the assigned resources. Since this is a multi-lingual based competition we will fetch **100% test, validation data** from XNLI, **25% train data** from Glue/MNLI and **25% train data** from SNLI.*

In [6]:
# mnli data
mnli = nlp.load_dataset(path='glue', name='mnli', split='train[:5%]')

# xnli data
xnli = nlp.load_dataset(path='xnli')
xnli = nlp.concatenate_datasets([xnli['test'], xnli['validation']])

# snli data
snli = nlp.load_dataset(path='snli', split='train[:5%]')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=29068.0, style=ProgressStyle(descriptio…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=30329.0, style=ProgressStyle(descriptio…


Downloading and preparing dataset glue/mnli (download: 298.29 MiB, generated: 78.65 MiB, post-processed: Unknown sizetotal: 376.95 MiB) to /root/.cache/huggingface/datasets/glue/mnli/1.0.0/005857b1e5a6280d8f1a9b9537d44a08ba30cb6be958e81fac98e625a0d487a7...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=312783507.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/mnli/1.0.0/005857b1e5a6280d8f1a9b9537d44a08ba30cb6be958e81fac98e625a0d487a7. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4263.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2301.0, style=ProgressStyle(description…


Downloading and preparing dataset xnli/plain_text (download: 17.04 MiB, generated: 27.66 MiB, post-processed: Unknown sizetotal: 44.70 MiB) to /root/.cache/huggingface/datasets/xnli/plain_text/1.0.0/9bed2c9a464959786460b62992dcfde22ca64526ba9d9e151cd0754249266614...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=17865352.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset xnli downloaded and prepared to /root/.cache/huggingface/datasets/xnli/plain_text/1.0.0/9bed2c9a464959786460b62992dcfde22ca64526ba9d9e151cd0754249266614. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3827.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1945.0, style=ProgressStyle(description…


Downloading and preparing dataset snli/plain_text (download: 90.17 MiB, generated: 65.51 MiB, post-processed: Unknown sizetotal: 155.68 MiB) to /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/e417f6f2e16254938d977a17ed32f3998f5b23e4fcab0f6eb1d28784f23ea60d...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1929.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1259440.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=65886400.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1263568.0, style=ProgressStyle(descript…


Dataset snli downloaded and prepared to /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/e417f6f2e16254938d977a17ed32f3998f5b23e4fcab0f6eb1d28784f23ea60d. Subsequent calls will reuse this data.


#### External Data Statistics

In [7]:
print("#"*25)
print("  MNLI"); print("#"*25)
print("Shape: ", mnli.shape)
print("Num of Samples: ", mnli.num_rows)
print("Num of Columns: ", mnli.num_columns)
print("Column Names: ", mnli.column_names)
print("Features: ", mnli.features)
print("Num of Classes: ", mnli.features['label'].num_classes)
print("Split: ", mnli.split)
print("Description: ", mnli.description)
print(f"Labels Count - 0's:{len(mnli.filter(lambda x: x['label']==0))}, 1's:{len(mnli.filter(lambda x: x['label']==1))}, 2's: 0's:{len(mnli.filter(lambda x: x['label']==2))}")
print()
print("#"*25)
print("  XNLI"); print("#"*25)
print("Shape: ", xnli.shape)
print("Num of Samples: ", xnli.num_rows)
print("Num of Columns: ", xnli.num_columns)
print("Column Names: ", xnli.column_names)
print("Features: ", xnli.features)
print("Split: ", xnli.split)
print("Description: ", xnli.description)
print(f"Labels Count - 0's:{len(xnli.filter(lambda x: x['label']==0))}, 1's:{len(xnli.filter(lambda x: x['label']==1))}, 2's: 0's:{len(xnli.filter(lambda x: x['label']==2))}")
print()
print("#"*25)
print("  SNLI"); print("#"*25)
print("Shape: ", snli.shape)
print("Num of Samples: ", snli.num_rows)
print("Num of Columns: ", snli.num_columns)
print("Column Names: ", snli.column_names)
print("Features: ", snli.features)
print("Num of Classes: ", snli.features['label'].num_classes)
print("Split: ", snli.split)
print("Description: ", snli.description)
print(f"Labels Count - 0's:{len(snli.filter(lambda x: x['label']==0))}, 1's:{len(snli.filter(lambda x: x['label']==1))}, 2's: 0's:{len(snli.filter(lambda x: x['label']==2))}")

#########################
  MNLI
#########################
Shape:  (19635, 4)
Num of Samples:  19635
Num of Columns:  4
Column Names:  ['hypothesis', 'idx', 'label', 'premise']
Features:  {'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradiction'], names_file=None, id=None), 'idx': Value(dtype='int32', id=None)}
Num of Classes:  3
Split:  train[:5%]
Description:  GLUE, the General Language Understanding Evaluation benchmark
(https://gluebenchmark.com/) is a collection of resources for training,
evaluating, and analyzing natural language understanding systems.




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


Labels Count - 0's:6648, 1's:5836, 2's: 0's:7151

#########################
  XNLI
#########################
Shape:  (7500, 3)
Num of Samples:  7500
Num of Columns:  3
Column Names:  ['hypothesis', 'label', 'premise']
Features:  {'hypothesis': {'language': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'translation': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}, 'label': Value(dtype='int64', id=None), 'premise': {'ar': Value(dtype='string', id=None), 'bg': Value(dtype='string', id=None), 'de': Value(dtype='string', id=None), 'el': Value(dtype='string', id=None), 'en': Value(dtype='string', id=None), 'es': Value(dtype='string', id=None), 'fr': Value(dtype='string', id=None), 'hi': Value(dtype='string', id=None), 'ru': Value(dtype='string', id=None), 'sw': Value(dtype='string', id=None), 'th': Value(dtype='string', id=None), 'tr': Value(dtype='string', id=None), 'ur': Value(dtype='string', id=None), 'vi': Value(dtype='string', id=None), 'z

HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Labels Count - 0's:2500, 1's:2500, 2's: 0's:2500

#########################
  SNLI
#########################
Shape:  (27508, 3)
Num of Samples:  27508
Num of Columns:  3
Column Names:  ['premise', 'hypothesis', 'label']
Features:  {'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradiction'], names_file=None, id=None)}
Num of Classes:  3
Split:  train[:5%]
Description:  The SNLI corpus (version 1.0) is a collection of 570k human-written English
sentence pairs manually labeled for balanced classification with the labels
entailment, contradiction, and neutral, supporting the task of natural language
inference (NLI), also known as recognizing textual entailment (RTE).



HBox(children=(FloatProgress(value=0.0, max=28.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=28.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=28.0), HTML(value='')))


Labels Count - 0's:9176, 1's:9145, 2's: 0's:9151


#### Helper Functions

In [8]:
# encoding
def convert_to_features(batch):
    input_pairs = list(zip(batch['premise'], batch['hypothesis']))
    encodings = tokenizer.batch_encode_plus(input_pairs, 
                                            add_special_tokens=True, 
                                            pad_to_max_length=True, 
                                            max_length=MAX_LEN, 
                                            truncation=True, 
                                            return_attention_mask=True, 
                                            return_token_type_ids=True)
    return encodings

In [9]:
# function to preprocess special structure of xnli
def preprocess_xnli(example):
    premise_output = []
    hypothesis_output = []
    label_output = []
    for prem, hyp, lab in zip(example['premise'],  example['hypothesis'], example["label"]):
        label = lab
        langs = hyp['language']
        translations = hyp['translation']
        hypothesis = {k: v for k, v in zip(langs, translations)}
        for lang in prem:
            if lang in hypothesis:
                premise_output += [prem[lang]]
                hypothesis_output += [hypothesis[lang]]
                label_output += [label]
    return {'premise':premise_output, 'hypothesis':hypothesis_output, 'label':label_output}

#### Encode Datasets

In [10]:
# encode mnli and convert to torch tensor
mnli_encoded = mnli.map(convert_to_features, batched=True, remove_columns=['idx', 'premise', 'hypothesis'])
mnli_encoded.set_format("torch", columns=['attention_mask', 'input_ids', 'token_type_ids', 'label'])

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




In [11]:
# preprocess xnli, encode and convert to torch tensor
xnli_processed = xnli.map(preprocess_xnli, batched=True)
xnli_encoded = xnli_processed.map(convert_to_features, batched=True, remove_columns=['premise', 'hypothesis'])
xnli_encoded.set_format("torch", columns=['attention_mask', 'input_ids', 'token_type_ids', 'label']) 

HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=113.0), HTML(value='')))




In [12]:
# encode snli and convert to torch tensor
snli_encoded = snli.map(convert_to_features, batched=True, remove_columns=['premise', 'hypothesis'])
snli_encoded.set_format("torch", columns=['attention_mask', 'input_ids', 'token_type_ids', 'label']) 

HBox(children=(FloatProgress(value=0.0, max=28.0), HTML(value='')))




#### Encoded Data Statistics

In [13]:
print(mnli_encoded.column_names)
print(snli_encoded.column_names)
print(xnli_encoded.column_names)

print(mnli_encoded.num_rows)
print(snli_encoded.num_rows)
print(xnli_encoded.num_rows)

['label', 'input_ids', 'token_type_ids', 'attention_mask']
['label', 'input_ids', 'token_type_ids', 'attention_mask']
['label', 'input_ids', 'token_type_ids', 'attention_mask']
19635
27508
112500


#### Competitions Data - Convert & Encode 

In [14]:
train_dataset = nlp.load_dataset('csv', data_files=['../input/contradictory-my-dear-watson/train.csv'])['train']

print(train_dataset.num_rows)
print(train_dataset.column_names)
drop_columns = train_dataset.column_names[:-1]

encoded_train_dataset = train_dataset.map(convert_to_features, batched=True, remove_columns=drop_columns)
encoded_train_dataset.set_format("torch", columns=['attention_mask', 'input_ids', 'token_type_ids', 'label']) 
print(encoded_train_dataset.num_rows)
print(encoded_train_dataset.column_names)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2720.0, style=ProgressStyle(description…


Downloading and preparing dataset csv/default-2a59171e5677bc0a (download: Unknown size, generated: Unknown size, post-processed: Unknown sizetotal: Unknown size) to /root/.cache/huggingface/datasets/csv/default-2a59171e5677bc0a/0.0.0/d27f9d4163bc98ad11a8c6b35120f1486e488f0fba0736cae84fcc51c291c35e...


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-2a59171e5677bc0a/0.0.0/d27f9d4163bc98ad11a8c6b35120f1486e488f0fba0736cae84fcc51c291c35e. Subsequent calls will reuse this data.
12120
['id', 'premise', 'hypothesis', 'lang_abv', 'language', 'label']


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


12120
['label', 'input_ids', 'token_type_ids', 'attention_mask']


#### Cocatenate and shuffle all datasets

In [15]:
train_dataset = nlp.concatenate_datasets([mnli_encoded, 
                                          xnli_encoded, 
                                          snli_encoded,
                                          encoded_train_dataset
                                         ])

print(train_dataset.num_rows)
print(train_dataset.column_names)

171763
['label', 'input_ids', 'token_type_ids', 'attention_mask']


In [16]:
train_dataset.cleanup_cache_files()
del mnli, mnli_encoded
del xnli, xnli_encoded, xnli_processed
del snli, snli_encoded
gc.collect()

54

### Dataset Factory

In [17]:
class DatasetRetriever(Dataset):
    def __init__(self, dataset:nlp.arrow_dataset.Dataset):
        self.dataset = dataset
        self.ids = self.dataset['input_ids']
        self.mask = self.dataset['attention_mask']
        self.type_ids = self.dataset['token_type_ids']
        self.targets = self.dataset["label"]
        
    def __len__(self):
        return self.dataset.num_rows
    
    def __getitem__(self, index):   
        ids = self.ids[index]
        mask = self.mask[index]
        type_ids = self.type_ids[index]
        targets = self.targets[index]
        return {
            'ids':torch.tensor(ids),
            'mask':torch.tensor(mask),
            'type_ids':torch.tensor(type_ids),
            'targets':targets
        }

### Model Factory

In [18]:
class XLMRoberta(nn.Module):
    def __init__(self, num_labels, multisample):
        super(XLMRoberta, self).__init__()
        output_hidden_states = False
        self.num_labels = num_labels
        self.multisample= multisample
        self.roberta = XLMRobertaModel.from_pretrained("xlm-roberta-large", 
                                                       output_hidden_states=output_hidden_states, 
                                                       num_labels=1)
        self.layer_norm = nn.LayerNorm(1024*2)
        self.dropout = nn.Dropout(p=0.2)
        self.high_dropout = nn.Dropout(p=0.5)        
        self.classifier = nn.Linear(1024*2, self.num_labels)
    
    def forward(self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None):
        outputs = self.roberta(input_ids,
                               attention_mask=attention_mask,
                               token_type_ids=token_type_ids,
                               position_ids=position_ids,
                               head_mask=head_mask,
                               inputs_embeds=inputs_embeds)
        average_pool = torch.mean(outputs[0], 1)
        max_pool, _ = torch.max(outputs[0], 1)
        concatenate_layer = torch.cat((average_pool, max_pool), 1)
        normalization = self.layer_norm(concatenate_layer)
        if self.multisample:
            # Multisample Dropout
            logits = torch.mean(
                torch.stack(
                    [self.classifier(self.dropout(normalization)) for _ in range(5)],
                    dim=0,
                ),
                dim=0,
            )
        else:
            logits = self.dropout(normalization)
            logits = self.classifier(logits)       
        outputs = logits
        return outputs  

### Metrics Factory

In [19]:
class AverageMeter(object):
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

### Optimizer Factory

In [20]:
def get_model_optimizer(model):
    # Differential Learning Rate
    def is_backbone(name):
        return "roberta" in name
    
    optimizer_grouped_parameters = [
       {'params': [param for name, param in model.named_parameters() if is_backbone(name)], 'lr': LR},
       {'params': [param for name, param in model.named_parameters() if not is_backbone(name)], 'lr': 1e-3} 
    ]
    
    optimizer = AdamW(
        optimizer_grouped_parameters, lr=LR, weight_decay=1e-2
    )
    
    return optimizer

### Loss Factory

In [21]:
def loss_fn(outputs, targets):
    return nn.CrossEntropyLoss()(outputs, targets)

### Training

In [22]:
def train_loop_fn(train_loader, model, optimizer, device, scheduler, epoch=None):
    # Train
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1],
        prefix="[xla:{}]Train:  Epoch: [{}]".format(xm.get_ordinal(), epoch)
    )
    model.train()
    end = time.time()
    for i, data in enumerate(train_loader):
        data_time.update(time.time()-end)
        ids, mask, type_ids, targets = data["input_ids"], data["attention_mask"], data['token_type_ids'], data["label"]
        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        type_ids = type_ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)
        optimizer.zero_grad()
        outputs = model(
            input_ids = ids,
            attention_mask = mask,
            token_type_ids = type_ids
        )
        loss = loss_fn(outputs, targets)
        loss.backward()
        xm.optimizer_step(optimizer)
        loss = loss_fn(outputs, targets)
        acc1= accuracy(outputs, targets, topk=(1,))
        losses.update(loss.item(), ids.size(0))
        top1.update(acc1[0].item(), ids.size(0))
        scheduler.step()
        batch_time.update(time.time() - end)
        end = time.time()
        if i % 50 == 0:
            progress.display(i)
    del loss
    del outputs
    del ids
    del mask
    del targets
    gc.collect()

### Evaluation

In [23]:
def eval_loop_fn(validation_loader, model, device):
    #Validation
    model.eval()
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    learning_rate = AverageMeter('LR',':2.8f')
    progress = ProgressMeter(
        len(validation_loader),
        [batch_time, losses, top1],
        prefix='[xla:{}]Validation: '.format(xm.get_ordinal()))
    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(validation_loader):
            ids, mask, type_ids, targets = data["input_ids"], data["attention_mask"], data['token_type_ids'], data["label"]
            ids = ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            type_ids = type_ids.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)
            outputs = model(
                input_ids = ids,
                attention_mask = mask,
                token_type_ids = type_ids
            )
            loss = loss_fn(outputs, targets)
            acc1= accuracy(outputs, targets, topk=(1,))
            losses.update(loss.item(), ids.size(0))
            top1.update(acc1[0].item(), ids.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
            if i % 50 == 0:
                progress.display(i)
    del loss
    del outputs
    del ids
    del mask
    del targets
    gc.collect()

### Model and Dataset Config

In [24]:
WRAPPED_MODEL = xmp.MpModelWrapper(XLMRoberta(num_labels=3, multisample=False))

dataset = train_dataset.train_test_split(test_size=0.1)
train_dataset = dataset['train']
valid_dataset = dataset['test']
train_dataset.set_format("torch", columns=['attention_mask', 'input_ids', 'token_type_ids', 'label']) 
valid_dataset.set_format("torch", columns=['attention_mask', 'input_ids', 'token_type_ids', 'label']) 

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=513.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2244861551.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, max=155.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))




### Run

In [25]:
def _run():
    xm.master_print('Starting Run ...')
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    
    train_data_loader = DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=0
    )
    xm.master_print('Train Loader Created.')
    
    valid_sampler = DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    
    valid_data_loader = DataLoader(
        valid_dataset,
        batch_size=VALID_BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=True,
        num_workers=0
    )
    xm.master_print('Valid Loader Created.')
    
    num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size())
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    xm.master_print('Done Model Loading.')
    optimizer = get_model_optimizer(model)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps = 0,
        num_training_steps = num_train_steps * EPOCHS
    )
    xm.master_print(f'Num Train Steps= {num_train_steps}, XRT World Size= {xm.xrt_world_size()}.')
    
    for epoch in range(EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        xm.master_print('Parallel Loader Created. Training ...')
        train_loop_fn(para_loader.per_device_loader(device),
                      model,  
                      optimizer, 
                      device, 
                      scheduler, 
                      epoch
                     )
        
        xm.master_print("Finished training epoch {}".format(epoch))
            
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        xm.master_print('Parallel Loader Created. Validating ...')
        eval_loop_fn(para_loader.per_device_loader(device), 
                     model,  
                     device
                    )
        
        # Serialized and Memory Reduced Model Saving
        if epoch == EPOCHS-1:
            xm.master_print('Saving Model ..')
            xm.save(model.state_dict(), "model.bin")
            xm.master_print('Model Saved.')
            
    if METRICS_DEBUG:
      xm.master_print(met.metrics_report(), flush=True)

In [26]:
def _mp_fn(rank, flags):
    # torch.set_default_tensor_type('torch.FloatTensor')
    _run()

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

Starting Run ...
Train Loader Created.
Valid Loader Created.
Done Model Loading.
Num Train Steps= 1207, XRT World Size= 8.
Parallel Loader Created. Training ...
[xla:2]Train:  Epoch: [0][   0/1207]	Time 94.077 (94.077)	Data  0.200 ( 0.200)	Loss 1.2578e+00 (1.2578e+00)	Acc@1  31.25 ( 31.25)
[xla:6]Train:  Epoch: [0][   0/1207]	Time 82.875 (82.875)	Data  0.203 ( 0.203)	Loss 1.3516e+00 (1.3516e+00)	Acc@1  25.00 ( 25.00)
[xla:3]Train:  Epoch: [0][   0/1207]	Time 99.932 (99.932)	Data  0.198 ( 0.198)	Loss 1.3984e+00 (1.3984e+00)	Acc@1  31.25 ( 31.25)
[xla:7]Train:  Epoch: [0][   0/1207]	Time 65.987 (65.987)	Data  0.191 ( 0.191)	Loss 1.4844e+00 (1.4844e+00)	Acc@1  31.25 ( 31.25)
[xla:4]Train:  Epoch: [0][   0/1207]	Time 71.837 (71.837)	Data  0.286 ( 0.286)	Loss 1.2344e+00 (1.2344e+00)	Acc@1  18.75 ( 18.75)
[xla:5]Train:  Epoch: [0][   0/1207]	Time 77.344 (77.344)	Data  0.189 ( 0.189)	Loss 1.3047e+00 (1.3047e+00)	Acc@1  31.25 ( 31.25)
[xla:1]Train:  Epoch: [0][   0/1207]	Time 88.447 (88.447)	D