In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/imdb-csv/train.csv
/kaggle/input/imdb-csv/test.csv


In [2]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5116  100  5116    0     0  19452      0 --:--:-- --:--:-- --:--:-- 19452
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Found existing installation: torch 1.5.0
Uninstalling torch-1.5.0:
  Successfully uninstalled torch-1.5.0
Found existing installation: torchvision 0.6.0a0+35d732a
Uninstalling torchvision-0.6.0a0+35d732a:
Done updating TPU runtime
  Successfully uninstalled torchvision-0.6.0a0+35d732a
Copying gs://tpu-pytorch/wheels/torch-nightly-cp37-cp37m-linux_x86_64.whl...

Operation completed over 1 objects/94.3 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp37-cp37m-linux_x86_64.whl...

Operation completed over 1 objects/131.0 MiB.                                    
Copying gs://tpu-pytorch/wheels/torchvision-nig

In [3]:
import argparse
import os
import torch
import torch.nn as nn
import random
import numpy as np
from tqdm import tqdm
import time
from torch.utils.data import TensorDataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp



In [4]:
parser = argparse.ArgumentParser()
parser.add_argument('-seed', default=0, type=int)
parser.add_argument('-max_seq_length', default=512, type=int)
parser.add_argument('-batch_size', default=24, type=int)
parser.add_argument('-num_epochs', default=4, type=int)
parser.add_argument('-learning_rate', default=2e-5, type=float)
parser.add_argument('-max_grad_norm', default=1.0, type=float)
parser.add_argument('-warm_up_proportion', default=0.1, type=float)
parser.add_argument('-bert_path', default='bert-base-uncased')
parser.add_argument('-trunc_mode', default=128, type=str)
args = parser.parse_args([])
args.learning_rate = args.learning_rate * xm.xrt_world_size()

In [5]:
tokenizer = BertTokenizer.from_pretrained(args.bert_path)
wrapped_model = xmp.MpModelWrapper(BertForSequenceClassification.from_pretrained(args.bert_path, num_labels=2))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




In [6]:
def load_data(path):
    input_ids, attention_mask, token_type_ids = [], [], []
    sentiments = []
    input_file = open(path, encoding="utf8")
    lines = input_file.readlines()
    input_file.close()
    for line in tqdm(lines):
        label, text = line.split("\t")
        text = tokenizer.tokenize(text)
        if args.trunc_mode == "head":
            if len(text) > args.max_seq_length - 2:
                text = text[:args.max_seq_length - 2]
        elif args.trunc_mode == "tail":
            if len(text) > args.max_seq_length - 2:
                text = text[-(args.max_seq_length - 2):]
        else:
            args.trunc_mode = int(args.trunc_mode)
            assert args.trunc_mode < args.max_seq_length
            if len(text) > args.max_seq_length - 2:
                text = text[:args.trunc_mode] + text[-(args.max_seq_length - 2 - args.trunc_mode):]
        text = ["[CLS]"] + text + ["[SEP]"]
        attention_mask.append([1] * len(text) + [0] * (args.max_seq_length - len(text)))
        token_type_ids.append([0] * args.max_seq_length)
        input_ids.append(tokenizer.convert_tokens_to_ids(text) + [0] * (args.max_seq_length - len(text)))
        sentiments.append(int(label))
    return np.array(input_ids), np.array(attention_mask), np.array(token_type_ids), np.array(sentiments)

train_input_ids, train_attention_mask, train_token_type_ids, y_train = load_data('../input/imdb-csv/train.csv')
test_input_ids, test_attention_mask, test_token_type_ids, y_test = load_data('../input/imdb-csv/test.csv')

100%|██████████| 25000/25000 [03:36<00:00, 115.67it/s]
100%|██████████| 25000/25000 [03:31<00:00, 118.39it/s]


In [7]:
train_input_ids = torch.tensor(train_input_ids, dtype=torch.long)
train_attention_mask = torch.tensor(train_attention_mask, dtype=torch.float)
train_token_type_ids = torch.tensor(train_token_type_ids, dtype=torch.long)
y_train = torch.tensor(y_train, dtype=torch.long)
test_input_ids = torch.tensor(test_input_ids, dtype=torch.long)
test_attention_mask = torch.tensor(test_attention_mask, dtype=torch.float)
test_token_type_ids = torch.tensor(test_token_type_ids, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)
train_data = TensorDataset(train_input_ids, train_attention_mask, train_token_type_ids, y_train)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids, y_test)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False)

