# Install torch_xla: enabling PyTorch on Google TPU

In [0]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'
 
 
def install_stable():
 VERSION = "nightly"  #@param ["1.5" , "20200516", "nightly"]
 !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
 !python pytorch-xla-env-setup.py --version $VERSION


install_stable()

import torch 
print(torch.__version__)

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  4264  100  4264    0     0  18620      0 --:--:-- --:--:-- --:--:-- 18620
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Uninstalling torch-1.5.0+cu101:
Done updating TPU runtime: <Response [200]>
  Successfully uninstalled torch-1.5.0+cu101
Uninstalling torchvision-0.6.0+cu101:
  Successfully uninstalled torchvision-0.6.0+cu101
Copying gs://tpu-pytorch/wheels/torch-nightly-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 89.6 MiB/ 89.6 MiB]                                                
Operation completed over 1 objects/89.6 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp36-cp36m-linux_x86_64.whl...
- [1 files][117.1 MiB/117.1 MiB]                                                
Operation completed over 1 objects/117.1 MiB.         

# Set up

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)
[K     |████████████████████████████████| 675kB 2.1MB/s 
[?25hCollecting tokenizers==0.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 10.2MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 26.5MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |███

# config.py

In [0]:
import os
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

import torch
import pandas as pd
from scipy import stats
import numpy as np

from tqdm import tqdm
from collections import OrderedDict, namedtuple
import torch.nn as nn
from torch.optim import lr_scheduler
import joblib

import logging
import transformers
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule, XLMRobertaTokenizer, XLMRobertaModel, XLMRobertaConfig
import sys
from sklearn import metrics, model_selection


import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

import warnings
warnings.filterwarnings("ignore")


class config:
    MAX_LEN = 192
    TRAIN_BATCH_SIZE = 16
    VALID_BATCH_SIZE = 4
    EPOCHS = 1
    LEARNING_RATE = 1e-6

In [0]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# model.py

In [0]:
class CustomRoberta(nn.Module):
    def __init__(self):
        super(CustomRoberta, self).__init__()
        self.num_labels = 1
        self.roberta = transformers.XLMRobertaModel.from_pretrained("xlm-roberta-large", output_hidden_states=False, num_labels=1)
        self.dropout = nn.Dropout(p=0.2)
        self.classifier = nn.Linear(1024, self.num_labels)

    def forward(self,
                input_ids=None,
                attention_mask=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None):

        _, o2 = self.roberta(input_ids,
                               attention_mask=attention_mask,
                               position_ids=position_ids,
                               head_mask=head_mask,
                               inputs_embeds=inputs_embeds)

        logits = self.classifier(o2)       
        outputs = logits
        return outputs

In [0]:
mx = CustomRoberta();
mx

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=513.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2244861551.0, style=ProgressStyle(descr…




CustomRoberta(
  (roberta): XLMRobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(250002, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1

# dataset.py

In [0]:
class ArrayDataset(torch.utils.data.Dataset):
    def __init__(self,*arrays):
        assert all(arrays[0].shape[0] == array.shape[0] for array in arrays)
        self.arrays = arrays
    
    def __getitem__(self, index):
        return tuple(torch.from_numpy(np.array(array[index])) for array in self.arrays)
    
    def __len__(self):
        return self.arrays[0].shape[0]

In [0]:
tokenized_path = '/content/gdrive/My Drive/xlm-r-large-tokenize-dataset/'

In [0]:
x_train = np.load(tokenized_path+'x_train.npy',mmap_mode='r')
train_toxic = np.load(tokenized_path+'df_train_toxic.npy',mmap_mode='r')

x_valid = np.load(tokenized_path+'x_valid.npy',mmap_mode='r')
valid_toxic = np.load(tokenized_path+'df_valid_toxic.npy',mmap_mode='r')

In [0]:
x_train.shape, x_valid.shape

((208000, 192), (8000, 192))

In [0]:
train_dataset = ArrayDataset(x_train, train_toxic)
valid_dataset = ArrayDataset(x_valid, valid_toxic)

In [0]:
del x_train, x_valid
import gc;gc.collect()

771

In [0]:
gc.collect()

0

# train.py

In [0]:
import torch_xla.version as xv
print('PYTORCH:', xv.__torch_gitrev__)
print('XLA:', xv.__xla_gitrev__)

PYTORCH: af05158c56af29e062580f458a86a32b8f4c2b85
XLA: 54fe46e0397bfa5f5589e17b4696923a15a959f0


In [0]:
!free -h

              total        used        free      shared  buff/cache   available
Mem:            12G        2.8G        9.5G        976K        387M        9.6G
Swap:            0B          0B          0B


In [0]:
def loss_fn(outputs, targets):
    return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))

def reduce_fn(vals):
    return sum(vals) / len(vals)

def train_loop_fn(data_loader, model, optimizer, device, scheduler=None):
    model.train()
    for bi, d in enumerate(data_loader):

        ids = d[0]
        targets = d[1]

        ids = ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        optimizer.zero_grad()
        outputs = model(
            input_ids=ids,
        )
        loss = loss_fn(outputs, targets)
        if bi % 50 == 0:
            loss_reduced = xm.mesh_reduce('loss_reduce',loss,reduce_fn)
            xm.master_print(f'bi={bi}, loss={loss_reduced}')
        loss.backward()
        xm.optimizer_step(optimizer)
        if scheduler is not None:
            scheduler.step()
            

    model.eval()
    
def eval_loop_fn(data_loader, model, device):
    fin_targets = []
    fin_outputs = []
    for bi, d in enumerate(data_loader):
        ids = d[0]
        targets = d[1]

        ids = ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        outputs = model(
            input_ids=ids,
        )

        targets_np = targets.cpu().detach().numpy().tolist()
        outputs_np = outputs.cpu().detach().numpy().tolist()
        fin_targets.extend(targets_np)
        fin_outputs.extend(outputs_np)    
        del targets_np, outputs_np
        gc.collect()
    return fin_outputs, fin_targets

# run.py

In [0]:
def _run():
    MAX_LEN = config.MAX_LEN
    TRAIN_BATCH_SIZE = config.TRAIN_BATCH_SIZE
    EPOCHS = config.EPOCHS

    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=0,
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=0
    )

    device = xm.xla_device()
    model = mx.to(device)
    xm.master_print('done loading model')

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    lr = 0.5e-5 * xm.xrt_world_size()
    num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )
    xm.master_print(f'num_train_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')


    for epoch in range(EPOCHS):
        gc.collect()
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        xm.master_print('parallel loader created... training now')
        gc.collect()
        train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=scheduler)
        del para_loader
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        gc.collect()
        o, t = eval_loop_fn(para_loader.per_device_loader(device), model, device)
        del para_loader
        gc.collect()
        auc = metrics.roc_auc_score(np.array(t) >= 0.5, o)
        auc_reduced = xm.mesh_reduce('auc_reduce',auc,reduce_fn)
        xm.master_print(f'AUC = {auc_reduced}')
        gc.collect()
    xm.save(model.state_dict(), "/content/gdrive/My Drive/xlm_roberta_model.bin")

In [0]:
import time

def train_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# Start training processes
def _mp_fn(rank, flags):
    a = _run()

FLAGS={}
start_time = time.time()
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')
end_time = time.time()
mins, secs = train_time(start_time, end_time)
print(f'Train Time: {mins}m {secs}s')

Exception: ignored