Skip to content

Commit ae63cd1

Browse files
authored
Add simple example for how to use torch_xla (#7048)
1 parent b64d8a2 commit ae63cd1

File tree

5 files changed

+139
-0
lines changed

5 files changed

+139
-0
lines changed

examples/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
## Overview
2+
This repo aims to provide some basic examples of how to run an existing pytorch model with PyTorch/XLA. train_resnet_base.py is a minimal trainer to run ResNet50 with fake data on a single device. Other examples will import the train_resnet_base and demonstrate how to enable different features(distributed training, profiling, dynamo etc) on PyTorch/XLA.The objective of this repository is to offer fundamental examples of executing an existing PyTorch model utilizing PyTorch/XLA. train_resnet_base.py acts as a bare-bones trainer for running ResNet50 with simulated data on an individual device. Additional examples will import train_resnet_base and illustrate how to activate various features (e.g., distributed training, profiling, dynamo) on PyTorch/XLA.

examples/train_resnet_base.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from torch_xla import runtime as xr
2+
import torch_xla.utils.utils as xu
3+
import torch_xla.core.xla_model as xm
4+
import torch_xla.distributed.parallel_loader as pl
5+
6+
import time
7+
import itertools
8+
9+
import torch
10+
import torch_xla
11+
import torchvision
12+
import torch.optim as optim
13+
import torch.nn as nn
14+
15+
16+
def _train_update(step, loss, tracker, epoch):
17+
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')
18+
19+
20+
class TrainResNetBase():
21+
22+
def __init__(self):
23+
img_dim = 224
24+
self.batch_size = 128
25+
self.num_steps = 300
26+
self.num_epochs = 1
27+
train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
28+
# For the purpose of this example, we are going to use fake data.
29+
train_loader = xu.SampleGenerator(
30+
data=(torch.zeros(self.batch_size, 3, img_dim, img_dim),
31+
torch.zeros(self.batch_size, dtype=torch.int64)),
32+
sample_count=train_dataset_len // self.batch_size // xr.world_size())
33+
34+
self.device = torch_xla.device()
35+
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
36+
self.model = torchvision.models.resnet50().to(self.device)
37+
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
38+
self.loss_fn = nn.CrossEntropyLoss()
39+
40+
def run_optimizer(self):
41+
self.optimizer.step()
42+
43+
def start_training(self):
44+
45+
def train_loop_fn(loader, epoch):
46+
tracker = xm.RateTracker()
47+
self.model.train()
48+
loader = itertools.islice(loader, self.num_steps)
49+
for step, (data, target) in enumerate(loader):
50+
self.optimizer.zero_grad()
51+
output = self.model(data)
52+
loss = self.loss_fn(output, target)
53+
loss.backward()
54+
self.run_optimizer()
55+
tracker.add(self.batch_size)
56+
if step % 10 == 0:
57+
xm.add_step_closure(_train_update, args=(step, loss, tracker, epoch))
58+
59+
for epoch in range(1, self.num_epochs + 1):
60+
xm.master_print('Epoch {} train begin {}'.format(
61+
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
62+
train_loop_fn(self.train_device_loader, epoch)
63+
xm.master_print('Epoch {} train end {}'.format(
64+
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
65+
xm.wait_device_ops()
66+
67+
68+
if __name__ == '__main__':
69+
base = TrainResNetBase()
70+
base.start_training()

examples/train_resnet_ddp.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from train_resnet_base import TrainResNetBase
2+
import torch.distributed as dist
3+
from torch.nn.parallel import DistributedDataParallel as DDP
4+
import torch.optim as optim
5+
import torch_xla.distributed.xla_multiprocessing as xmp
6+
7+
8+
class TrainResNetDDP(TrainResNetBase):
9+
10+
def __init__(self):
11+
super().__init__()
12+
dist.init_process_group('xla', init_method='xla://')
13+
self.model = DDP(
14+
self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
15+
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
16+
17+
18+
def _mp_fn(index):
19+
ddp = TrainResNetDDP()
20+
ddp.start_training()
21+
22+
23+
if __name__ == '__main__':
24+
xmp.spawn(_mp_fn, args=())

examples/train_resnet_profile.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import os
2+
3+
from train_resnet_base import TrainResNetBase
4+
import torch_xla.debug.profiler as xp
5+
6+
# check https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#environment-variables
7+
os.environ["XLA_IR_DEBUG"] = "1"
8+
os.environ["XLA_HLO_DEBUG"] = "1"
9+
10+
if __name__ == '__main__':
11+
base = TrainResNetBase()
12+
profile_port = 9012
13+
profile_logdir = "/tmp/profile/"
14+
duration_ms = 30000
15+
assert os.path.exists(profile_logdir)
16+
server = xp.start_server(profile_port)
17+
# Ideally you want to start the profile tracing after the initial compilation, for example
18+
# at step 5.
19+
xp.trace_detached(
20+
f'localhost:{profile_port}', profile_logdir, duration_ms=duration_ms)
21+
base.start_training()
22+
# You can view the profile at tensorboard by
23+
# 1. pip install tensorflow tensorboard-plugin-profile
24+
# 2. tensorboard --logdir /tmp/profile/ --port 6006
25+
# For more detail plase take a look at https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm

examples/train_resnet_xla_ddp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from train_resnet_base import TrainResNetBase
2+
import torch_xla.distributed.xla_multiprocessing as xmp
3+
import torch_xla.core.xla_model as xm
4+
5+
6+
class TrainResNetXLADDP(TrainResNetBase):
7+
8+
def run_optimizer(self):
9+
xm.optimizer_step(self.optimizer)
10+
11+
12+
def _mp_fn(index):
13+
xla_ddp = TrainResNetXLADDP()
14+
xla_ddp.start_training()
15+
16+
17+
if __name__ == '__main__':
18+
xmp.spawn(_mp_fn, args=())

0 commit comments

Comments
 (0)