In [1]:
# Uncomment and run these if you haven't already installed `torch_xla2`
#!pip uninstall -y tensorflow
#!pip install tpu-info 'torch_xla2[tpu] @ git+https://github.com/pytorch/xla.git#subdirectory=experimental/torch_xla2' -f https://storage.googleapis.com/libtpu-releases/index.html
#!pip install torchvision

# Distributed training with `torch_xla2`

This Notebook demonstrates how to perform distributed training using `torch_xla2`, which allows you to run PyTorch models with JAX.

## Dataset and model setup

Below, we download and preprocess the MNIST dataset and instantiate a simple neural network to use as an example. The details here aren't important here. You can follow the same steps below for any PyTorch model and dataset.

A couple of important notes about this section:

- When we're loading data, the batch will be split across all local devices.
- `model` remains on the CPU device. We'll move it to the TPU in the next step.

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=128,
    drop_last=True,
    shuffle=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=128,
    drop_last=True,
    shuffle=False)

In [3]:
import torch.nn as nn

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 10)
)

## Replicating the model across devices

Most TPU configurations include multiple TPU cores per host. For example, a v4-8 TPU has 4 chips total. We can use `tpu-info` to see how many devices are available on this host.

In [4]:
!tpu-info

[3mTPU Chips                                   [0m
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━┓
┃[1m [0m[1mDevice     [0m[1m [0m┃[1m [0m[1mType       [0m[1m [0m┃[1m [0m[1mCores[0m[1m [0m┃[1m [0m[1mPID [0m[1m [0m┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━┩
│ /dev/accel0 │ TPU v4 chip │ 1     │ None │
│ /dev/accel1 │ TPU v4 chip │ 1     │ None │
│ /dev/accel2 │ TPU v4 chip │ 1     │ None │
│ /dev/accel3 │ TPU v4 chip │ 1     │ None │
└─────────────┴─────────────┴───────┴──────┘
Libtpu metrics unavailable. Did you start a workload with `TPU_RUNTIME_METRICS_PORTS=8431,8432,8433,8434`?


`torch_xla2` uses JAX as a backend, so we can use JAX to double-check the device count. Don't worry -- we won't have to directly use JAX to run the model.

In [5]:
import jax

# The TPU core count will vary depending on your environment.
jax.device_count()

4

The device count above should match the output of `tpu-info` (4 devices in the case of a v4-8).

In this example, we'll use `torch_xla2`'s custom `DistributedDataParallel` implementation to replicate the model parameters across all available TPU devices and split input data between each core.

In [6]:
import torch_xla2

ddp_model = torch_xla2.distributed.DistributedDataParallel(model)



We can dig into the underlying JAX array to see that there's an identical copy of the parameter tensor on each TPU device:

In [7]:
example_param = next(ddp_model.parameters())

In [8]:
import pprint
pprint.pprint(example_param._elem.addressable_shards)

[Shard(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=0, data=[[ 0.03249096  0.01343462 -0.022144   ...  0.00668433  0.00833362
   0.00225713]
 [ 0.02272127  0.02205281  0.00828168 ... -0.02310903  0.02183958
   0.01084254]
 [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385  0.0339912
  -0.02596978]
 ...
 [ 0.0168394   0.0063334  -0.02949585 ... -0.0254653   0.03273752
  -0.02620777]
 [-0.00896274 -0.03342744 -0.0269749  ...  0.01811987  0.03423703
  -0.02689848]
 [ 0.01867637  0.0117135   0.02216029 ...  0.00011777  0.02212651
   0.00852821]]),
 Shard(device=TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=1, data=[[ 0.03249096  0.01343462 -0.022144   ...  0.00668433  0.00833362
   0.00225713]
 [ 0.02272127  0.02205281  0.00828168 ... -0.02310903  0.02183958
   0.01084254]
 [-0.01985117 -0.01139126 -0.00

The replicated tensor still behaves as a plain PyTorch tensor, however:

In [9]:
example_param

XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [[ 0.03249096  0.01343462 -0.022144   ...  0.00668433  0.00833362
   0.00225713]
 [ 0.02272127  0.02205281  0.00828168 ... -0.02310903  0.02183958
   0.01084254]
 [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385  0.0339912
  -0.02596978]
 ...
 [ 0.0168394   0.0063334  -0.02949585 ... -0.0254653   0.03273752
  -0.02620777]
 [-0.00896274 -0.03342744 -0.0269749  ...  0.01811987  0.03423703
  -0.02689848]
 [ 0.01867637  0.0117135   0.02216029 ...  0.00011777  0.02212651
   0.00852821]])

## Sharding inputs

Unlike the model parameters, we want to send a different shard of the input data to each device. We'll take one batch of images as an example:

In [10]:
example_images, _ = next(iter(train_loader))
example_images.shape

torch.Size([128, 1, 28, 28])

Sharding the input batch across devices does not change the overall size of the tensor:

In [11]:
sharded_example_images = ddp_model.shard_input(example_images)
sharded_example_images.shape

(128, 1, 28, 28)

If we dig into the underlying JAX array, we can see that the input has been split (into quarters in this case) across the batch dimension:

In [12]:
[s.data.shape for s in sharded_example_images._elem.addressable_shards]

[(32, 1, 28, 28), (32, 1, 28, 28), (32, 1, 28, 28), (32, 1, 28, 28)]

## Putting it all together

`torch_xla2` allows us to seamlessly shard and replicate tensors across devices, while still maintaining a singular view of that tensor through PyTorch. With some minor changes, we can adapt the conventional PyTorch training loop to use the TPU.

Note that we do not have to spawn any child processes. Although each parameter and input is represented by one tensor, that tensor is already distributed across multiple devices.

The loss function and optimizer stay the same:

In [13]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

JAX gets significantly better performance when compiled, normally through `jax.jit`. `torch_xla2`'s DDP implementation contains a utility `jit_step` that can be used to compile a training step. Note that for this to work, the training step must be separated out into a function. Otherwise, the actual contents are the same as they would be for eager CPU or GPU.

In [14]:
@ddp_model.jit_step
def train_step(sharded_inputs, sharded_labels):
  optimizer.zero_grad()
  outputs = ddp_model(sharded_inputs)
  loss = loss_fn(outputs, sharded_labels)
  loss.backward()
  optimizer.step()

  return loss

Finally, let's quickly run training for several epochs and check the validation results:

In [15]:
for epoch in range(10):
  running_loss = 0

  print('Epoch', epoch)
  for i, data in enumerate(train_loader):
      inputs, labels = data
      # Distribute the batch across all TPU cores
      sharded_inputs, sharded_labels = ddp_model.shard_input(inputs), ddp_model.shard_input(labels)
      loss = train_step(sharded_inputs, sharded_labels)

      if i % 100 == 0:
          print('  batch {} loss: {}'.format(i, loss.item()))
          running_loss = 0.

Epoch 0


  batch 0 loss: 2.3075523376464844
  batch 100 loss: 2.3029651641845703
  batch 200 loss: 2.2921366691589355
  batch 300 loss: 2.2877070903778076
  batch 400 loss: 2.274242401123047
Epoch 1
  batch 0 loss: 2.2708349227905273
  batch 100 loss: 2.269294261932373
  batch 200 loss: 2.2480335235595703
  batch 300 loss: 2.243983268737793
  batch 400 loss: 2.2470455169677734
Epoch 2
  batch 0 loss: 2.234013557434082
  batch 100 loss: 2.2184624671936035
  batch 200 loss: 2.2029666900634766
  batch 300 loss: 2.198725461959839
  batch 400 loss: 2.1829864978790283
Epoch 3
  batch 0 loss: 2.1811957359313965
  batch 100 loss: 2.1297898292541504
  batch 200 loss: 2.1378531455993652
  batch 300 loss: 2.0720174312591553
  batch 400 loss: 2.0413732528686523
Epoch 4
  batch 0 loss: 2.046309471130371
  batch 100 loss: 1.9817270040512085
  batch 200 loss: 1.9381718635559082
  batch 300 loss: 1.847656011581421
  batch 400 loss: 1.808678388595581
Epoch 5
  batch 0 loss: 1.7617125511169434
  batch 100 loss: 

In [16]:
@ddp_model.jit_step
def eval_step(sharded_vinputs, sharded_vlabels):
  voutputs = ddp_model(sharded_vinputs)
  vloss = loss_fn(voutputs, sharded_vlabels)
  return vloss

ddp_model.eval()
running_vloss = 0.

# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
  for i, vdata in enumerate(test_loader):
    vinputs, vlabels = vdata
    sharded_vinputs, sharded_vlabels = ddp_model.shard_input(vinputs), ddp_model.shard_input(vlabels)
    vloss = eval_step(sharded_vinputs, sharded_vlabels)
    running_vloss += vloss

avg_vloss = running_vloss / (i + 1)
print('Validation loss', avg_vloss.item())

Validation loss 0.6315549612045288


## Conclusion

With some minor changes to your training loop, `torch_xla2` allows you to distribute a model across multiple devices and run a compiled version with JAX. All of the data you interact with directly is still a `torch` tensor, and JAX handles all of the distributed details in the background.

`torch_xla2` (and especially training) is still under heavy development. To learn more about the project and its current status, see https://github.com/pytorch/xla/tree/master/experimental/torch_xla2