# Setup PyTorch/XLA Environment

In [None]:
import os

# Environment variable for profiling / debug
os.environ['PT_XLA_DEBUG'] = '1'

In [None]:
!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl

In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

import torch.nn as nn

## Create XLA Tensor 

For illustration, perform operations with XLA tensor(s), and view HLO Graph:

In [None]:
dev = xm.xla_device()

x1 = torch.rand((3, 3)).to(dev)
x2 = torch.rand((3, 8)).to(dev)

y1 = torch.einsum('bs,st->bt', x1, x2)
y1 = y1 + x2
print(torch_xla._XLAC._get_xla_tensors_text([y1]))

Notice that XLA Tensors are "Lazy", i.e. The operations have been recorded but no computation/execution actually is done until required.

The execution is done when a LazyTensor Barrier is inserted.
The easiest way to insert a barrier is mark_step() call:

## Exploring LazyTensor with Debug Metrics
Report the metrics and counters, and notice that no compilation has been performed yet, nor the graph has been executed.

In [None]:
# Print all available metrics 
print(f"Available metrics:\n {met.metric_names()}")
# Print all available counters
print(f"Available counters:\n {met.counter_names()}")

## Graph Execution Scenarios - 1

The simplest where LazyTensor barrier is inserted triggers execution of graph(s) recorded so far is to call the mark_step API explicitly:


In [None]:
xm.mark_step()

Let's review the available metrics after the mark step call:

In [None]:
# Print all available metrics 
print(f"Available metrics:\n {met.metric_names()}")

Note that we see the CompileTime metric available now. This metrics can provide the details of Compilation Times distribution for all the graph compilations executed so far. However, here we are only interested in the number of times the compilations happens, we can report it as:

In [None]:
met.metric_data('CompileTime')[:1]

## Execution Scenario - 2
Another scenario, where the LazyTensor Barrier is inserted is when PyTorch/XLA encounters an OP with no XLA lowering. Let's examine this scenario:

In [None]:
y1 = y1.view(3, 1, 2, 4)
# Example op with no XLA lowering
unfold = nn.Unfold(kernel_size=(2, 3))
y2 =  unfold(y1)
y4 = y2 * 2

Notice that an additional compilation is triggered.

In [None]:
met.metric_data('CompileTime')[:1]

Notice also the counters:

In [None]:
print(f"Available counters:\n {met.counter_names()}")

## PyTorch/XLA Profiler
In the remainder of this notebook we will explore how PyTorch/XLA profiler can help surface these metrics insights without writing any additional line of code.

Note: We alter the lower level variables to display the debug info which will by default be printed on your terminal (can be captured in the logfile). It is intended for educational purposes and is not the recommended way to use the profiler.

In [None]:
from torch_xla.debug.frame_parser_util import process_frames

Example stack trace:

In [None]:
debug_file = torch_xla._tmp_fname
process_frames(debug_file)

In [None]:
y4 = y4.reshape(-1,1)

## Device to host transfer
Now let's create a device to host transfer scenario:

In [None]:
print(y4[0].item())

In [None]:
# Print all available counters
print(f"Available counters:\n {met.counter_names()}")

In [None]:
print(met.counter_value('aten::_local_scalar_dense'))

In [None]:
process_frames(debug_file)

Notice that device to host transfer are reported in terms of _local_scalar_dense op. In the usual seting PyTorch/XLA profiler would provide you the full stack-trace leading to lines in your code which are causing device to host transfers.

# Summary
In this notebook we have explored the LazyTensor behavior with some basic metrics and briefly also experiemented with some of the functionalities of PyTorch/XLA profiler. To explore other features of Pytorch/XLA profiler please review:
- [Blog Posts](https://cloud.google.com/blog/topics/developers-practitioners/pytorchxla-performance-debugging-tpu-vm-part-1)
- [Contrib Notebooks](https://github.com/pytorch/xla.git)