Skip to content

Commit 781ee93

Browse files
authored
Add examples for training (#6929)
1 parent aee1d37 commit 781ee93

File tree

10 files changed

+819
-3
lines changed

10 files changed

+819
-3
lines changed

experimental/torch_xla2/README.md

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,98 @@ pip install -e .
7171
```bash
7272
pip install -r test_requirements.txt
7373
pytest test
74-
```
74+
```
75+
76+
77+
## Run a model
78+
79+
Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model
80+
it can be in theory any instance of `torch.nn.Module`.
81+
82+
```python
83+
84+
import torch_xla2
85+
from torch import nn
86+
87+
class MyModel(nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
self.fc1 = nn.Linear(28 * 28, 120)
91+
self.fc2 = nn.Linear(120, 84)
92+
self.fc3 = nn.Linear(84, 10)
93+
94+
def forward(self, x):
95+
x = x.view(-1, 28 * 28)
96+
x = F.relu(self.fc1(x))
97+
x = F.relu(self.fc2(x))
98+
x = self.fc3(x)
99+
return x
100+
101+
m = MyModel()
102+
103+
# Execute this model using torch
104+
inputs = (torch.randn(3, 3, 28, 28), )
105+
print(m(*inputs))
106+
```
107+
108+
This model `m` contains 2 parts: the weights that is stored inside of the model
109+
and it's submodules (`nn.Linear`).
110+
111+
To execute this model with `torch_xla2`; we need to move the tensors involved in compute
112+
to `XLA` devices. This can be accomplished with `torch_xla2.tensor.move_to_device`.
113+
114+
We need move both the weights and the input to xla devices:
115+
116+
```python
117+
from torch.utils import _pytree as pytree
118+
from torch_xla2.tensor import move_to_device
119+
120+
inputs = move_to_device(inputs)
121+
new_state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, m.state_dict())
122+
m.load_state_dict(new_state_dict, assign=True)
123+
124+
res = m(*inputs)
125+
126+
print(type(res)) # outputs XLATensor2
127+
```
128+
129+
### Executing with jax.jit
130+
131+
The above script will execute the model using eager mode Jax as backend. This
132+
does allow executing torch models on TPU, but is often slower than what we can
133+
achieve with `jax.jit`.
134+
135+
`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
136+
and returns jax array) into the same function, but faster.
137+
138+
We have made the `jax_jit` decorator that would accomplish the same with functions
139+
that takes and returns `torch.Tensor`. To use this, the first step is to create
140+
a functional version of this model: this means the parameters should be passed in
141+
as input instead of being attributes on class:
142+
143+
144+
```python
145+
146+
def model_func(param, inputs):
147+
return torch.func.functional_call(m, param, inputs)
148+
149+
```
150+
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
151+
from PyTorch to replace the model
152+
weights with `param`, then call the model. This is equivalent to:
153+
154+
```python
155+
def model_func(param, inputs):
156+
m.load_state_dict(param)
157+
return m(*inputs)
158+
```
159+
160+
Now, we can apply `jax_jit`
161+
162+
```python
163+
from torch_xla2.extra import jax_jit
164+
model_func_jitted = jax_jit(model_func)
165+
print(model_func_jitted(new_state_dict, inputs))
166+
```
167+
168+
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
## Intro
2+
3+
This readme will have a subsection for every example *.py file.
4+
5+
Please follow the instructions in [README.md](../README.md) to install torch_xla2,
6+
then install requirements for all of the examples with
7+
8+
```bash
9+
pip install -r requirements.txt
10+
```
11+
12+
13+
14+
## basic_training.py
15+
16+
This file constructed by first copy & paste code fragments from this pytorch training tutorial:
17+
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
18+
19+
Then adding few lines of code that serves the purpose of moving `torch.Tensor` into
20+
`XLA devices`.
21+
22+
Example:
23+
24+
```python
25+
state_dict = pytree.tree_map_only(torch.Tensor,
26+
torch_xla2.tensor.move_to_device, state_dict)
27+
```
28+
29+
This fragment moves the state_dict to XLA devices; then the state_dict is passed
30+
back to model via `load_state_dict`.
31+
32+
Then, you can train the model. This shows what is minimum to train a model on XLA
33+
devices. The perf is not as good because we didn't use `jax.jit`, this is intentional
34+
as it is meant to showcase the minimum code change.
35+
36+
Example run:
37+
```bash
38+
(xla2) hanq-macbookpro:examples hanq$ python basic_training.py
39+
Training set has 60000 instances
40+
Validation set has 10000 instances
41+
Bag Dress Sneaker T-shirt/top
42+
tensor([[0.8820, 0.3807, 0.3010, 0.9266, 0.7253, 0.9265, 0.0688, 0.4567, 0.7035,
43+
0.2279],
44+
[0.3253, 0.1558, 0.1274, 0.2776, 0.2590, 0.4169, 0.1881, 0.7423, 0.4561,
45+
0.5985],
46+
[0.5067, 0.4514, 0.9758, 0.6088, 0.7438, 0.6811, 0.9609, 0.3572, 0.4504,
47+
0.8738],
48+
[0.1850, 0.1217, 0.8551, 0.2120, 0.9902, 0.7623, 0.1658, 0.6980, 0.3086,
49+
0.5709]])
50+
tensor([1, 5, 3, 7])
51+
Total loss for this batch: 2.325265645980835
52+
EPOCH 1:
53+
batch 1000 loss: 1.041275198560208
54+
batch 2000 loss: 0.6450189483696595
55+
batch 3000 loss: 0.5793989677671343
56+
batch 4000 loss: 0.5170258888280951
57+
batch 5000 loss: 0.4920090722264722
58+
batch 6000 loss: 0.48910293977567926
59+
batch 7000 loss: 0.48058812761632724
60+
batch 8000 loss: 0.47159107415075413
61+
batch 9000 loss: 0.4712311488997657
62+
batch 10000 loss: 0.4675815168160479
63+
batch 11000 loss: 0.43210567891132085
64+
batch 12000 loss: 0.445208148030797
65+
batch 13000 loss: 0.4119230824254337
66+
batch 14000 loss: 0.4190662656680215
67+
batch 15000 loss: 0.4094535468676477
68+
LOSS train 0.4094535468676477 valid XLA
69+
```
70+
71+
## basic_training_jax.py
72+
73+
This file constructed by first copy & paste code fragments from this pytorch training tutorial:
74+
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
75+
76+
Then replacing torch optimizer with `optax` optimizer; and use `jax.grad` for
77+
gradient instead of `torch.Tensor.backward()`.
78+
79+
Then, you can train the model using jax ecosystem's training loop. This is meant to
80+
showcase how easy is to integrate with Jax.
81+
82+
Example run:
83+
```bash
84+
(xla2) hanq-macbookpro:examples hanq$ python basic_training_jax.py
85+
Training set has 60000 instances
86+
Validation set has 10000 instances
87+
Pullover Ankle Boot Pullover Ankle Boot
88+
tensor([[0.5279, 0.8340, 0.3131, 0.8608, 0.3668, 0.6192, 0.7453, 0.3261, 0.8872,
89+
0.1854],
90+
[0.7414, 0.8309, 0.8127, 0.8866, 0.2475, 0.2664, 0.0327, 0.6918, 0.6010,
91+
0.2766],
92+
[0.3304, 0.9135, 0.2762, 0.6737, 0.0480, 0.6150, 0.5610, 0.5804, 0.9607,
93+
0.6450],
94+
[0.9464, 0.9439, 0.3122, 0.1814, 0.1194, 0.5012, 0.2058, 0.1170, 0.7377,
95+
0.7453]])
96+
tensor([1, 5, 3, 7])
97+
Total loss for this batch: 2.4054245948791504
98+
EPOCH 1:
99+
batch 1000 loss: 1.0705260595591972
100+
batch 2000 loss: 1.0997755021179327
101+
batch 3000 loss: 1.0186579653513108
102+
batch 4000 loss: 0.9090727646966116
103+
batch 5000 loss: 0.8309370622411024
104+
batch 6000 loss: 0.8702225417760783
105+
batch 7000 loss: 0.8750176187023462
106+
batch 8000 loss: 0.9652624803795453
107+
batch 9000 loss: 0.8688667197711766
108+
batch 10000 loss: 0.8021814124770199
109+
batch 11000 loss: 0.8000540231048071
110+
batch 12000 loss: 0.9150884484921057
111+
batch 13000 loss: 0.819690621060171
112+
batch 14000 loss: 0.8569030471532278
113+
batch 15000 loss: 0.8740896808278603
114+
LOSS train 0.8740896808278603 valid 2.3132264614105225
115+
```
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import functools
2+
3+
import torch
4+
from time import time
5+
from diffusers import DiffusionPipeline
6+
from torch.utils import _pytree as pytree
7+
8+
9+
import torch_xla2
10+
import torch_xla2.functions
11+
from torch_xla2.extra import torch_view, jax_view
12+
13+
import jax
14+
import torch.func
15+
16+
17+
class CompiledModule:
18+
19+
def __init__(self, model):
20+
weights = model.state_dict()
21+
weights.update(model.named_parameters())
22+
self._weights = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.move_to_device, weights)
23+
self._model = model
24+
25+
self._func_jitted_torch = None #torch_view(func_mod_jitted)
26+
27+
28+
def _maybe_move_tensor(self, tensor):
29+
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torch_xla2.tensor.XLATensor2):
30+
return torch_xla2.tensor.move_to_device(tensor)
31+
return tensor
32+
33+
def _make_jitted(self, args, kwargs):
34+
static = []
35+
for i, a in enumerate(args):
36+
if not isinstance(a, torch.Tensor):
37+
static.append(i + 1) # weight is 0
38+
static_argnames = []
39+
for k, v in kwargs.items():
40+
if not isinstance(v, torch.Tensor):
41+
static_argnames.append(k)
42+
43+
def f(weights, *args, **kwargs):
44+
weights, args, kwargs = torch_xla2.tensor.wrap((weights, args, kwargs))
45+
with torch_xla2.functions.XLAFunctionMode(), torch_xla2.tensor.XLADispatchMode():
46+
res = torch.func.functional_call(self._model, weights, args, kwargs)
47+
if isinstance(res, tuple) and len(res) == 1:
48+
res = res[0]
49+
return torch_xla2.tensor.unwrap(res)
50+
51+
fjit = jax.jit(f, static_argnames=tuple(static_argnames))
52+
return torch_view(fjit)
53+
54+
55+
def forward(self, *args, **kwargs):
56+
(args, kwargs) = pytree.tree_map(self._maybe_move_tensor, (args, kwargs))
57+
if self._func_jitted_torch is None:
58+
self._func_jitted_torch = self._make_jitted(args, kwargs)
59+
return self._func_jitted_torch(
60+
self._weights,
61+
*args,
62+
**kwargs
63+
)
64+
65+
def __call__(self, *args, **kwargs):
66+
return self.forward(*args, **kwargs)
67+
68+
def __getattr__(self, key):
69+
return getattr(self._model, key)
70+
71+
72+
def compile_pipe(pipe):
73+
pipe.text_encoder = CompiledModule(pipe.text_encoder)
74+
pipe.text_encoder_2 = CompiledModule(pipe.text_encoder_2)
75+
pipe.unet = CompiledModule(pipe.unet)
76+
pipe.vae = CompiledModule(pipe.vae)
77+
78+
79+
def main():
80+
pipe = DiffusionPipeline.from_pretrained(
81+
# "stabilityai/stable-diffusion-xl-base-0.9",
82+
"stabilityai/stable-diffusion-xl-base-1.0",
83+
use_safetensors=True,
84+
85+
)
86+
compile_pipe(pipe)
87+
88+
global_bs = 10
89+
inference_steps = 20
90+
resol = 1024
91+
prompts = ["a photo of an astronaut riding a horse on mars"] * global_bs
92+
print(f'global batch size {global_bs}',
93+
f'inference steps {inference_steps}',
94+
f'Image resolution {resol}',
95+
flush=True
96+
)
97+
98+
iters = 5
99+
for i in range(iters):
100+
prompt = prompts
101+
# print('per device prompts len',len(prompt))
102+
# prompt = prompts[rank]
103+
start = time()
104+
image = pipe(prompt,
105+
num_inference_steps=inference_steps,
106+
height=resol,
107+
width=resol).images[0]
108+
print(f'Step {i} inference time {time()-start} sec', flush=True)
109+
110+
111+
if __name__ == '__main__':
112+
main()

0 commit comments

Comments
 (0)