# Install torch_xla: enabling PyTorch on Google TPU

In [0]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'
 
 
def install_stable():
 VERSION = "20200516"  #@param ["1.5" , "20200516", "nightly"]
 !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
 !python pytorch-xla-env-setup.py --version $VERSION


install_stable()

import torch 
print(torch.__version__)

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  4264  100  4264    0     0  49581      0 --:--:-- --:--:-- --:--:-- 49581
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200516 ...
Uninstalling torch-1.5.0+cu101:
Done updating TPU runtime: <Response [200]>
  Successfully uninstalled torch-1.5.0+cu101
Uninstalling torchvision-0.6.0+cu101:
  Successfully uninstalled torchvision-0.6.0+cu101
Copying gs://tpu-pytorch/wheels/torch-nightly+20200516-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 91.0 MiB/ 91.0 MiB]                                                
Operation completed over 1 objects/91.0 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200516-cp36-cp36m-linux_x86_64.whl...
- [1 files][119.8 MiB/119.8 MiB] 

# Set up

Link Google Drive

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


Install transformers from huggingface

In [0]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)
[K     |████████████████████████████████| 675kB 3.5MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 18.0MB/s 
Collecting tokenizers==0.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 26.8MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |█████████

# config.py

In [0]:
import os

import matplotlib
import matplotlib.pyplot as plt

import time

import torch
import transformers
import warnings

import pandas as pd
import numpy as np
import torch.nn as nn

from sklearn import metrics
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

warnings.filterwarnings("ignore")

class config:
    MAX_LEN = 192
    TRAIN_BATCH_SIZE = 64
    VALID_BATCH_SIZE = 4
    EPOCHS = 2
    LEARNING_RATE = 1e-6
    BERT_PATH = "/content/gdrive/My Drive/bert-base-multilingual-uncased/"
    MODEL_PATH = "/content/gdrive/My Drive/bert_model.bin"
    TOKENIZER = transformers.BertTokenizer.from_pretrained(
        BERT_PATH,
        do_lower_case=True
    )
    JIGSAW_DATA_PATH = "/content/gdrive/My Drive/"
    TRAINING_FILE_1 = os.path.join(
        JIGSAW_DATA_PATH, 
        "jigsaw-toxic-comment-train.csv"
    )
    TRAINING_FILE_2 = os.path.join(
        JIGSAW_DATA_PATH, 
        "jigsaw-unintended-bias-train.csv"
    )
    VALIDATION_FILE = os.path.join(
        JIGSAW_DATA_PATH, 
        "validation.csv"
    )

# dataset.py

In [0]:
class JigsawTraining:
    def __init__(self, comment_text, targets, config):
        self.comment_text = comment_text
        self.tokenizer = config.TOKENIZER
        self.max_length = config.MAX_LEN
        self.targets = targets

    def __len__(self):
        return len(self.comment_text)

    def __getitem__(self, item):
        comment_text = str(self.comment_text[item])
        comment_text = " ".join(comment_text.split())

        inputs = self.tokenizer.encode_plus(
            comment_text,
            None,
            add_special_tokens=True,
            max_length=self.max_length,
        )
        ids = inputs["input_ids"]
        token_type_ids = inputs["token_type_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = self.max_length - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[item], dtype=torch.float)
        }

# model.py

In [0]:
import torch.nn as nn

class JigsawModel(nn.Module):
    def __init__(self, bert_path):
        super(JigsawModel, self).__init__()
        self.bert_path = bert_path
        self.bert = transformers.BertModel.from_pretrained(self.bert_path)
        self.bert_drop = nn.Dropout(0.5)
        self.out = nn.Linear(768 * 2, 1)

    def forward(
            self,
            ids,
            mask,
            token_type_ids
    ):
        out_1, out_2 = self.bert(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids)
        
        apool = torch.mean(out_1, 1)
        mpool, _ = torch.max(out_1, 1)
        cat = torch.cat((apool, mpool), 1)

        bo = self.bert_drop(cat)
        p2 = self.out(bo)
        return p2

In [0]:
MX = JigsawModel(config.BERT_PATH)

# engine.py

In [0]:
def loss_fn(outputs, targets):
    return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))


