## PyTorch/XLA TPU Profiling Colab tutorial

*Note*: Since we're not using GCS in this tutorial, TPU side traces won't be collected. To collect full TPU traces follow [this tutorial](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).

### [RUNME] Install Colab compatible PyTorch/XLA wheels and dependencies



In [None]:
!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl

### Define Parameters



In [None]:
# Define Parameters
import os
FLAGS = {}
FLAGS['data_dir'] = "/tmp/cifar"
FLAGS['batch_size'] = 128
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 0.02
FLAGS['momentum'] = 0.9
FLAGS['num_epochs'] = 200
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False

In [None]:
# Setup profiler env var
os.environ['XLA_HLO_DEBUG'] = '1'

In [None]:
import multiprocessing
import numpy as np
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms

class BasicBlock(nn.Module):
  expansion = 1

  def __init__(self, in_planes, planes, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(
        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(
        planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion * planes:
      self.shortcut = nn.Sequential(
          nn.Conv2d(
              in_planes,
              self.expansion * planes,
              kernel_size=1,
              stride=stride,
              bias=False), nn.BatchNorm2d(self.expansion * planes))

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))
    out += self.shortcut(x)
    out = F.relu(out)
    return out


class ResNet(nn.Module):

  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_planes = 64

    self.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    self.linear = nn.Linear(512 * block.expansion, num_classes)

  def _make_layer(self, block, planes, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks - 1)
    layers = []
    for stride in strides:
      layers.append(block(self.in_planes, planes, stride))
      self.in_planes = planes * block.expansion
    return nn.Sequential(*layers)

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = torch.flatten(out, 1)
    out = self.linear(out)
    return F.log_softmax(out, dim=1)


def ResNet18():
  return ResNet(BasicBlock, [2, 2, 2, 2])

In the following cell we define the training loops and most importantly add tracing annotations `xp.StepTrace` and `xp.Trace` to that we'll be able to inspect in our profiler traces view on Tensorboard. `xp.StepTrace` specifically should be annotated only once per step as it denotes a full step and is used to calculate the step time for the model and is displayed on Tensorboard profile summary page. The `xp.Trace` context manager annotation can be sprinkled around on whichever parts you want a more detailed timeline of.

In [None]:
SERIAL_EXEC = xmp.MpSerialExecutor()
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(ResNet18())

def train_resnet18(training_started):
  torch.manual_seed(1)

  # We are using fake data here (not real CIFAR dataset).
  train_dataset_len = 50000  # Number of example in CIFAR train set.
  train_loader = xu.SampleGenerator(
      data=(torch.zeros(FLAGS['batch_size'], 3, 32,
                        32), torch.zeros(FLAGS['batch_size'],
                                          dtype=torch.int64)),
      sample_count=train_dataset_len // FLAGS['batch_size'] //
      xm.xrt_world_size())
  test_loader = xu.SampleGenerator(
      data=(torch.zeros(FLAGS['batch_size'], 3, 32,
                        32), torch.zeros(FLAGS['batch_size'],
                                          dtype=torch.int64)),
      sample_count=10000 // FLAGS['batch_size'] // xm.xrt_world_size())

  # Get loss function, optimizer, and model
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)
  optimizer = optim.SGD(model.parameters(), lr=FLAGS['learning_rate'],
                        momentum=FLAGS['momentum'], weight_decay=5e-4)
  loss_fn = nn.NLLLoss()

  server = xp.start_server(9012)

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()
    for x, (data, target) in enumerate(loader):
      if x == 5:
        training_started.set()
      # Let's now profile the training step.
      with xp.StepTrace('train_loop', step_num=x):
        # This profiles the construction of the graph.
        with xp.Trace('build_graph'):
          optimizer.zero_grad()
          output = model(data)
          loss = loss_fn(output, target)
          loss.backward()

        xm.optimizer_step(optimizer)
        tracker.add(FLAGS['batch_size'])
        if x % FLAGS['log_steps'] == 0:
          print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
              xm.get_ordinal(), x, loss.item(), tracker.rate(),
              tracker.global_rate(), time.asctime()), flush=True)

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    model.eval()
    data, pred, target = None, None, None
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target
  

In [None]:
# Start training processes
def _mp_fn(rank, flags, training_started):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy, data, pred, target = train_resnet18(training_started)
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot.
    plot_results(data.cpu(), pred.cpu(), target.cpu())

def target_fn(training_started):
  sys.stdout = open('training_logs.stdout', 'w')
  sys.stderr = open('training_logs.stderr', 'w')
  xmp.spawn(_mp_fn, args=(FLAGS, training_started,),
            nprocs=FLAGS['num_cores'], start_method='fork')
  
training_started = multiprocessing.Event()
p = multiprocessing.Process(target=target_fn, args=(training_started,))
p.start()

The following cell first waits for the training to start up and then subsequently traces both the client VM side (i.e., where the XLA graph is built and input pipeline is run) and the TPU device side (where the actual compilation and execution happens). However, note that since we're running on Colab and not using GCS in this tutorial, TPU side traces won't be collected. To collect full TPU traces follow this [tutorial](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).

In [None]:
training_started.wait(120)

import re
tpu_ip = re.match('grpc\://((\d{1,3}\.){3}\d{1,3})\:\d{4}',
             os.environ.get('TPU_NAME')).group(1)
xp.trace('localhost:9012', '/tmp/tensorboard')  # client side profiling
xp.trace(f'{tpu_ip}:8466', '/tmp/tensorboard')  # need GCS bucket for all traces to be written

In [None]:
%load_ext tensorboard
%tensorboard --logdir /tmp/tensorboard --load_fast=false
# Click on "INACTIVE" dropdown and select "PROFILE"