Install PyTorch Nightly packages and set up the backend version.

In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.6-cp36-cp36m-linux_x86_64.whl

Only run the below commented cell if you would like a nightly release

In [None]:
# VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version $VERSION

Install the other publicly available dependencies (PIP, APT, ...).

In [None]:
# !pip install transformers

Clone the repo containing the model to be tested.

If all the code fits a single code snippet (see below the *%%writefile* cell), you can leave the cell below empty or remove it.

In [None]:
!rm -rf pytorch-xla-transformer-language-model/
!git clone https://github.com/dlibenzi/pytorch-xla-transformer-language-model.git

Setup the environment.

In [None]:
import os
# os.environ['XLA_IR_DEBUG'] = '1'
# os.environ['XLA_HLO_DEBUG'] = '1'
# os.environ['TF_CPP_VMODULE'] = 'tensor=5'
# os.environ['XLA_USE_32BIT_LONG'] = '1'
# os.environ['XLA_SAVE_TENSORS_FILE'] = 'tensors.log'
# os.environ['XLA_SAVE_TENSORS_FMT'] = 'text'
# os.environ['XLA_TRIM_GRAPH_SIZE'] = '1000000'

Override the files which needs editing/tweaking during the debug session.

This means copy and pasting the content of one or more of the original github repo files, so that one can easily iterate with debugging. If the test/debug code does not need to pull any github repo, the cell(s) below are essentially pasting the test code itself.

We strongly suggest to run single core when debugging. If using multi-processing, just pass *nprocs=1* to *xmp.spawn()*.

In case accuracy debugging is not needed, to avoid fetching large datasets, it is possible to use the PyTorch/XLA [data generators](https://github.com/pytorch/xla/blob/dfab0b544c02b5319c3d52bef12cf4487829c182/test/test_train_mp_mnist.py#L61).


In [None]:
%%writefile pytorch-xla-transformer-language-model/train.py
# Copyright (c) 2019, Bryan McCann
# All rights reserved.

import os
import time
import math

import numpy
import torch
import torch.utils.data

import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

from transformer import Transformer


class LazyDataset:

  def __init__(self, path, sequence_length):
    self.path = path
    self.size = os.stat(path).st_size - sequence_length
    self.sequence_length = sequence_length

  def __getitem__(self, index):
    with open(self.path, 'rb') as f:
      f.seek(index)
      chunk = f.read(self.sequence_length)
    return torch.ByteTensor(numpy.frombuffer(chunk, dtype=numpy.uint8))

  def __len__(self):
    return self.size


def get_dataloader(path, batch_size, sequence_length, num_workers):
  dataset = LazyDataset(path, sequence_length + 1)
  if xm.xrt_world_size() > 1:
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
  else:
    sampler = torch.utils.data.RandomSampler(dataset)
  return torch.utils.data.DataLoader(
      dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)


def main(index):
  BATCH_SIZE = 128
  LOG_STEPS = 10
  METRICS_STEP = 50
  NUM_EPOCHS = 8
  SEQUENCE_LENGTH = 256

  device = xm.xla_device()
  model = Transformer(256, 12, 512, 2048, 8, 0.2).to(device)
  optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

  def train_loop_fn(loader):
    tracker = xm.RateTracker()

    positions = torch.arange(SEQUENCE_LENGTH).long().view(
        1, SEQUENCE_LENGTH).to(device)
    causal_mask = torch.triu(
        torch.ones(
            SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.uint8, device=device),
        diagonal=1).unsqueeze(0)

    model.train()
    for iteration, batch in enumerate(loader):
      input = batch[:, :-1].long()
      target = batch[:, 1:].long()

      loss = model(input, positions, target, batch_mask=causal_mask)
      loss.backward()
      xm.optimizer_step(optimizer)

      tracker.add(BATCH_SIZE)
      if iteration % LOG_STEPS == 0:
        print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(
            device, iteration,
            loss.item() / math.log(2), tracker.rate()))
      if iteration % METRICS_STEP == 0:
        xm.master_print(met.metrics_report())

  train_loader = get_dataloader('pytorch-xla-transformer-language-model/datasets/enwik8/train/train.txt.raw',
                                BATCH_SIZE, SEQUENCE_LENGTH, 0)

  for epoch in range(0, NUM_EPOCHS):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))


if __name__ == '__main__':
  # Set nprocs=1 for debugging (using one core).
  xmp.spawn(main, args=(), nprocs=1)


Cleanup (optional) the products of previous runs, as some operations might append to existing content (like tensors logging).

In [None]:
!rm -f tensors.log
!rm -rf /tmp/debug_run*

Run the model's script with proper command line.

In [None]:
!python pytorch-xla-transformer-language-model/train.py

For debugging it is also useful to run the *debug_run.py* script to collect a set of debug information packaged in a TAR file.

The *debug_run.py* command below should be run for a few steps (around 10 should be enough), or stopped after a given time if hanging happen.

In [None]:
!git clone https://github.com/pytorch/xla.git

In [None]:
!./xla/scripts/debug_run.py --outfile debug_run.tar.gz --hlo -- python -u pytorch-xla-transformer-language-model/train.py

Download generated debug files or logs.

In [None]:
from google.colab import files
# files.download('tensors.log')
# files.download('debug_run.tar.gz')