In [2]:
import torch
import torchvision
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.data_parallel as dp
import torch_xla.utils.utils as xu
import torch_xla.debug.metrics as met
from torchvision import transforms
from torch.utils.data import DataLoader

def get_data_loader(batch_size):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    
    dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    return loader

def eval_model(model, loader):
    device = xm.xla_device()
    for data, _ in loader:
        data = data.to(device)
        with torch.no_grad():
            output = model(data)
            print(output.sum())


In [3]:
batch_size = 256
loader = get_data_loader(batch_size)

device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18(pretrained=True).to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')
eval_model(dynamo_resnet18, loader)

2024-07-31 05:01:32.194688: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 2710650 nanoseconds and will start immediately.


Files already downloaded and verified




tensor(20.2726, device='xla:0')
tensor(20.5006, device='xla:0')
tensor(19.9703, device='xla:0')
tensor(19.6732, device='xla:0')
tensor(20.0878, device='xla:0')
tensor(19.7444, device='xla:0')
tensor(19.4213, device='xla:0')
tensor(19.5926, device='xla:0')
tensor(19.1535, device='xla:0')


2024-07-31 05:01:54.091469: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 3279160 nanoseconds and will start immediately.


tensor(19.9890, device='xla:0')
tensor(20.1976, device='xla:0')
tensor(20.4524, device='xla:0')
tensor(18.8251, device='xla:0')
tensor(19.2354, device='xla:0')
tensor(20.5606, device='xla:0')
tensor(19.6442, device='xla:0')
tensor(20.3219, device='xla:0')
tensor(18.3967, device='xla:0')
tensor(18.0711, device='xla:0')
tensor(19.5000, device='xla:0')
tensor(18.7404, device='xla:0')
tensor(19.9876, device='xla:0')
tensor(24.5668, device='xla:0')
tensor(21.3810, device='xla:0')
tensor(19.9725, device='xla:0')
tensor(19.6045, device='xla:0')
tensor(20.1262, device='xla:0')
tensor(18.6039, device='xla:0')
tensor(21.2876, device='xla:0')
tensor(20.9116, device='xla:0')
tensor(19.7567, device='xla:0')
tensor(18.0652, device='xla:0')
tensor(20.0142, device='xla:0')
tensor(17.7503, device='xla:0')
tensor(20.2894, device='xla:0')
tensor(20.7117, device='xla:0')
tensor(19.4540, device='xla:0')
tensor(21.3903, device='xla:0')
tensor(20.0119, device='xla:0')
tensor(1.3527, device='xla:0')


In [4]:
dynamo_resnet18

OptimizedModule(
  (_orig_mod): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True,