In [1]:
import tensorflow as tf
try:
   tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  
   print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
   tpu = None
if tpu:
   tf.config.experimental_connect_to_cluster(tpu)
   tf.tpu.experimental.initialize_tpu_system(tpu)
   strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
   strategy = tf.distribute.get_strategy()

Running on TPU  ['10.0.0.2:8470']


In [2]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly  --apt-packages libomp5 libopenblas-dev

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5115  100  5115    0     0  18600      0 --:--:-- --:--:-- --:--:-- 18600
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Found existing installation: torch 1.5.0
Uninstalling torch-1.5.0:
  Successfully uninstalled torch-1.5.0
Found existing installation: torchvision 0.6.0a0+35d732a
Uninstalling torchvision-0.6.0a0+35d732a:
Done updating TPU runtime
  Successfully uninstalled torchvision-0.6.0a0+35d732a
Copying gs://tpu-pytorch/wheels/torch-nightly-cp37-cp37m-linux_x86_64.whl...
\ [1 files][110.1 MiB/110.1 MiB]                                                
Operation completed over 1 objects/110.1 MiB.                                    
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp37-cp37m-linux_x86_64.whl...
\ [1 files][127.2 MiB/127.2 MiB]                                    

In [3]:
%%time
%autosave 60

import os
os.environ['XLA_USE_BF16'] = "1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

import gc
gc.enable()
import time
import random 

import numpy as np
import pandas as pd
from tqdm import tqdm 

import transformers
from transformers import (AdamW, 
                          XLMRobertaTokenizer, 
                          XLMRobertaModel, 
                          get_cosine_schedule_with_warmup)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.serialization as xser
import torch_xla.version as xv

import warnings
warnings.filterwarnings("ignore")

print('PYTORCH:', xv.__torch_gitrev__)
print('XLA:', xv.__xla_gitrev__)

Autosaving every 60 seconds




PYTORCH: b6810c1064eb24cb542a9be56140689dce8ad7a1
XLA: 2c732f20b5343dd1a694cd69cd15dd7f76ee1028
CPU times: user 1.43 s, sys: 215 ms, total: 1.65 s
Wall time: 2.5 s


In [4]:
train = pd.read_csv('../input/contradictory-my-dear-watson/train.csv')
test = pd.read_csv('../input/contradictory-my-dear-watson/test.csv')
sample_submission = pd.read_csv('../input/contradictory-my-dear-watson/sample_submission.csv')

In [5]:
def load_data(data, sort, tokenizer):
    sentences = list()
    texts = data[['premise','hypothesis']].values.tolist()
    labels = data.label
    
    for (text, label) in zip(texts, labels):
        text = " ".join(text)
        ids = tokenizer.encode(text,
                               add_special_tokens=True,
                              )
        lab = len(ids)
        sentences.append((lab, text, label))
    if sort:
        sentences.sort(key=lambda x: x[0])
    return sentences

def build_batches(sentences, batch_size):
    batch_ordered_sentences = list()
    while len(sentences) > 0:
        to_take = min(batch_size, len(sentences))
        select = random.randint(0, len(sentences) - to_take)
        batch_ordered_sentences += sentences[select:select + to_take]
        del sentences[select:select + to_take]
    return batch_ordered_sentences

class TextDataset(Dataset):
    def __init__(self,
                 data,
                 tokenizer,
                 max_length):       
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_length

    def encode(self, example):
        encoded_dict = self.tokenizer.encode_plus(text=example,
                                                  padding=False,
                                                  truncation=True,
                                                  add_special_tokens=True,
                                                  max_length=self.max_len,
                                                  pad_to_max_length=False,
                                                  return_token_type_ids=False,
                                                  return_attention_mask=False,
                                                  return_position_ids=False,
                                                  return_overflowing_tokens=False,
                                                  return_special_tokens_mask=False,
                                                 )
        return encoded_dict['input_ids']

    def __getitem__(self, idx):
        data = self.data[idx]
        lab, text, label = data
        ids = self.encode(text)
        return ids, label

    def __len__(self):
        return len(self.data)
    
def pad_seq(seq, max_batch_len, pad_value):
    return seq + (max_batch_len - len(seq)) * [pad_value]

def collate_batch(batch):
    batch_inputs = list()
    batch_attention_masks = list()
    batch_labels = list()
    input_ids, labels = zip(*batch)
    max_size = max([len(ids) for ids in input_ids])
    for (ids, label) in zip(input_ids, labels):
        batch_inputs += [pad_seq(ids, max_size, tokenizer.pad_token_id)]
        batch_attention_masks += [[1] * len(ids) + [0] * (max_size - len(ids))]
        batch_labels.append(label)
    return {"input_ids": torch.tensor(batch_inputs),
            "attention_mask": torch.tensor(batch_attention_masks),
            "labels": torch.tensor(batch_labels)
            }

