Skip to content

Commit 31d348e

Browse files
authored
Add an example that train the torchtitan version of llama. (#8400)
1 parent d503ca5 commit 31d348e

File tree

7 files changed

+292
-33
lines changed

7 files changed

+292
-33
lines changed

experimental/torch_xla2/examples/eager_mode.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.nn import functional as F
44
import torch
55

6-
xla_env = torch_xla2.default_env()
6+
xla_env = torch_xla2.enable_globally()
77

88

99
class MyModel(nn.Module):
@@ -21,28 +21,18 @@ def forward(self, x):
2121
return x
2222

2323
m = MyModel()
24-
m = xla_env.to_xla(m)
24+
m = m.to('jax')
2525

2626
# Execute this model using torch
27-
inputs = (torch.randn(3, 3, 28, 28), )
28-
inputs = xla_env.to_xla(inputs)
27+
inputs = torch.randn(3, 3, 28, 28, device='jax')
2928

30-
print(m(*inputs))
29+
print(m(inputs))
3130
print('---=====')
3231

33-
from torch_xla2.interop import jax_jit
32+
m_compiled = torch_xla2.compile(m)
3433

35-
@jax_jit
36-
def model_func(param, inputs):
37-
return torch.func.functional_call(m, param, inputs)
38-
39-
print(model_func(m.state_dict(), inputs))
40-
41-
print('---=====')
42-
with xla_env:
43-
m2 = MyModel()
44-
inputs = (torch.randn(3, 3, 28, 28), )
45-
print(m2(*inputs))
34+
print(m_compiled(inputs))
4635

4736

37+
print('---')
4838

experimental/torch_xla2/examples/train_llama/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def lightning_mod_loss(
227227
with xla_env:
228228
loss = jittable_mod.functional_call(
229229
'training_step',
230-
weights, buffers, (data, batch_id))
230+
weights, buffers, data, batch_id)
231231
return jax_view(loss)
232232

233233
jax_optimizer = self.torch_opt_to_jax_opt(

experimental/torch_xla2/examples/train_llama_torchtitan/__init__.py

Whitespace-only changes.
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
import os
2+
import time
3+
import logging
4+
from typing import Tuple
5+
from collections import defaultdict
6+
import functools
7+
8+
def _setup_default_env():
9+
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
10+
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
11+
os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1')
12+
# only need for tpu v4
13+
# os.environ.setdefault('TPU_MEGACORE', 'megacore_dense')
14+
tpu_args = "--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
15+
os.environ.setdefault('LIBTPU_INIT_ARGS', tpu_args)
16+
17+
_setup_default_env()
18+
19+
import torch
20+
import torch.nn.functional
21+
from torch.utils import _pytree as pytree
22+
23+
import torch_xla2
24+
import torch_xla2.interop
25+
from torch_xla2.interop import jax_view, torch_view, JittableModule
26+
import jax
27+
import jax.numpy as jnp
28+
from jax.experimental import shard_map
29+
from jax.experimental import mesh_utils
30+
import optax
31+
32+
from torchtitan.models.llama import llama3_configs
33+
from torchtitan.models.llama import model as titan
34+
35+
P = jax.sharding.PartitionSpec
36+
37+
38+
39+
SEQLEN = 8192
40+
BATCH = 8
41+
global_axis: Tuple[str, str] = ('fsdp', )
42+
num_global_devices = jax.device_count()
43+
num_local_devices = jax.local_device_count()
44+
num_partitions = (num_global_devices, )
45+
46+
47+
def sharded_device_put(tensor, sharding):
48+
if isinstance(tensor, tuple):
49+
return tuple(sharded_device_put(t, sharding) for t in tensor)
50+
51+
if num_global_devices == num_local_devices:
52+
return jax.device_put(tensor, sharding)
53+
54+
shape = tensor.shape
55+
x_split = [jax.device_put(tensor[i], device) for device, i in sharding.addressable_devices_indices_map(shape).items()]
56+
return jax.make_array_from_single_device_arrays(shape, sharding, x_split)
57+
58+
59+
class FSDPv2(torch.nn.Module):
60+
61+
def __init__(self, mod):
62+
super().__init__()
63+
self.mod = mod
64+
self.mesh = jax.sharding.Mesh(
65+
mesh_utils.create_device_mesh(num_partitions),
66+
axis_names=global_axis,
67+
)
68+
self.sharding = jax.sharding.NamedSharding(self.mesh, P(*global_axis))
69+
70+
def forward(self, *args):
71+
args = list(args)
72+
args[0] = self.shard(args[0])
73+
res = self.mod(*args)
74+
return self.shard(res)
75+
76+
def shard(self, x):
77+
return torch_xla2.interop.call_jax(
78+
jax.lax.with_sharding_constraint,
79+
x,
80+
self.sharding,
81+
)
82+
83+
def print_shapes(pyt):
84+
for p in pytree.tree_flatten(pyt)[0]:
85+
if hasattr(p, 'shape'):
86+
print(p.shape, p.dtype)
87+
88+
89+
class Module(torch.nn.Module):
90+
91+
def __init__(self, inner):
92+
super().__init__()
93+
self.inner = FSDPv2(inner)
94+
95+
def training_step(self, data, batch_id):
96+
x, y = data
97+
logits = self.inner(x)
98+
num_tokens = logits.shape[-1]
99+
logits = logits.reshape(-1, num_tokens)
100+
y = y.reshape(-1)
101+
return torch.nn.functional.cross_entropy(
102+
logits, y)
103+
104+
105+
class Trainer:
106+
107+
def __init__(self):
108+
self.mesh = jax.sharding.Mesh(
109+
mesh_utils.create_device_mesh(num_partitions),
110+
axis_names=global_axis,
111+
)
112+
self.x_sharding = jax.sharding.NamedSharding(self.mesh, P(global_axis))
113+
self.replicated = jax.sharding.NamedSharding(self.mesh, P())
114+
115+
def _shard_fsdp_style(self, state_dict, sharding=None):
116+
if sharding is None:
117+
sharding = self.x_sharding
118+
def move_one_tensor(x):
119+
jval = torch_xla2.tensor.t2j(x)
120+
return sharded_device_put(jval, sharding)
121+
122+
if isinstance(state_dict, torch.Tensor):
123+
return move_one_tensor(state_dict)
124+
res = {}
125+
for k, v in sorted(state_dict.items()):
126+
res[k] = move_one_tensor(v)
127+
return res
128+
129+
def fit(self, lightning_mod, data_loader):
130+
xla_env = torch_xla2.default_env()
131+
jax.config.update('jax_enable_x64', False)
132+
xla_env._mesh = self.mesh
133+
xla_env.use_flash_attention = True
134+
135+
jittable_mod = JittableModule(lightning_mod)
136+
jax_params = self._shard_fsdp_style(jittable_mod.params)
137+
jax_buffers = self._shard_fsdp_style(jittable_mod.buffers)
138+
139+
@jax.checkpoint
140+
def lightning_mod_loss(
141+
weights: jax.Array, buffers: jax.Array, data: jax.Array, batch_id):
142+
"""returns loss"""
143+
with jax.named_scope("Computing_loss"):
144+
weights, buffers, data = torch_view((weights, buffers, data))
145+
# NOTE: these is needed because the original model
146+
# did not register those as persistent buffer
147+
with xla_env:
148+
loss = jittable_mod.functional_call(
149+
'training_step',
150+
weights, buffers, data, batch_id)
151+
return jax_view(loss)
152+
153+
jax_optimizer = optax.adamw(0.001)
154+
155+
opt_state = jax_optimizer.init(jax_params)
156+
grad_fn = jax.value_and_grad(lightning_mod_loss)
157+
158+
opt_state_sharding = jax.tree_util.tree_map(lambda p : p.sharding, opt_state)
159+
160+
print('Begining training')
161+
162+
@functools.partial(
163+
jax.jit,
164+
donate_argnums=(0, 2),
165+
)
166+
def step(jax_weights, jax_buffers, optimizer_state, xla_data, bid):
167+
print('Tracing inside of step')
168+
with jax.named_scope("Computing_loss_and_grad"):
169+
loss, grads = grad_fn(jax_weights, jax_buffers, xla_data, bid)
170+
with jax.named_scope("optimizer_updates"):
171+
updates, opt_state = jax_optimizer.update(
172+
grads, optimizer_state, jax_weights)
173+
jax_weights = optax.apply_updates(jax_weights, updates)
174+
return loss, jax_weights, opt_state
175+
176+
total_param_size = 0
177+
for k, v in jax_params.items():
178+
total_param_size += v.size
179+
180+
print('Total number of params: ', total_param_size)
181+
182+
print('Start compiling')
183+
start = time.perf_counter()
184+
lowered = step.lower(
185+
jax_params, jax_buffers, opt_state,
186+
(jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding),
187+
jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding)),
188+
0
189+
)
190+
# print(lowered.as_text())
191+
print('program size:', len(lowered.as_text()) / 1e6, 'm chars')
192+
step_compiled = lowered.compile()
193+
end = time.perf_counter()
194+
compile_time = end - start
195+
print('End compiling', compile_time)
196+
197+
for co in step_compiled.cost_analysis():
198+
print('flops counter:', co['flops'])
199+
200+
s = time.perf_counter()
201+
jax.profiler.start_trace('/tmp/tensorboard')
202+
print('start training')
203+
min_loop_time = 10000
204+
for i, item in enumerate(data_loader):
205+
inputs, labels = sharded_device_put(jax_view(xla_env.to_xla(item)),
206+
self.x_sharding)
207+
print('INPUT shape', inputs.shape)
208+
209+
step_start = time.perf_counter()
210+
loss, jax_params, opt_state = step_compiled(
211+
jax_params, jax_buffers, opt_state, (inputs, labels), 0)
212+
jax.block_until_ready((loss, jax_params))
213+
step_end = time.perf_counter()
214+
print(i, 'loss', loss, 'step latency: ', step_end - step_start)
215+
loop_time = step_end - step_start
216+
min_loop_time = min(min_loop_time, loop_time)
217+
print('======')
218+
if i >= 2:
219+
break
220+
jax.profiler.stop_trace()
221+
return min_loop_time, compile_time
222+
223+
224+
225+
def fake_dataloader(size, seqlen, batch_size):
226+
for _ in range(size):
227+
x = torch.randint(0, 32000, (batch_size, seqlen), device='cpu')
228+
yield x, (x + 1) % 32000
229+
230+
231+
def main(
232+
model_type='8B',
233+
batch_size=8,
234+
seqlen=2048,
235+
mode='regular',
236+
):
237+
logging.getLogger("jax").setLevel(logging.DEBUG)
238+
print(f"Running with parameters {locals()}")
239+
global SEQLEN
240+
global BATCH
241+
SEQLEN = seqlen
242+
BATCH = batch_size
243+
244+
mesh = jax.make_mesh((len(jax.local_devices()), ), ('fsdp', ))
245+
env = torch_xla2.default_env()
246+
env.config.use_tpu_flash_attention = use_flash_attention
247+
env.config.shmap_flash_attention = use_flash_attention
248+
249+
args = llama3_configs[model_type]
250+
#with torch.device('meta'):
251+
gpt = titan.Transformer(args)
252+
253+
light_mod = Module(gpt)
254+
light_mod.to(torch.bfloat16)
255+
256+
train_loader = fake_dataloader(10, seqlen, batch_size)
257+
258+
with mesh:
259+
trainer = Trainer()
260+
return trainer.fit(
261+
light_mod,
262+
train_loader
263+
)
264+
265+
266+
if __name__ == '__main__':
267+
import fire
268+
fire.Fire(main)

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def _aten_mul(x, y):
364364

