# Distributing Large Embedding tables over TPU cores

Use Colab Cloud TPU

<a href="https://cloud.google.com/tpu/"><img valign=middle src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a></h3>

* On the main menu, click Runtime and select **Change runtime type**. Set "TPU" as the hardware accelerator.
* The cell below makes sure you have access to a TPU on Colab.

In [0]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

## [RUNME] Install Colab TPU compatible PyTorch/TPU wheels and dependencies
This may take up to ~2 minutes


In [0]:
VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

## Description and Objective

The goal of this notebook is to illustrate a technique of distributing embedding tables over many TPU cores. This technique may come in handy for cases where the embedding table is very large, and won't fit on a single TPU device.

We will use the popular [`fairseq`](https://github.com/pytorch-tpu/fairseq) repository to demonstrate how the training works, with parameters which would make the regular runs lead to an `HBM out-of-memory` error.

### Explanation of the technique:

The trick can be roughly summarized as follows:
- Each core will own a slice of the embedding table, sliced by the embedding dimension.
  - e.g. Core 1 will own dimensions 1-10, Core 2 will own 11-20, and so on.
  - Every core will have the full list of entities being embedded.
- During forward pass:
  - Every core will share its input with other cores and end up with the full batch input.
  - Then get the corresponding embedding dimensions for the full input.
  - Do an all-gather and collect the other embedding dimensions from the other cores.
  - At this point, every core has the full embeddings for the full input.
  - Then each core will slice the full batch and end up with only the samples in the batch belonging to itself.
  - Then the forward will resume normally.
- During backward, it'll perform the opposite operations and each core will update the slice of the embedding table that it owns.

## Setting up the task

We will modify the translation workload [tutorial](https://cloud.google.com/tpu/docs/tutorials/transformer-pytorch) which uses `fairseq`'s Transformer model. Let's begin by first installing fairseq, and downloading the data.

In [0]:
fairseq_path = '/tmp/fairseq'
!git clone https://github.com/pytorch-tpu/fairseq.git -b tpu {fairseq_path}
!pip install --editable {fairseq_path}
!wget https://dl.fbaipublicfiles.com/fairseq/data/wmt18_en_de_bpej32k.zip
!unzip wmt18_en_de_bpej32k.zip -d /tmp

import sys
sys.path.append(fairseq_path)

Now let's define `DistributedEmbedding` and the wrapper around the `fairseq_model` that will use it. We override the original model's embedding table, add the forward and backward methods described above, and add a couple of other methods to be used later.

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.core.xla_model as xm


class DistributedEmbedding(nn.Module):

  def __init__(self, vocab_size, embedding_size, world_size=None,
               batch_dim=0):
    super(DistributedEmbedding, self).__init__()
    self._embedding_size = embedding_size
    self._world_size = world_size
    self._batch_dim = batch_dim
    assert embedding_size % self._world_size == 0, \
        ("For this example to work, please provide embedding size "
         "a multiple of {}".format(self._world_size))
    self._sliced_emb_size = self._embedding_size // self._world_size
    self.embeddings = nn.Embedding(vocab_size, self._sliced_emb_size)

  @property
  def _rank(self):
    # We allow delaying the rank setting to allow module creation at global scope.
    return xm.get_ordinal()

  def _get_embedding_pad(self):
    size = self._sliced_emb_size
    return self._rank * size, (self._world_size - 1 - self._rank) * size

  def forward(self, batch):
    bsz = batch.size(self._batch_dim)
    fullbatch = xm.all_gather(
        batch.type(torch.float), dim=self._batch_dim).type(batch.dtype)
    embeds = self.embeddings(fullbatch)
    pembeds = xm.all_gather(embeds, dim=-1)
    sliced_embeds = torch.narrow(pembeds, self._batch_dim, self._rank*bsz, bsz)
    # We return both sub-batch's full embeddings and fullbatch's sliced embeddings
    # The former is needed to do forward pass for the remainder of the model
    # The latter is needed to do backward pass and update the embedding table.
    return sliced_embeds.clone().detach().requires_grad_(True), embeds

  def backward(self, fullbatch_slicedemb, grad):
    # Gradient at this point has the full embedding dimensions
    # and only contains info on the samples this core processed.
    grad = xm.all_gather(grad, dim=self._batch_dim)
    size = self._sliced_emb_size
    sliced_grad = torch.narrow(grad, grad.ndim-1, self._rank * size, size)
    fullbatch_slicedemb.backward(sliced_grad)


class TransformerWithDistributedEmbeddings(nn.Module):

  def __init__(self, model, emb_size, world_size):
    super(TransformerWithDistributedEmbeddings, self).__init__()
    self.model = model
    self.dropout = self.model.encoder.dropout
    self.embedding_size = emb_size
    self._world_size = world_size
    self._distribute_embeddings()

  def _distribute_embeddings(self):
    vocab_size = self.model.encoder.embed_tokens.weight.size(0)
    self.padding_idx = self.model.encoder.embed_tokens.padding_idx
    self.embedding = DistributedEmbedding(
        vocab_size, self.embedding_size, world_size=self._world_size)
    # We remove the original embedding layer.
    self.model.encoder.embed_tokens = None

  def init_emb_weights(self):
    std = self.embedding_size
    nn.init.normal_(self.embedding.embeddings.weight, mean=0, std=std**-0.5)
    nn.init.constant_(self.embedding.embeddings.weight[self.padding_idx], 0)

  def forward(self, **kwargs):
    inputs = kwargs['src_tokens']
    embedded_batch, emb_globalbatch_dimsliced = self.embedding(inputs)
    x = F.dropout(
        embedded_batch, p=self.dropout, training=self.model.training)
    # Hack the encoder's `forward_embedding` method so that it returns what
    #   was just computed in distributed fashion.
    # This needs to return two values.
    self.model.encoder.forward_embedding = lambda _: (x, None)
    return self.model(**kwargs), embedded_batch, emb_globalbatch_dimsliced

  def emb_backward(self, *args, **kwargs):
    self.embedding.backward(*args, **kwargs)

  def non_distr_params(self):
    # Last parameter is the added distributed embedding table.
    last_index = len(list(self.parameters())) - 1
    for i, _ in enumerate(self.parameters()):
      if i != last_index:
        yield _

Let's now create the `Namespace`, which `fairseq` uses to define the task, dataset, model and more.

In [0]:
from fairseq import options

# The following  leads to an HBM OOM w/ the regular way of embedding tokens.
#   On v3-8:
EMBEDDING_SIZE = 4096
INPUT_SHAPES = [[64, 64],]
#   On v2-8:
EMBEDDING_SIZE = 2048
INPUT_SHAPES = [[64, 64],]

args = [
  '/tmp/wmt18_en_de_bpej32k',
  '--arch=transformer_wmt_en_de',
  '--max-target-positions=64',
  '--max-source-positions=64',
  '--attention-dropout=0.0',
  '--dropout=0.0',
  '--no-progress-bar',
  '--criterion=label_smoothed_cross_entropy',
  '--source-lang=en',
  '--target-lang=de',
  '--lr-scheduler=inverse_sqrt',
  '--min-lr=1e-09',
  '--label-smoothing=0.1',
  '--optimizer=adam',
  '--adam-betas',
  '(0.9, 0.98)',
  '--warmup-init-lr=1e-07',
  '--lr=0.0005',
  '--warmup-updates=4000',
  '--weight-decay=0.0',
  '--no-save',
  '--log-interval=20',
  '--num-workers=1',
  '--disable-validation',
  '--max-epoch=1',
  '--encoder-embed-dim={}'.format(EMBEDDING_SIZE),
  '--decoder-embed-dim=512',
]

parser = options.get_training_parser()
args = options.parse_args_and_arch(parser, input_args=args)
args.input_shapes = INPUT_SHAPES
args.use_gpu = False

Now let's create the models. We're still at global scope, doing this will save host memory. Let's also define the training:

In [0]:
import math
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
from fairseq import tasks, optim

# GLOBAL SCOPE
NUM_DEVICES = 8
# Set up fairseq dataclasses
task = tasks.setup_task(args)
task.load_dataset(args.train_subset, epoch=0)
criterion = task.build_criterion(args)
# This is our initial model.
fairseq_model = task.build_model(args)
# Let's set the embedding table dimension to a high number, explicitly here:
# Distributing the embedding table now it with the 
distr_model = TransformerWithDistributedEmbeddings(
    fairseq_model, EMBEDDING_SIZE, world_size=NUM_DEVICES)
distr_model.train(), fairseq_model.train()
wrapped_model = xmp.MpModelWrapper(distr_model)


def train(index):
  device = xm.xla_device()
  m = wrapped_model.to(device)
  # Let's initialize the table weights.
  #   We seed per process so every table inits to a different set of weights.
  torch.manual_seed(xm.get_ordinal())
  m.init_emb_weights()   
  torch.manual_seed(args.seed)
  epoch_itr = task.get_batch_iterator(
      dataset=task.dataset(args.train_subset),
      max_tokens=args.max_tokens,
      max_sentences=args.max_sentences,
      max_positions=(args.max_source_positions, args.max_target_positions),
      ignore_invalid_inputs=True,
      required_batch_size_multiple=args.required_batch_size_multiple,
      seed=args.seed,
      num_shards=NUM_DEVICES,
      shard_id=xm.get_ordinal(),
      num_workers=args.num_workers,
      epoch=0,
  )
  itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=False, shuffle=False)
  para_loader = pl.MpDeviceLoader(itr, device)
  # The distributed embedding needs to have its own optimizer, because
  #   the embedding table is sharded and we do not want gradient reduction
  #   happening across all cores.
  # Thus, we create two optimizers, one for the distributed embedding, and 
  #   another for the remainder of the model. The latter's gradients will be
  #   reduced as usual, and we'll call the custom backward on the other one. 

  model_optimizer = optim.build_optimizer(args, m.non_distr_params())
  model_lr_scheduler = optim.lr_scheduler.build_lr_scheduler(
      args, model_optimizer)  # learning rate warmup
  demb_optimizer = optim.build_optimizer(args, m.embedding.parameters())

  running_loss = 0
  for step, batch in enumerate(para_loader, 1):
    # We will do 100 steps to illustrate the training avoids any OOMs.
    if step > 100 or step == len(itr):
        break  # drop the last batch
    model_optimizer.zero_grad(), demb_optimizer.zero_grad()
    demb_optimizer.set_lr(model_optimizer.get_lr())
    net_output, fewsamples_fullemb, fullsamples_slicedemb = \
        m(**batch['net_input'])
    loss, _ = criterion.compute_loss(m.model, net_output, batch)
    loss.backward()  # this only back-propagates up to the embeddings
    xm.reduce_gradients(model_optimizer)
    model_optimizer.clip_grad_norm(args.clip_norm)
    model_optimizer.step()  # update model weights up to the embeddings
    # Custom backward to handle distributed embeddings
    m.emb_backward(fullsamples_slicedemb, fewsamples_fullemb.grad)
    demb_optimizer.clip_grad_norm(args.clip_norm)
    demb_optimizer.step()  # update embeddings
    # Learning rate warmup
    model_lr_scheduler.step_update(step)
    # Record loss for reporting later.
    running_loss += loss / math.log(2) / batch['ntokens']
    if step % args.log_interval:
      continue
    running_loss = running_loss.item()
    update = 'Step {}, loss {:.4f}'.format(step, running_loss / step)
    xm.add_step_closure(lambda s: xm.master_print(s, flush=True), args=(update,))
  xm.master_print(met.metrics_report())

Now let's fire up the training, and observe that it doesn't crash w/ an HBM OOM! Note that the first few steps take long because of initial compilations.

In [0]:
xmp.spawn(train, args=(), nprocs=NUM_DEVICES, start_method='fork')