## Train Your PyTorch Model on Cloud TPU

This notebook will show you how to:

* Install PyTorch/XLA on Colab, which lets you use PyTorch with TPUs.
* Outlines the syntactical elements use specific to TPUs.



<h3>  &nbsp;&nbsp;Use Colab Cloud TPU&nbsp;&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a></h3>

* On the main menu, click Runtime and select **Change runtime type**. Set "TPU" as the hardware accelerator.
* The cell below makes sure you have access to a TPU on Colab.


In [1]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

## Installing PyTorch/XLA

Run the following cell (or copy it into your own notebook!) to install PyTorch, Torchvision, and PyTorch/XLA. It will take a couple minutes to run.

The PyTorch/XLA package lets PyTorch connect to Cloud TPUs. (It's named PyTorch/XLA, not PyTorch/TPU, because XLA is the name of the TPU compiler.) In particular, PyTorch/XLA makes TPU cores available as PyTorch devices. This lets PyTorch create and manipulate tensors on TPUs.

In [2]:

!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl



In [3]:
# Imports
import torch
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms

# PyTorch/XLA Library Elements
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.debug.metrics as met



In [8]:
# Model
class ToyModel(nn.Module):
    """ Toy Classifier """
    def __init__(self):
        super(ToyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.mp1 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(1440, 10)

    def forward(self, x):
        x = nn.ReLU()(self.conv1(x))
        x = self.mp1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.Softmax(dim=-1)(x)
        return x

In [9]:
# Config Parameters
FLAGS = {
    'batch_size': 32,
    'world_size': 1,
    'epochs': 1,
    'log_steps': 10,
    'metrics_debug': False,
    'updates_per_epoch' : 400
}
SERIAL_EXEC = xmp.MpSerialExecutor()
WRAPPED_MODEL = xmp.MpModelWrapper(ToyModel())

In [10]:
# Training Loop
def train(rank, FLAGS):
    print("Starting train method on rank: {}".format(rank))
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)

    def get_dataset():
        transform = transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
            ]
        )
       
        return torchvision.datasets.MNIST( 
                '/tmp/', train=True, download=True, transform=transform)

    train_dataset = SERIAL_EXEC.run(get_dataset)    

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=FLAGS['world_size'], rank=rank)
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=FLAGS['batch_size'], shuffle=False,
        num_workers=0, sampler=train_sampler)
    # PyTorch/XLA: Parallel Loader Wrapper
    train_loader = pl.MpDeviceLoader(train_loader, device)

    for epoch in range(FLAGS['epochs']):
        for i, (images, labels) in enumerate(train_loader):
            if i > FLAGS['updates_per_epoch']:
              break
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            # PyTorch/XLA: All Reduce followed by parameter update 
            xm.optimizer_step(optimizer)

            if not i % FLAGS['log_steps']:
                xm.master_print(
                    'Epoch: {}/{}, Loss:{}'.format(
                        epoch + 1, FLAGS['epochs'], loss.item()
                    )
                )
        if FLAGS['metrics_debug']:
            xm.master_print(met.metrics_report())

In [11]:
#PyTorch/XLA: Distributed training on 4 TPU Chips (8 cores)
xmp.spawn(train, nprocs=FLAGS['world_size'], args=(FLAGS,), start_method='fork')

Starting train method on rank: 0
Epoch: 1/1, Loss:2.296616315841675
Epoch: 1/1, Loss:2.299046754837036
Epoch: 1/1, Loss:2.29319429397583
Epoch: 1/1, Loss:2.2989964485168457
Epoch: 1/1, Loss:2.2943575382232666
Epoch: 1/1, Loss:2.2907660007476807
Epoch: 1/1, Loss:2.284073829650879
Epoch: 1/1, Loss:2.284733772277832
Epoch: 1/1, Loss:2.2944512367248535
Epoch: 1/1, Loss:2.2889037132263184
Epoch: 1/1, Loss:2.302015542984009
Epoch: 1/1, Loss:2.2987492084503174
Epoch: 1/1, Loss:2.2814888954162598
Epoch: 1/1, Loss:2.2919561862945557
Epoch: 1/1, Loss:2.300222873687744
Epoch: 1/1, Loss:2.2973620891571045
Epoch: 1/1, Loss:2.287539005279541
Epoch: 1/1, Loss:2.295802593231201
Epoch: 1/1, Loss:2.2876291275024414
Epoch: 1/1, Loss:2.2938716411590576
Epoch: 1/1, Loss:2.3000991344451904
Epoch: 1/1, Loss:2.2910425662994385
Epoch: 1/1, Loss:2.2864394187927246
Epoch: 1/1, Loss:2.284902572631836
Epoch: 1/1, Loss:2.270642042160034
Epoch: 1/1, Loss:2.2965376377105713
Epoch: 1/1, Loss:2.304417371749878
Epoch: 1