365365

366366
@op(torch.ops.aten.silu)
367+
@op(torch.ops.aten.silu.default)
367368
def _aten_silu(x):
368369
return jax.nn.silu(x)
369370

experimental/torch_xla2/torch_xla2/ops/jtorch.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -130,21 +130,21 @@ def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0,
130130

131131
def _tpu_flash_attention(query, key, value, env):
132132
fsdp_partition = PartitionSpec('fsdp')
133-
block_sizes = flash_attention.BlockSizes(
134-
block_b=min(2, query.shape[0]),
135-
block_q=min(512, query.shape[2]),
136-
block_k_major=min(512, key.shape[2]),
137-
block_k=min(512, key.shape[2]),
138-
block_q_major_dkv=min(512, query.shape[2]),
139-
block_k_major_dkv=min(512, key.shape[2]),
140-
block_k_dkv=min(512, key.shape[2]),
141-
block_q_dkv=min(512, query.shape[2]),
142-
block_k_major_dq=min(512, key.shape[2]),
143-
block_k_dq=min(256, key.shape[2]),
144-
block_q_dq=min(1024, query.shape[2]),
145-
)
146133
def wrap_flash_attention(query, key, value):
147-
return flash_attention.flash_attention(
134+
block_sizes = flash_attention.BlockSizes(
135+
block_b=min(2, query.shape[0]),
136+
block_q=min(512, query.shape[2]),
137+
block_k_major=min(512, key.shape[2]),
138+
block_k=min(512, key.shape[2]),
139+
block_q_major_dkv=min(512, query.shape[2]),
140+
block_k_major_dkv=min(512, key.shape[2]),
141+
block_k_dkv=min(512, key.shape[2]),
142+
block_q_dkv=min(512, query.shape[2]),
143+
block_k_major_dq=min(512, key.shape[2]),
144+
block_k_dq=min(256, key.shape[2]),
145+
block_q_dq=min(1024, query.shape[2]),
146+
)
147+
return flash_attention.flash_attention(
148148
query, key, value, causal=True, block_sizes=block_sizes)
149149

150150
if env.config.shmap_flash_attention:

experimental/torch_xla2/torch_xla2/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def __init__(self, configuration=None):
281281

282282
def get_as_jax_device(self, device: Any):
283283
if device is None:
284-
return jax.devices()[0]
284+
device = torch.get_default_device()
285285

286286
if isinstance(device, torch.device):
287287
device = str(device)

0 commit comments

Comments
 (0)