In [1]:
%env PJRT_DEVICE=TPU
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1

# MaxText flags except that we disable optimization barrier removal
# %env LIBTPU_INIT_ARGS=--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_decompose_all_gather_einsum=true --xla_tpu_decompose_einsum_reduce_scatter=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2

env: PJRT_DEVICE=TPU
env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1


In [2]:
import torch_xla
import torch
from torch_xla import runtime as xr

In [3]:
from torch.autograd.graph import saved_tensors_hooks
from torch_xla.experimental.stablehlo_custom_call import (
  place_to_host, place_to_device
)

class OffloadingModule(torch.nn.Module):
  def __init__(self, m):
    super().__init__()
    self.m = m

  def forward(self, *args, **kwargs):
    with saved_tensors_hooks(place_to_host, place_to_device):
      return self.m(*args, **kwargs)

In [4]:
import decoder_only_model
from trainer import TrainDecoderOnlyBase
import functools

## Test 1D FSDP sharding

In [5]:
import torch
import numpy as np
import torch_xla.distributed.spmd as xs
import torch_xla.utils.utils as xu
import torch_xla.distributed.parallel_loader as pl
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
from torch_xla import runtime as xr
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy

# checkout our doc at https://github.com/pytorch/xla/blob/master/docs/fsdpv2.md
class TrainDecoderOnlyFSDPv2(TrainDecoderOnlyBase):

  def __init__(self):
    super().__init__(decoder_only_model.DecoderOnlyConfig(
       hidden_size=4096,
       num_hidden_layers=32,
       num_attention_heads=16,
       num_key_value_heads=8,
       intermediate_size=8192,
       vocab_size=16384,
    ))
    # Define the mesh following common SPMD practice
    num_devices = xr.global_runtime_device_count()
    mesh_shape = (num_devices, 1)
    device_ids = np.array(range(num_devices))
    # To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
    mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))
    xs.set_global_mesh(mesh)

    # Shard the input(data parallel).
    # Scale the batch size with num_devices since there will be only one
    # process that handles all runtime devices.
    self.batch_size *= num_devices
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64),
              torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64)),
        sample_count=self.train_dataset_len // self.batch_size)
    self.train_device_loader = pl.MpDeviceLoader(
        train_loader,
        self.device,
        # Shard the input's batch dimension along the `fsdp` axis, no sharding along other dimensions
        input_sharding=xs.ShardingSpec(mesh, ('fsdp', None)))  # type:ignore
    
    model: decoder_only_model.DecoderOnlyModel = self.model  # type:ignore
    self.model = model

    # Apply checkpoint to each DecoderLayer layer.
    from torch_xla.distributed.fsdp import checkpoint_module
    for i, block in enumerate(self.model.layers):
        self.model.layers[i] = checkpoint_module(block)
        
    # Apply offloading to each DecoderLayer layer.
    from torch_xla.distributed.fsdp import checkpoint_module
    for i, block in enumerate(self.model.layers):
        self.model.layers[i] = OffloadingModule(block)

    # Apply FSDP sharding on each DecoderLayer layer.
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            OffloadingModule
        },
    )

    # FSDPv2 will use the global mesh set above
    self.model: torch.nn.Module = self.model
    self.model = FSDPv2(
        self.model, auto_wrap_policy=auto_wrap_policy)
    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.00001)


In [6]:
xr.use_spmd()

print("Compiling model")
base = TrainDecoderOnlyFSDPv2()
base.num_steps = 3
base.start_training()
torch_xla.sync(wait=True)

print("Profiling model")
import torch_xla.debug.profiler as xp
server = xp.start_server(9012)
xp.trace_detached(
    service_addr="localhost:9012", logdir="profile/", duration_ms=10000)
base.num_steps = 5
base.start_training()
del server

Compiling model


Epoch 1 train begin  6:52AM UTC on Nov 09, 2024


  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):


epoch: 1, step: 0, loss: 10.497030258178711, rate: 3.885934977170958
epoch: 1, step: 1, loss: 10.308744430541992, rate: 59.265707754513116
Epoch 1 train end  6:54AM UTC on Nov 09, 2024
epoch: 1, step: 2, loss: 10.120159149169922, rate: 81.529329088687
Profiling model
Epoch 1 train begin  6:54AM UTC on Nov 09, 2024
Starting to trace for 10000 ms. Remaining attempt(s): 2


2024-11-09 06:54:28.902295: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 951080 nanoseconds and will start immediately.


epoch: 1, step: 0, loss: 9.933348655700684, rate: 92.53957494213485
epoch: 1, step: 1, loss: 9.745992660522461, rate: 94.92735715523048
epoch: 1, step: 2, loss: 9.557472229003906, rate: 95.92802707093577
epoch: 1, step: 3, loss: 9.36929988861084, rate: 96.16945501334894
epoch: 1, step: 4, loss: 9.18035888671875, rate: 96.41094065149827
epoch: 1, step: 5, loss: 8.997147560119629, rate: 96.5329254760474
epoch: 1, step: 6, loss: 8.807817459106445, rate: 96.57038474682722
epoch: 1, step: 7, loss: 8.617790222167969, rate: 96.56631265184365
epoch: 1, step: 8, loss: 8.432229042053223, rate: 96.55531166528806
epoch: 1, step: 9, loss: 8.246435165405273, rate: 96.54987734217524
epoch: 1, step: 10, loss: 8.05585765838623, rate: 96.5551424700034
epoch: 1, step: 11, loss: 7.868907928466797, rate: 96.52002129625203
epoch: 1, step: 12, loss: 7.681149482727051, rate: 96.50362104908496
epoch: 1, step: 13, loss: 7.494668960571289, rate: 96.45459015651032
epoch: 1, step: 14, loss: 7.307222843170166, rate

## Test 2D (FSDP, TP) sharding