This notebook demonstrates 1d sharded checkpoint offload.

Profile (v6e-8): http://shortn/_AVBaAqnFNd

In [1]:
%env PJRT_DEVICE=TPU
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1
%env TPU_LIBRARY_PATH=/workspaces/torch/_libtpu.so

# MaxText flags except that we disable optimization barrier removal: crashes in native code
# %env LIBTPU_INIT_ARGS=--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_decompose_all_gather_einsum=true --xla_tpu_decompose_einsum_reduce_scatter=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2

# Still crashes in native code
# %env LIBTPU_INIT_ARGS=--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_decompose_all_gather_einsum=true --xla_tpu_decompose_einsum_reduce_scatter=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100

# Removed some more flags, and 1% scheduler shared memory limit: OK
%env LIBTPU_INIT_ARGS=--xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=1

env: PJRT_DEVICE=TPU
env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1
env: TPU_LIBRARY_PATH=/workspaces/torch/_libtpu.so
env: LIBTPU_INIT_ARGS=--xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=1


In [2]:
import torch_xla
import torch
from torch_xla import runtime as xr

In [3]:
from torch.autograd.graph import saved_tensors_hooks
from torch_xla.experimental.stablehlo_custom_call import (
  place_to_host, place_to_device
)

class OffloadingModule(torch.nn.Module):
  def __init__(self, m):
    super().__init__()
    self.m = m

  def forward(self, *args, **kwargs):
    with saved_tensors_hooks(place_to_host, place_to_device):
      return self.m(*args, **kwargs)

In [4]:
import decoder_only_model
from trainer import TrainDecoderOnlyBase
import functools

## Test 1D FSDP sharding

In [5]:
import torch
import numpy as np
import torch_xla.distributed.spmd as xs
import torch_xla.utils.utils as xu
import torch_xla.distributed.parallel_loader as pl
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
from torch_xla import runtime as xr
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy

# checkout our doc at https://github.com/pytorch/xla/blob/master/docs/fsdpv2.md
class TrainDecoderOnlyFSDPv2(TrainDecoderOnlyBase):

  def __init__(self):
    super().__init__(decoder_only_model.DecoderOnlyConfig(
       hidden_size=4096,
       num_hidden_layers=32,
       num_attention_heads=16,
       num_key_value_heads=8,
       intermediate_size=8192,
       vocab_size=16384,
    ))
    # Define the mesh following common SPMD practice
    num_devices = xr.global_runtime_device_count()
    mesh_shape = (num_devices, 1)
    device_ids = np.array(range(num_devices))
    # To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
    mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'model'))
    xs.set_global_mesh(mesh)

    # Shard the input(data parallel).
    # Scale the batch size with num_devices since there will be only one
    # process that handles all runtime devices.
    self.batch_size *= num_devices
    train_loader = xu.SampleGenerator(
        data=(torch.randint(
            0,
            self.config.vocab_size, (self.batch_size, self.seq_len),
            dtype=torch.int64,
            device='cpu'),
              torch.randint(
                  0,
                  self.config.vocab_size, (self.batch_size, self.seq_len),
                  dtype=torch.int64,
                  device='cpu')),
        sample_count=self.train_dataset_len // self.batch_size)
    self.train_device_loader = pl.MpDeviceLoader(
        train_loader,
        self.device,
        # Shard the input's batch dimension along the `fsdp` axis, no sharding along other dimensions
        input_sharding=xs.ShardingSpec(mesh, ('fsdp', None)))  # type:ignore
    
    model: decoder_only_model.DecoderOnlyModel = self.model  # type:ignore
    self.model = model

    # Apply checkpoint to each DecoderLayer layer.
    from torch_xla.distributed.fsdp import checkpoint_module
    for i, block in enumerate(self.model.layers):
        self.model.layers[i] = checkpoint_module(block)
        
    # Apply offloading to each DecoderLayer layer.
    from torch_xla.distributed.fsdp import checkpoint_module
    for i, block in enumerate(self.model.layers):
        self.model.layers[i] = OffloadingModule(block)

    # Apply FSDP sharding on each DecoderLayer layer.
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            OffloadingModule
        },
    )

    # FSDPv2 will use the global mesh set above
    self.model: torch.nn.Module = self.model
    self.model = FSDPv2(
        self.model, auto_wrap_policy=auto_wrap_policy)
    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.00001)


In [6]:
xr.use_spmd()
base = TrainDecoderOnlyFSDPv2()

In [7]:
print("Compiling model")
base.num_steps = 3
base.start_training()
torch_xla.sync(wait=True)

print("Profiling model")
import torch_xla.debug.profiler as xp
server = xp.start_server(9012)
xp.trace_detached(
    service_addr="localhost:9012", logdir=f"profile", duration_ms=15000)
base.num_steps = 5
base.start_training()
torch_xla.sync(wait=True)
del server

Compiling model
Epoch 1 train begin  6:27AM UTC on Nov 10, 2024


  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):


epoch: 1, step: 0, loss: 9.870162010192871, rate: 3.319816552357878
epoch: 1, step: 1, loss: 9.870155334472656, rate: 46.719624509538505
Epoch 1 train end  6:30AM UTC on Nov 10, 2024
epoch: 1, step: 2, loss: 9.870158195495605, rate: 64.20032894730252
Profiling model
Epoch 1 train begin  6:30AM UTC on Nov 10, 2024
Starting to trace for 15000 ms. Remaining attempt(s): 2


2024-11-10 06:30:27.324573: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 4520830 nanoseconds and will start immediately.


epoch: 1, step: 0, loss: 9.870157241821289, rate: 73.35901042365177
epoch: 1, step: 1, loss: 9.870158195495605, rate: 74.85404335119196
epoch: 1, step: 2, loss: 9.870161056518555, rate: 75.44193497289696
epoch: 1, step: 3, loss: 9.870157241821289, rate: 75.70407848861927
Epoch 1 train end  6:30AM UTC on Nov 10, 2024
epoch: 1, step: 4, loss: 9.87015438079834, rate: 75.80861577492679


In [8]:
from visualize import visualize_tensor_sharding
_ = visualize_tensor_sharding(base.model.embed_tokens.weight, use_color=False)

In [9]:
_ = visualize_tensor_sharding(base.model.layers[0].m.self_attn.q_proj.weight, use_color=False)

In [10]:
_ = visualize_tensor_sharding(base.model.layers[0].m.self_attn.o_proj.weight, use_color=False)