In [6]:
class XLMRoberta(nn.Module):
    def __init__(self, num_labels, multisample):
        super(XLMRoberta, self).__init__()
        output_hidden_states = False
        self.num_labels = num_labels
        self.multisample= multisample
        self.roberta = XLMRobertaModel.from_pretrained("xlm-roberta-base", 
                                                       output_hidden_states=output_hidden_states, 
                                                       num_labels=1)
        self.layer_norm = nn.LayerNorm(768*2)
        self.dropout = nn.Dropout(p=0.2)
        self.high_dropout = nn.Dropout(p=0.5)        
        self.classifier = nn.Linear(768*2, self.num_labels)
    
    def forward(self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None):
        outputs = self.roberta(input_ids,
                               attention_mask=attention_mask,
                               token_type_ids=token_type_ids,
                               position_ids=position_ids,
                               head_mask=head_mask,
                               inputs_embeds=inputs_embeds)
        average_pool = torch.mean(outputs[0], 1)
        max_pool, _ = torch.max(outputs[0], 1)
        concatenate_layer = torch.cat((average_pool, max_pool), 1)
        normalization = self.layer_norm(concatenate_layer)
        if self.multisample:
            # Multisample Dropout
            logits = torch.mean(
                torch.stack(
                    [self.classifier(self.dropout(normalization)) for _ in range(5)],
                    dim=0,
                ),
                dim=0,
            )
        else:
            logits = self.dropout(normalization)
            logits = self.classifier(logits)       
        outputs = logits
        return outputs  

In [7]:
class AverageMeter(object):
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [8]:
def get_model_optimizer(model):
    # Differential Learning Rate
    def is_backbone(name):
        return "roberta" in name
    
    optimizer_grouped_parameters = [
       {'params': [param for name, param in model.named_parameters() if is_backbone(name)], 'lr': LR},
       {'params': [param for name, param in model.named_parameters() if not is_backbone(name)], 'lr': 1e-3} 
    ]
    
    optimizer = AdamW(
        optimizer_grouped_parameters, lr=LR, weight_decay=1e-2
    )
    
    return optimizer

In [9]:
def loss_fn(outputs, targets):
    return nn.CrossEntropyLoss()(outputs, targets)

In [10]:
def train_loop_fn(train_loader, model, optimizer, device, scheduler, epoch=None):
    # Train
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1],
        prefix="[xla:{}]Train:  Epoch: [{}]".format(xm.get_ordinal(), epoch)
    )
    model.train()
    end = time.time()
    for i, data in enumerate(train_loader):
        data_time.update(time.time()-end)
        ids = data["input_ids"]
        mask = data["attention_mask"]
        targets = data["labels"]
        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)
        optimizer.zero_grad()
        outputs = model(
            input_ids = ids,
            attention_mask = mask
        )
        loss = loss_fn(outputs, targets)
        loss.backward()
        xm.optimizer_step(optimizer)
        loss = loss_fn(outputs, targets)
        acc1= accuracy(outputs, targets, topk=(1,))
        losses.update(loss.item(), ids.size(0))
        top1.update(acc1[0].item(), ids.size(0))
        scheduler.step()
        batch_time.update(time.time() - end)
        end = time.time()
        if i % 10 == 0:
            progress.display(i)
    del loss
    del outputs
    del ids
    del mask
    del targets
    gc.collect()