In [8]:
def _mp_fn(index, flags):
    torch.manual_seed(args.seed)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False)
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=4,
        drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=4,
        drop_last=False)
    device = xm.xla_device()
    model = wrapped_model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
    scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=len(train_loader) * args.num_epochs * args.warm_up_proportion,
            num_training_steps=len(train_loader) * args.num_epochs)
    total_step = len(train_loader)
    for epoch in range(args.num_epochs):
        model.train()
        para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        for i, (cur_input_ids, cur_attention_mask, cur_token_type_ids, cur_y) in enumerate(para_train_loader):
            cur_input_ids = cur_input_ids.to(device)
            cur_attention_mask = cur_attention_mask.to(device)
            cur_token_type_ids = cur_token_type_ids.to(device)
            cur_y = cur_y.to(device)
            outputs = model(cur_input_ids, cur_attention_mask, cur_token_type_ids)
            loss = nn.CrossEntropyLoss()(outputs[0], cur_y)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            xm.optimizer_step(optimizer)
            scheduler.step()
            if (i + 1) % 10 == 0:
                print ('[{}] [xla:{}] Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                        time.strftime("%Y-%m-%d %H:%M:%S"), xm.get_ordinal(), epoch + 1,
                        args.num_epochs, i + 1, total_step, loss.item()))
        xm.master_print("Finished training epoch {}".format(epoch))
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            para_test_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
            for i, (cur_input_ids, cur_attention_mask, cur_token_type_ids, cur_y) in enumerate(para_test_loader):
                cur_input_ids = cur_input_ids.to(device)
                cur_attention_mask = cur_attention_mask.to(device)
                cur_token_type_ids = cur_token_type_ids.to(device)
                cur_y = cur_y.to(device)
                outputs = model(cur_input_ids, cur_attention_mask, cur_token_type_ids)
                _, predicted = torch.max(outputs[0], 1)
                total += cur_y.size(0)
                correct += (predicted == cur_y).sum().item()
            accuracy = correct / total
            print ('[{}] [xla:{}] samples: {} accuracy: {}'.format(
                    time.strftime("%Y-%m-%d %H:%M:%S"), xm.get_ordinal(), total, accuracy))
            acc_reduced = xm.mesh_reduce('acc_reduce', accuracy, lambda x: sum(x) / len(x))
            xm.master_print('reduced accuracy: {}'.format(acc_reduced))
        xm.master_print("Finished evaluating epoch {}".format(epoch))

In [9]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

[2021-01-15 11:31:03] [xla:3] Epoch [1/4], Step [10/130], Loss: 0.7032
[2021-01-15 11:31:03] [xla:7] Epoch [1/4], Step [10/130], Loss: 0.7055
[2021-01-15 11:31:03] [xla:5] Epoch [1/4], Step [10/130], Loss: 0.6613
[2021-01-15 11:31:03] [xla:2] Epoch [1/4], Step [10/130], Loss: 0.7137
[2021-01-15 11:31:02] [xla:4] Epoch [1/4], Step [10/130], Loss: 0.6851
[2021-01-15 11:31:02] [xla:6] Epoch [1/4], Step [10/130], Loss: 0.6550
[2021-01-15 11:31:03] [xla:0] Epoch [1/4], Step [10/130], Loss: 0.6531
[2021-01-15 11:31:03] [xla:1] Epoch [1/4], Step [10/130], Loss: 0.6489
[2021-01-15 11:31:21] [xla:2] Epoch [1/4], Step [20/130], Loss: 0.4512
[2021-01-15 11:31:21] [xla:3] Epoch [1/4], Step [20/130], Loss: 0.5078
[2021-01-15 11:31:21] [xla:6] Epoch [1/4], Step [20/130], Loss: 0.4510
[2021-01-15 11:31:21] [xla:5] Epoch [1/4], Step [20/130], Loss: 0.5114
[2021-01-15 11:31:21] [xla:4] Epoch [1/4], Step [20/130], Loss: 0.4897
[2021-01-15 11:31:21] [xla:0] Epoch [1/4], Step [20/130], Loss: 0.4664
[2021-