In [23]:
from torch_xla import runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

import time
import itertools

import torch
import torch_xla
import torch.optim as optim
import torch.nn as nn


class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
    self.fc1 = nn.Linear(64 * 54 * 54, 128)
    self.fc2 = nn.Linear(128, 10)
    
  def forward(self, x):
    x = self.conv1(x)
    x = nn.functional.relu(x)
    x = nn.functional.max_pool2d(x, 2)
    x = self.conv2(x)
    x = nn.functional.relu(x)
    x = nn.functional.max_pool2d(x, 2)
    x = x.view(-1, 64 * 54 * 54)
    x = self.fc1(x)
    x = nn.functional.relu(x)
    x = self.fc2(x)
    return x


class TrainResNetBase():

  def __init__(self):
    self.img_dim = 224
    self.batch_size = 256
    self.num_steps = 100
    self.num_epochs = 1
    self.train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
    # For the purpose of this example, we are going to use fake data.
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(self.batch_size, 3, self.img_dim, self.img_dim),
              torch.zeros(self.batch_size, dtype=torch.int64)),
        sample_count=self.train_dataset_len // self.batch_size //
        xr.world_size())

    self.device = torch_xla.device()
    self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
    self.model = CNN().to(self.device)
    self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
    self.loss_fn = nn.CrossEntropyLoss()
    self.compiled_step_fn = torch_xla.compile(self.step_fn, name="train_step")

  def _train_update(self, step, loss, tracker, epoch):
    print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')

  def run_optimizer(self):
    self.optimizer.step()

  def step_fn(self, data, target):
    self.optimizer.zero_grad()
    output = self.model(data)
    loss = self.loss_fn(output, target)
    loss.backward()
    self.run_optimizer()

  def train_loop_fn(self, loader, epoch):
    import torch_xla.debug.profiler as xp
    import os
    profile_port = 9012
    # you can also set profile_logdir to a gs bucket, for example
    # profile_logdir = "gs://your_gs_bucket/profile"
    profile_logdir = "/workspaces/torch/playground/profile/"
    duration_ms = 180
    assert profile_logdir.startswith('gs://') or os.path.exists(profile_logdir)
    server = xp.start_server(profile_port)

    tracker = xm.RateTracker()
    self.model.train()
    loader = itertools.islice(loader, self.num_steps)
    for step, (data, target) in enumerate(loader):
      loss = self.compiled_step_fn(data, target)  # type: ignore
      tracker.add(self.batch_size)
      if step % 10 == 0:
        xm.add_step_closure(
            self._train_update, args=(step, loss, tracker, epoch))
      if step == 12:
        xp.trace_detached(
          f'localhost:{profile_port}', profile_logdir, duration_ms=duration_ms)

  def start_training(self):

    for epoch in range(1, self.num_epochs + 1):
      xm.master_print('Epoch {} train begin {}'.format(
          epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
      self.train_loop_fn(self.train_device_loader, epoch)
      xm.master_print('Epoch {} train end {}'.format(
          epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
    xm.wait_device_ops()


In [24]:
import os


# check https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#environment-variables
os.environ["XLA_IR_DEBUG"] = "0"
os.environ["XLA_HLO_DEBUG"] = "0"


In [25]:
base = TrainResNetBase()
base.start_training()

# You can view the profile at tensorboard by
# 1. pip install tensorflow-cpu tensorboard-plugin-profile
# 2. tensorboard --logdir /tmp/profile/ --port 6006
# For more detail please take a look at https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm


Epoch 1 train begin  2:05AM UTC on Jan 31, 2025
epoch: 1, step: 0, loss: None, rate: 7836.330372716591
epoch: 1, step: 10, loss: None, rate: 51903.29374240089
Starting to trace for 180 ms. Remaining attempt(s): 2


2025-01-31 02:05:23.069579: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 968060 nanoseconds and will start immediately.


epoch: 1, step: 20, loss: None, rate: 28618.559073762506
epoch: 1, step: 30, loss: None, rate: 18873.392118030828
epoch: 1, step: 40, loss: None, rate: 14842.749830784665
epoch: 1, step: 50, loss: None, rate: 13228.008486592364
epoch: 1, step: 60, loss: None, rate: 13865.532707287919
epoch: 1, step: 70, loss: None, rate: 12957.070395361821
epoch: 1, step: 80, loss: None, rate: 11350.36131181967
epoch: 1, step: 90, loss: None, rate: 12768.881452208418
Epoch 1 train end  2:05AM UTC on Jan 31, 2025