def train_fn(data_loader, model, optimizer, device, scheduler):
    import torch_xla.core.xla_model as xm
    model.train()

    train_losses = []

    for bi, d in enumerate(data_loader):
        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        targets = d["targets"]

        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        optimizer.zero_grad()
        outputs = model(
            ids=ids,
            mask=mask,
            token_type_ids=token_type_ids
        )

        loss = loss_fn(outputs, targets)

        train_losses.append(loss)

        loss.backward()

        xm.optimizer_step(optimizer)

        if scheduler is not None:
            scheduler.step()

        if bi % 100 == 0:
            print(f'[xla:{xm.get_ordinal()}]: bi={bi}, train loss={loss}')

    return train_losses


def eval_fn(data_loader, model, device):
    import torch_xla.core.xla_model as xm
    model.eval()

    val_losses = []

    fin_targets = []
    fin_outputs = []

    with torch.no_grad():
        for bi, d in enumerate(data_loader):
            ids = d["ids"]
            token_type_ids = d["token_type_ids"]
            mask = d["mask"]
            targets = d["targets"]

            ids = ids.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)

            outputs = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
            )

            loss = loss_fn(outputs, targets)
           
            val_losses.append(loss)

            if bi % 100 == 0:
                print(f'[xla:{xm.get_ordinal()}]: bi={bi}, valid loss={loss}')

            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(outputs.cpu().detach().numpy().tolist())

    return fin_outputs, fin_targets, val_losses

# train.py

In [0]:
df_train1 = pd.read_csv(
    config.TRAINING_FILE_1, 
    usecols=["comment_text", "toxic"]
).fillna("none")

df_train2 = pd.read_csv(
    config.TRAINING_FILE_2, 
    usecols=["comment_text", "toxic"]
).fillna("none")

df_valid = pd.read_csv(config.VALIDATION_FILE)

df_train = pd.concat([df_train1, df_train2], axis=0).reset_index(drop=True)
# df_train = df_train.sample(frac=1).reset_index(drop=True).head(200000)
df_train = df_train.sample(frac=1).reset_index(drop=True).head(1000000)


df_train = df_train.reset_index(drop=True)
df_valid = df_valid.reset_index(drop=True)

train_targets = df_train.toxic.values
valid_targets = df_valid.toxic.values

In [0]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def run():
    train_dataset = JigsawTraining(
        comment_text=df_train.comment_text.values,
        targets=train_targets,
        config=config
    )

    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=2
    )

    valid_dataset = JigsawTraining(
        comment_text=df_valid.comment_text.values,
        targets=valid_targets,
        config=config
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=1
    )

    device = xm.xla_device()
    model = MX.to(device)
    
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            'params': [
                p for n, p in param_optimizer if not any(
                    nd in n for nd in no_decay
                )
            ], 
            'weight_decay': 0.001
        },
        {
            'params': [
                p for n, p in param_optimizer if any(
                    nd in n for nd in no_decay
                )
            ],
            'weight_decay': 0.0
        },
    ]

    num_train_steps = int(
        len(df_train) / config.TRAIN_BATCH_SIZE / xm.xrt_world_size() * config.EPOCHS
    )
    optimizer = AdamW(
        optimizer_parameters, 
        lr=config.LEARNING_RATE * xm.xrt_world_size()
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )

    best_auc = 0
    for epoch in range(config.EPOCHS):
        start_time = time.time()

        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_losses = train_fn(
            para_loader.per_device_loader(device), 
            model, 
            optimizer, 
            device, 
            scheduler
        )
        
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        outputs, targets, val_losses = eval_fn(
            para_loader.per_device_loader(device), 
            model, 
            device
        )

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        """        
        plt.plot(train_losses, label="Training loss")
        plt.plot(val_losses, label="Validation loss")
        plt.legend()
        plt.title("Losses")
        plt.show();
        """        

        targets = np.array(targets) >= 0.5
        auc = metrics.roc_auc_score(targets, outputs)

        print(f'[xla:{xm.get_ordinal()}]: AUC={auc}')
        if auc > best_auc:
            xm.save(model.state_dict(), config.MODEL_PATH)
            best_auc = auc
        
        print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')

Multi-processing wrapper

In [0]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = run()

Process spawner for training on TPUs

In [0]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')