In [1]:
%load_ext tensorboard
# Remember to pick /tmp/profile/

In [2]:
from torch_xla import runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

import time
import itertools

import torch
import torch_xla
import torchvision
import torch.optim as optim
import torch.nn as nn


class TrainResNetBase():

  def __init__(self):
    self.img_dim = 224
    self.batch_size = 128
    self.num_steps = 300
    self.num_epochs = 1
    self.train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
    # For the purpose of this example, we are going to use fake data.
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(self.batch_size, 3, self.img_dim, self.img_dim),
              torch.zeros(self.batch_size, dtype=torch.int64)),
        sample_count=self.train_dataset_len // self.batch_size //
        xr.world_size())

    self.device = torch_xla.device()
    self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
    self.model = torchvision.models.resnet50().to(self.device)
    self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
    self.loss_fn = nn.CrossEntropyLoss()
    self.compiled_step_fn = torch_xla.experimental.compile(self.step_fn)

  def _train_update(self, step, loss, tracker, epoch):
    print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')

  def run_optimizer(self):
    self.optimizer.step()

  def step_fn(self, data, target):
    self.optimizer.zero_grad()
    output = self.model(data)
    loss = self.loss_fn(output, target)
    loss.backward()
    self.run_optimizer()

  def train_loop_fn(self, loader, epoch):
    tracker = xm.RateTracker()
    self.model.train()
    loader = itertools.islice(loader, self.num_steps)
    for step, (data, target) in enumerate(loader):
      loss = self.compiled_step_fn(data, target)
      tracker.add(self.batch_size)
      if step % 10 == 0:
        xm.add_step_closure(
            self._train_update, args=(step, loss, tracker, epoch))

  def start_training(self):

    for epoch in range(1, self.num_epochs + 1):
      xm.master_print('Epoch {} train begin {}'.format(
          epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
      self.train_loop_fn(self.train_device_loader, epoch)
      xm.master_print('Epoch {} train end {}'.format(
          epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
    xm.wait_device_ops()


In [3]:
import os
import sys

import torch_xla.debug.profiler as xp

# check https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#environment-variables
os.environ["XLA_IR_DEBUG"] = "1"
os.environ["XLA_HLO_DEBUG"] = "1"


In [4]:
base = TrainResNetBase()
profile_port = 9012
# you can also set profile_logdir to a gs bucket, for example
# profile_logdir = "gs://your_gs_bucket/profile"
profile_logdir = "/tmp/profile/"
duration_ms = 30000
assert profile_logdir.startswith('gs://') or os.path.exists(profile_logdir)
server = xp.start_server(profile_port)
# Ideally you want to start the profile tracing after the initial compilation, for example
# at step 5.
xp.trace_detached(
    f'localhost:{profile_port}', profile_logdir, duration_ms=duration_ms)
base.start_training()
# You can view the profile at tensorboard by
# 1. pip install tensorflow-cpu tensorboard-plugin-profile
# 2. tensorboard --logdir /tmp/profile/ --port 6006
# For more detail please take a look at https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm


Epoch 1 train begin  8:44PM UTC on Jul 31, 2024
Starting to trace for 30000 ms. Remaining attempt(s): 2


2024-07-31 20:44:38.864042: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 2952170 nanoseconds and will start immediately.


epoch: 1, step: 0, loss: None, rate: 4.546404332686245
epoch: 1, step: 10, loss: None, rate: 234.35643001468216
epoch: 1, step: 20, loss: None, rate: 326.6896138542191
epoch: 1, step: 30, loss: None, rate: 365.50556414045957
epoch: 1, step: 40, loss: None, rate: 383.52400527549355
epoch: 1, step: 50, loss: None, rate: 389.3695036014534
epoch: 1, step: 60, loss: None, rate: 392.5179354161948
epoch: 1, step: 70, loss: None, rate: 394.5766531756856
epoch: 1, step: 80, loss: None, rate: 394.8321200405354
epoch: 1, step: 90, loss: None, rate: 395.1472163241774
epoch: 1, step: 100, loss: None, rate: 394.1538913699761
epoch: 1, step: 110, loss: None, rate: 394.62894760046026
epoch: 1, step: 120, loss: None, rate: 394.675310043137
epoch: 1, step: 130, loss: None, rate: 394.3734698171106
epoch: 1, step: 140, loss: None, rate: 392.9848480556546
epoch: 1, step: 150, loss: None, rate: 394.31165761017155
epoch: 1, step: 160, loss: None, rate: 394.52356824373067
epoch: 1, step: 170, loss: None, rate