## PyTorch/TPU MNIST Demo

This colab example corresponds to the implementation under [test_train_mnist.py](https://github.com/pytorch/xla/blob/master/test/test_train_mnist.py) and is TF/XRT 1.15 compatible.

<h3>  &nbsp;&nbsp;Use Colab Cloud TPU&nbsp;&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a></h3>

* On the main menu, click Runtime and select **Change runtime type**. Set "TPU" as the hardware accelerator.
* The cell below makes sure you have access to a TPU on Colab.


In [0]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

### [RUNME] Install Colab TPU compatible PyTorch/TPU wheels and dependencies
This may take up to ~2 minutes

In [0]:
DIST_BUCKET="gs://tpu-pytorch/wheels"
TORCH_WHEEL="torch-1.15-cp36-cp36m-linux_x86_64.whl"
TORCH_XLA_WHEEL="torch_xla-1.15-cp36-cp36m-linux_x86_64.whl"
TORCHVISION_WHEEL="torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl"

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5

### Define Parameters


In [0]:
data_dir = "/tmp/mnist" #@param {type:"string"}
batch_size = 128 #@param {type:"integer"}
num_workers = 4 #@param {type:"integer"}
learning_rate = 0.01 #@param {type:"number"}
momentum = 0.5 #@param {type:"number"}
num_epochs = 10 #@param {type:"integer"}
num_cores = 8 #@param [8, 1] {type:"raw"}
log_steps = 20 #@param {type:"integer"}
metrics_debug = False #@param {type:"boolean"}

In [0]:
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.distributed.data_parallel as dp
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
from torchvision import datasets, transforms

class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)
  
norm = transforms.Normalize((0.1307,), (0.3081,))
inv_norm = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))
train_dataset = datasets.MNIST(
    data_dir,
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), norm]))
test_dataset = datasets.MNIST(
    data_dir,
    train=False,
    transform=transforms.Compose(
        [transforms.ToTensor(), norm]))
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers)

devices = (
    xm.get_xla_supported_devices(
        max_devices=num_cores) if num_cores != 0 else [])
print("Devices: {}".format(devices))
# Scale learning rate to num cores
learning_rate = learning_rate * max(len(devices), 1)
# Pass [] as device_ids to run using the PyTorch/CPU engine.
model_parallel = dp.DataParallel(MNIST, device_ids=devices)

def train_loop_fn(model, loader, device, context):
  loss_fn = nn.NLLLoss()
  optimizer = context.getattr_or(
      'optimizer',
      lambda: optim.SGD(model.parameters(), lr=learning_rate,
                        momentum=momentum))
  tracker = xm.RateTracker()

  model.train()
  for x, (data, target) in loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)
    tracker.add(batch_size)
    if x % log_steps == 0:
      print('[{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}\n'.format(
          device, x, loss.item(), tracker.rate(),
          tracker.global_rate(), time.asctime()), flush=True)

def test_loop_fn(model, loader, device, context):
  total_samples = 0
  correct = 0
  model.eval()
  data, pred, target = None, None, None
  for x, (data, target) in loader:
    output = model(data)
    pred = output.max(1, keepdim=True)[1]
    correct += pred.eq(target.view_as(pred)).sum().item()
    total_samples += data.size()[0]

  accuracy = 100.0 * correct / total_samples
  print('[{}] Accuracy={:.2f}%\n'.format(device, accuracy))
  return accuracy, data, pred, target

In [0]:
# Start training
data, target, pred = None, None, None
for epoch in range(1, num_epochs + 1):
  model_parallel(train_loop_fn, train_loader)
  results = model_parallel(test_loop_fn, test_loader)
  results = np.array(results)  # for ease of slicing
  accuracies, data, pred, target = \
    [results[:,i] for i in range(results.shape[-1])]  
  accuracy = sum(accuracies) / len(accuracies)
  print('[Epoch {}] Global accuracy: {:.2f}%'.format(epoch, accuracy))
  if metrics_debug:
    print(torch_xla._XLAC._xla_metrics_report())

## Visualize Predictions

In [0]:
%matplotlib inline
from matplotlib import pyplot as plt

# Retrieve tensors that is on TPU cores
M = 4
N = 6
images, labels, preds = data[0][:N*M].cpu(), \
  target[0][:N*M].cpu(), pred[0][:N*M].cpu()

fig, axes = plt.subplots(M, N, figsize=(11, 9))
fig.suptitle('Correct / Predicted Labels (Red text for incorrect ones)')

for i, ax in enumerate(fig.axes):
  img, label, prediction = images[i], labels[i], preds[i]
  img = inv_norm(img)
  img = img.squeeze() # [1,Y,X] -> [Y,X]
  ax.axis('off')
  label, prediction = label.item(), prediction.item()
  if label == prediction:
    ax.set_title(u'\u2713', color='blue', fontsize=22)
  else:
    ax.set_title(
        'X {}/{}'.format(label, prediction), color='red')
  ax.imshow(img)