|
| 1 | +import os |
| 2 | +import time |
| 3 | +import logging |
| 4 | +from typing import Tuple |
| 5 | +from collections import defaultdict |
| 6 | +import functools |
| 7 | + |
| 8 | +def _setup_default_env(): |
| 9 | + os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') |
| 10 | + os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') |
| 11 | + os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') |
| 12 | + # only need for tpu v4 |
| 13 | + # os.environ.setdefault('TPU_MEGACORE', 'megacore_dense') |
| 14 | + tpu_args = "--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" |
| 15 | + os.environ.setdefault('LIBTPU_INIT_ARGS', tpu_args) |
| 16 | + |
| 17 | +_setup_default_env() |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.nn.functional |
| 21 | +from torch.utils import _pytree as pytree |
| 22 | + |
| 23 | +import torch_xla2 |
| 24 | +import torch_xla2.interop |
| 25 | +from torch_xla2.interop import jax_view, torch_view, JittableModule |
| 26 | +import jax |
| 27 | +import jax.numpy as jnp |
| 28 | +from jax.experimental import shard_map |
| 29 | +from jax.experimental import mesh_utils |
| 30 | +import optax |
| 31 | + |
| 32 | +from torchtitan.models.llama import llama3_configs |
| 33 | +from torchtitan.models.llama import model as titan |
| 34 | + |
| 35 | +P = jax.sharding.PartitionSpec |
| 36 | + |
| 37 | + |
| 38 | + |
| 39 | +SEQLEN = 8192 |
| 40 | +BATCH = 8 |
| 41 | +global_axis: Tuple[str, str] = ('fsdp', ) |
| 42 | +num_global_devices = jax.device_count() |
| 43 | +num_local_devices = jax.local_device_count() |
| 44 | +num_partitions = (num_global_devices, ) |
| 45 | + |
| 46 | + |
| 47 | +def sharded_device_put(tensor, sharding): |
| 48 | + if isinstance(tensor, tuple): |
| 49 | + return tuple(sharded_device_put(t, sharding) for t in tensor) |
| 50 | + |
| 51 | + if num_global_devices == num_local_devices: |
| 52 | + return jax.device_put(tensor, sharding) |
| 53 | + |
| 54 | + shape = tensor.shape |
| 55 | + x_split = [jax.device_put(tensor[i], device) for device, i in sharding.addressable_devices_indices_map(shape).items()] |
| 56 | + return jax.make_array_from_single_device_arrays(shape, sharding, x_split) |
| 57 | + |
| 58 | + |
| 59 | +class FSDPv2(torch.nn.Module): |
| 60 | + |
| 61 | + def __init__(self, mod): |
| 62 | + super().__init__() |
| 63 | + self.mod = mod |
| 64 | + self.mesh = jax.sharding.Mesh( |
| 65 | + mesh_utils.create_device_mesh(num_partitions), |
| 66 | + axis_names=global_axis, |
| 67 | + ) |
| 68 | + self.sharding = jax.sharding.NamedSharding(self.mesh, P(*global_axis)) |
| 69 | + |
| 70 | + def forward(self, *args): |
| 71 | + args = list(args) |
| 72 | + args[0] = self.shard(args[0]) |
| 73 | + res = self.mod(*args) |
| 74 | + return self.shard(res) |
| 75 | + |
| 76 | + def shard(self, x): |
| 77 | + return torch_xla2.interop.call_jax( |
| 78 | + jax.lax.with_sharding_constraint, |
| 79 | + x, |
| 80 | + self.sharding, |
| 81 | + ) |
| 82 | + |
| 83 | +def print_shapes(pyt): |
| 84 | + for p in pytree.tree_flatten(pyt)[0]: |
| 85 | + if hasattr(p, 'shape'): |
| 86 | + print(p.shape, p.dtype) |
| 87 | + |
| 88 | + |
| 89 | +class Module(torch.nn.Module): |
| 90 | + |
| 91 | + def __init__(self, inner): |
| 92 | + super().__init__() |
| 93 | + self.inner = FSDPv2(inner) |
| 94 | + |
| 95 | + def training_step(self, data, batch_id): |
| 96 | + x, y = data |
| 97 | + logits = self.inner(x) |
| 98 | + num_tokens = logits.shape[-1] |
| 99 | + logits = logits.reshape(-1, num_tokens) |
| 100 | + y = y.reshape(-1) |
| 101 | + return torch.nn.functional.cross_entropy( |
| 102 | + logits, y) |
| 103 | + |
| 104 | + |
| 105 | +class Trainer: |
| 106 | + |
| 107 | + def __init__(self): |
| 108 | + self.mesh = jax.sharding.Mesh( |
| 109 | + mesh_utils.create_device_mesh(num_partitions), |
| 110 | + axis_names=global_axis, |
| 111 | + ) |
| 112 | + self.x_sharding = jax.sharding.NamedSharding(self.mesh, P(global_axis)) |
| 113 | + self.replicated = jax.sharding.NamedSharding(self.mesh, P()) |
| 114 | + |
| 115 | + def _shard_fsdp_style(self, state_dict, sharding=None): |
| 116 | + if sharding is None: |
| 117 | + sharding = self.x_sharding |
| 118 | + def move_one_tensor(x): |
| 119 | + jval = torch_xla2.tensor.t2j(x) |
| 120 | + return sharded_device_put(jval, sharding) |
| 121 | + |
| 122 | + if isinstance(state_dict, torch.Tensor): |
| 123 | + return move_one_tensor(state_dict) |
| 124 | + res = {} |
| 125 | + for k, v in sorted(state_dict.items()): |
| 126 | + res[k] = move_one_tensor(v) |
| 127 | + return res |
| 128 | + |
| 129 | + def fit(self, lightning_mod, data_loader): |
| 130 | + xla_env = torch_xla2.default_env() |
| 131 | + jax.config.update('jax_enable_x64', False) |
| 132 | + xla_env._mesh = self.mesh |
| 133 | + xla_env.use_flash_attention = True |
| 134 | + |
| 135 | + jittable_mod = JittableModule(lightning_mod) |
| 136 | + jax_params = self._shard_fsdp_style(jittable_mod.params) |
| 137 | + jax_buffers = self._shard_fsdp_style(jittable_mod.buffers) |
| 138 | + |
| 139 | + @jax.checkpoint |
| 140 | + def lightning_mod_loss( |
| 141 | + weights: jax.Array, buffers: jax.Array, data: jax.Array, batch_id): |
| 142 | + """returns loss""" |
| 143 | + with jax.named_scope("Computing_loss"): |
| 144 | + weights, buffers, data = torch_view((weights, buffers, data)) |
| 145 | + # NOTE: these is needed because the original model |
| 146 | + # did not register those as persistent buffer |
| 147 | + with xla_env: |
| 148 | + loss = jittable_mod.functional_call( |
| 149 | + 'training_step', |
| 150 | + weights, buffers, data, batch_id) |
| 151 | + return jax_view(loss) |
| 152 | + |
| 153 | + jax_optimizer = optax.adamw(0.001) |
| 154 | + |
| 155 | + opt_state = jax_optimizer.init(jax_params) |
| 156 | + grad_fn = jax.value_and_grad(lightning_mod_loss) |
| 157 | + |
| 158 | + opt_state_sharding = jax.tree_util.tree_map(lambda p : p.sharding, opt_state) |
| 159 | + |
| 160 | + print('Begining training') |
| 161 | + |
| 162 | + @functools.partial( |
| 163 | + jax.jit, |
| 164 | + donate_argnums=(0, 2), |
| 165 | + ) |
| 166 | + def step(jax_weights, jax_buffers, optimizer_state, xla_data, bid): |
| 167 | + print('Tracing inside of step') |
| 168 | + with jax.named_scope("Computing_loss_and_grad"): |
| 169 | + loss, grads = grad_fn(jax_weights, jax_buffers, xla_data, bid) |
| 170 | + with jax.named_scope("optimizer_updates"): |
| 171 | + updates, opt_state = jax_optimizer.update( |
| 172 | + grads, optimizer_state, jax_weights) |
| 173 | + jax_weights = optax.apply_updates(jax_weights, updates) |
| 174 | + return loss, jax_weights, opt_state |
| 175 | + |
| 176 | + total_param_size = 0 |
| 177 | + for k, v in jax_params.items(): |
| 178 | + total_param_size += v.size |
| 179 | + |
| 180 | + print('Total number of params: ', total_param_size) |
| 181 | + |
| 182 | + print('Start compiling') |
| 183 | + start = time.perf_counter() |
| 184 | + lowered = step.lower( |
| 185 | + jax_params, jax_buffers, opt_state, |
| 186 | + (jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding), |
| 187 | + jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding)), |
| 188 | + 0 |
| 189 | + ) |
| 190 | + # print(lowered.as_text()) |
| 191 | + print('program size:', len(lowered.as_text()) / 1e6, 'm chars') |
| 192 | + step_compiled = lowered.compile() |
| 193 | + end = time.perf_counter() |
| 194 | + compile_time = end - start |
| 195 | + print('End compiling', compile_time) |
| 196 | + |
| 197 | + for co in step_compiled.cost_analysis(): |
| 198 | + print('flops counter:', co['flops']) |
| 199 | + |
| 200 | + s = time.perf_counter() |
| 201 | + jax.profiler.start_trace('/tmp/tensorboard') |
| 202 | + print('start training') |
| 203 | + min_loop_time = 10000 |
| 204 | + for i, item in enumerate(data_loader): |
| 205 | + inputs, labels = sharded_device_put(jax_view(xla_env.to_xla(item)), |
| 206 | + self.x_sharding) |
| 207 | + print('INPUT shape', inputs.shape) |
| 208 | + |
| 209 | + step_start = time.perf_counter() |
| 210 | + loss, jax_params, opt_state = step_compiled( |
| 211 | + jax_params, jax_buffers, opt_state, (inputs, labels), 0) |
| 212 | + jax.block_until_ready((loss, jax_params)) |
| 213 | + step_end = time.perf_counter() |
| 214 | + print(i, 'loss', loss, 'step latency: ', step_end - step_start) |
| 215 | + loop_time = step_end - step_start |
| 216 | + min_loop_time = min(min_loop_time, loop_time) |
| 217 | + print('======') |
| 218 | + if i >= 2: |
| 219 | + break |
| 220 | + jax.profiler.stop_trace() |
| 221 | + return min_loop_time, compile_time |
| 222 | + |
| 223 | + |
| 224 | + |
| 225 | +def fake_dataloader(size, seqlen, batch_size): |
| 226 | + for _ in range(size): |
| 227 | + x = torch.randint(0, 32000, (batch_size, seqlen), device='cpu') |
| 228 | + yield x, (x + 1) % 32000 |
| 229 | + |
| 230 | + |
| 231 | +def main( |
| 232 | + model_type='8B', |
| 233 | + batch_size=8, |
| 234 | + seqlen=2048, |
| 235 | + mode='regular', |
| 236 | +): |
| 237 | + logging.getLogger("jax").setLevel(logging.DEBUG) |
| 238 | + print(f"Running with parameters {locals()}") |
| 239 | + global SEQLEN |
| 240 | + global BATCH |
| 241 | + SEQLEN = seqlen |
| 242 | + BATCH = batch_size |
| 243 | + |
| 244 | + mesh = jax.make_mesh((len(jax.local_devices()), ), ('fsdp', )) |
| 245 | + env = torch_xla2.default_env() |
| 246 | + env.config.use_tpu_flash_attention = use_flash_attention |
| 247 | + env.config.shmap_flash_attention = use_flash_attention |
| 248 | + |
| 249 | + args = llama3_configs[model_type] |
| 250 | + #with torch.device('meta'): |
| 251 | + gpt = titan.Transformer(args) |
| 252 | + |
| 253 | + light_mod = Module(gpt) |
| 254 | + light_mod.to(torch.bfloat16) |
| 255 | + |
| 256 | + train_loader = fake_dataloader(10, seqlen, batch_size) |
| 257 | + |
| 258 | + with mesh: |
| 259 | + trainer = Trainer() |
| 260 | + return trainer.fit( |
| 261 | + light_mod, |
| 262 | + train_loader |
| 263 | + ) |
| 264 | + |
| 265 | + |
| 266 | +if __name__ == '__main__': |
| 267 | + import fire |
| 268 | + fire.Fire(main) |
0 commit comments