In [11]:
def eval_loop_fn(validation_loader, model, device):
    #Validation
    model.eval()
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    learning_rate = AverageMeter('LR',':2.8f')
    progress = ProgressMeter(
        len(validation_loader),
        [batch_time, losses, top1],
        prefix='[xla:{}]Validation: '.format(xm.get_ordinal()))
    with torch.no_grad():
        end = time.time()
        for i, data in enumerate(validation_loader):
            ids = data["input_ids"]
            mask = data["attention_mask"]
            targets = data["labels"]
            ids = ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)
            outputs = model(
                input_ids = ids,
                attention_mask = mask
            )
            loss = loss_fn(outputs, targets)
            acc1= accuracy(outputs, targets, topk=(1,))
            losses.update(loss.item(), ids.size(0))
            top1.update(acc1[0].item(), ids.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
            if i % 10 == 0:
                progress.display(i)
    del loss
    del outputs
    del ids
    del mask
    del targets
    gc.collect()

In [12]:
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 16
EPOCHS = 15
MAX_LEN = 96
# Scale learning rate to 8 TPU's
LR = 2e-5 * xm.xrt_world_size() 
METRICS_DEBUG = True

WRAPPED_MODEL = xmp.MpModelWrapper(XLMRoberta(num_labels=3, multisample=False))
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=512.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1115590446.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5069051.0, style=ProgressStyle(descript…




In [13]:
mask = np.random.rand(len(train)) < 0.95
train_df = train[mask]
valid_df = train[~mask]

train_sentences = load_data(train_df, sort=True, tokenizer=tokenizer)
train_batches = build_batches(train_sentences, batch_size=TRAIN_BATCH_SIZE)
train_dataset = TextDataset(train_batches, tokenizer, max_length=MAX_LEN)

valid_sentences = load_data(valid_df, sort=True, tokenizer=tokenizer)
valid_batches = build_batches(valid_sentences, batch_size=VALID_BATCH_SIZE)
valid_dataset = TextDataset(valid_batches, tokenizer, max_length=MAX_LEN)

In [15]:
def _run():
    xm.master_print('Starting Run ...')
    
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    
    train_data_loader = DataLoader(
        train_dataset,
        collate_fn = collate_batch,
        batch_size=TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=0
    )
    xm.master_print('Train Loader Created.')
    
    valid_sampler = DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    
    valid_data_loader = DataLoader(
        valid_dataset,
        collate_fn=collate_batch,
        batch_size=VALID_BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=True,
        num_workers=0
    )
    xm.master_print('Valid Loader Created.')
    
    num_train_steps = int(len(train_df) / TRAIN_BATCH_SIZE / xm.xrt_world_size())
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    xm.master_print('Done Model Loading.')
    optimizer = get_model_optimizer(model)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps = 0,
        num_training_steps = num_train_steps * EPOCHS
    )
    xm.master_print(f'Num Train Steps= {num_train_steps}, XRT World Size= {xm.xrt_world_size()}.')
    
    for epoch in range(EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        xm.master_print('Parallel Loader Created. Training ...')
        train_loop_fn(para_loader.per_device_loader(device),
                      model,  
                      optimizer, 
                      device, 
                      scheduler, 
                      epoch
                     )
        
        xm.master_print("Finished training epoch {}".format(epoch))
            
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        xm.master_print('Parallel Loader Created. Validating ...')
        eval_loop_fn(para_loader.per_device_loader(device), 
                     model,  
                     device
                    )
        
        # Serialized and Memory Reduced Model Saving
        if epoch == EPOCHS-1:
            xm.master_print('Saving Model ..')
            xser.save(model.state_dict(), "model.bin", master_only=True)
            xm.master_print('Model Saved.')
            
    if METRICS_DEBUG:
      xm.master_print(met.metrics_report(), flush=True)

In [16]:
def _mp_fn(rank, flags):
    # torch.set_default_tensor_type('torch.FloatTensor')
    _run()

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

Starting Run ...
Train Loader Created.
Valid Loader Created.
Done Model Loading.
Num Train Steps= 90, XRT World Size= 8.
Parallel Loader Created. Training ...
[xla:7]Train:  Epoch: [0][ 0/90]	Time 19.825 (19.825)	Data  0.070 ( 0.070)	Loss 1.0703e+00 (1.0703e+00)	Acc@1  43.75 ( 43.75)
[xla:3]Train:  Epoch: [0][ 0/90]	Time 16.296 (16.296)	Data  0.085 ( 0.085)	Loss 1.0391e+00 (1.0391e+00)	Acc@1  62.50 ( 62.50)
[xla:4]Train:  Epoch: [0][ 0/90]	Time 12.814 (12.814)	Data  0.075 ( 0.075)	Loss 1.2031e+00 (1.2031e+00)	Acc@1  37.50 ( 37.50)
[xla:6]Train:  Epoch: [0][ 0/90]	Time  2.306 ( 2.306)	Data  0.080 ( 0.080)	Loss 1.1562e+00 (1.1562e+00)	Acc@1  31.25 ( 31.25)
[xla:0]Train:  Epoch: [0][ 0/90]	Time 27.359 (27.359)	Data  0.307 ( 0.307)	Loss 1.4219e+00 (1.4219e+00)	Acc@1  12.50 ( 12.50)
[xla:5]Train:  Epoch: [0][ 0/90]	Time 23.773 (23.773)	Data  0.069 ( 0.069)	Loss 1.2734e+00 (1.2734e+00)	Acc@1  25.00 ( 25.00)
[xla:2]Train:  Epoch: [0][ 0/90]	Time  9.538 ( 9.538)	Data  0.068 ( 0.068)	Loss 1.218

Exception: process 2 terminated with signal SIGKILL