# Setup PyTorch/XLA Environment

In [None]:
!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl

In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

import torch.nn as nn

## LazyTensor Basics 

This colab is a companion to the blog post titled "Understanding Lazy Tensor System Performance".

For illustration of lazy tensor behavior, let's perform some operations with XLA tensor(s), and examine the resulting HLO Graph:

In [None]:
dev = xm.xla_device()

x1 = torch.rand((3, 3)).to(dev)
x2 = torch.rand((3, 8)).to(dev)

y1 = torch.einsum('bs,st->bt', x1, x2)
y1 = y1 + x2
print(torch_xla._XLAC._get_xla_tensors_text([y1]))

Notice that XLA Tensors are "Lazy", i.e. The operations have been recorded but no computation/execution actually is done until required.

The execution is done when a LazyTensor Barrier is inserted.
The easiest way to insert a barrier is mark_step() call:

In [None]:
xm.mark_step()
print(torch_xla._XLAC._get_xla_tensors_text([x1]))
print(y1.device)

# Dynamic Graph
Now let's create a method which executes operations on xla tensors followed by a mark_step call. Optionally we also introduce a dynamic structure with these tensors. We then execute this method with and without the dynamic structure and measure the run time.

In [None]:
def dummy_step(x, y, loss, acc=False):
  z = torch.einsum('bs,st->bt', y, x)
  step_loss = z.sum().view(1,)
  if acc: 
    loss = torch.cat((loss, step_loss))
  else:
    loss = step_loss
  xm.mark_step()
  return loss

In [None]:
import time
def measure_time(acc=False):
  exec_times = []
  iter_count = 100
  x = torch.rand((512, 8)).to(dev)
  y = torch.rand((512, 512)).to(dev)
  loss = torch.zeros(1).to(dev)
  for i in range(iter_count):
    tic = time.time()
    loss = dummy_step(x, y, loss, acc=acc)
    toc = time.time()
    exec_times.append(toc - tic)
  return exec_times

In [None]:
dyn = measure_time(acc=True) # acc= True Results in dynamic graph
st = measure_time(acc=False) # Static graph, computation shape, inputs and output shapes don't change

In [None]:
import matplotlib.pyplot as plt
plt.plot(st, label = 'static graph')
plt.plot(dyn, label = 'dynamic graph')
plt.legend()
plt.title('Execution time in seconds')

Notice that dynamic graph execution times are consistently higher for same computation because of the compilation cost incurred in every iteration. Static graph curve benefits from compilation cache and quickly stablizes to a faster execution time.