In [1]:
%env PJRT_DEVICE=TPU
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1
%env TPU_LIBRARY_PATH=/workspaces/torch/_libtpu.so
%env XLA_FLAGS=--xla_dump_to=xla_dumps

# MaxText flags except that we disable optimization barrier removal: crashes in native code
%env LIBTPU_INIT_ARGS=--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_decompose_all_gather_einsum=true --xla_tpu_decompose_einsum_reduce_scatter=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2

# Still crashes in native code
# %env LIBTPU_INIT_ARGS=--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_decompose_all_gather_einsum=true --xla_tpu_decompose_einsum_reduce_scatter=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100

# Removed some more flags, and 1% scheduler shared memory limit: OK
# %env LIBTPU_INIT_ARGS=--xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=1

env: PJRT_DEVICE=TPU
env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1
env: TPU_LIBRARY_PATH=/workspaces/torch/_libtpu.so
env: XLA_FLAGS=--xla_dump_to=xla_dumps
env: LIBTPU_INIT_ARGS=--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_decompose_all_gather_einsum=true --xla_tpu_decompose_einsum_reduce_scatter=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED 

In [2]:
import torch_xla
import torch
from torch_xla import runtime as xr

In [3]:
from torch.autograd.graph import saved_tensors_hooks
from torch_xla.experimental.stablehlo_custom_call import (
  place_to_host, place_to_device
)

class OffloadingModule(torch.nn.Module):
  def __init__(self, m):
    super().__init__()
    self.m = m

  def forward(self, *args, **kwargs):
    with saved_tensors_hooks(place_to_host, place_to_device):
      return self.m(*args, **kwargs)

In [4]:
import decoder_only_model
from trainer import TrainDecoderOnlyBase
import functools

## Test 2D FSDP+TP sharding

In [5]:
import torch
import numpy as np
import torch_xla.distributed.spmd as xs
import torch_xla.utils.utils as xu
import torch_xla.distributed.parallel_loader as pl
from torch_xla import runtime as xr
from itertools import chain

class TrainDecoderOnlyFSDPv2(TrainDecoderOnlyBase):

  def __init__(self):
    super().__init__(decoder_only_model.DecoderOnlyConfig(
       hidden_size=4096,
       num_hidden_layers=2,
       num_attention_heads=16,
       num_key_value_heads=8,
       intermediate_size=8192,
       vocab_size=16384,
    ))
    # Define the mesh following common SPMD practice
    num_devices = xr.global_runtime_device_count()
    tensor_axis = 4
    fsdp_axis = num_devices // tensor_axis
    mesh_shape = (fsdp_axis, tensor_axis)
    print(f"Single-slice sharding: mesh={mesh_shape}")
    spmd_mesh = xs.Mesh(list(range(num_devices)), mesh_shape, ('fsdp', 'tensor'))
    xs.set_global_mesh(spmd_mesh)

    model: decoder_only_model.DecoderOnlyModel = self.model  # type:ignore
    self.model = model
   
    # Mark model weights to be sharded
    for name, param in chain(model.named_parameters(), model.named_buffers()):
      # Here we intentionally skip layernorm and moe.gate weights given they are small.
      if 'embed_tokens' in name:
          xs.mark_sharding(param, spmd_mesh, ('fsdp', 'tensor'))
      elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name:
          xs.mark_sharding(param, spmd_mesh, ('tensor', 'fsdp'))
      elif 'o_proj' in name:
          xs.mark_sharding(param, spmd_mesh, ('fsdp', 'tensor'))
      elif 'gate_proj' in name or 'up_proj' in name:
          xs.mark_sharding(param, spmd_mesh, ('tensor', 'fsdp'))
      elif 'down_proj' in name:
          xs.mark_sharding(param, spmd_mesh, ('fsdp', 'tensor'))
      elif 'lm_head' in name:
          xs.mark_sharding(param, spmd_mesh, (('tensor', 'fsdp'), None))

    # Shard the input.
    # Scale the batch size with num_devices since there will be only one
    # process that handles all runtime devices.
    self.batch_size *= num_devices
    train_loader = xu.SampleGenerator(
        data=(torch.randint(
            0,
            self.config.vocab_size, (self.batch_size, self.seq_len),
            dtype=torch.int64,
            device='cpu'),
              torch.randint(
                  0,
                  self.config.vocab_size, (self.batch_size, self.seq_len),
                  dtype=torch.int64,
                  device='cpu')),
        sample_count=self.train_dataset_len // self.batch_size)
    self.train_device_loader = pl.MpDeviceLoader(
        train_loader,
        self.device,
        # Shard the input's batch dimension along the `fsdp` axis, no sharding along other dimensions
        input_sharding=xs.ShardingSpec(spmd_mesh, ('fsdp', None)))  # type:ignore
    
    # Apply checkpoint to each DecoderLayer layer.
    from torch_xla.distributed.fsdp import checkpoint_module
    for i, block in enumerate(self.model.layers):
        self.model.layers[i] = checkpoint_module(block)
        
    # Apply offloading to each DecoderLayer layer.
    from torch_xla.distributed.fsdp import checkpoint_module
    for i, block in enumerate(self.model.layers):
        self.model.layers[i] = OffloadingModule(block)

    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.00001)
    torch_xla.sync(wait=True)


In [None]:
xr.use_spmd()
base = TrainDecoderOnlyFSDPv2()

Single-slice sharding: mesh=(2, 4)


: 

In [None]:
print("Compiling model")
base.num_steps = 3
base.start_training()
torch_xla.sync(wait=True)

print("Profiling model")
import torch_xla.debug.profiler as xp
server = xp.start_server(9012)
xp.trace_detached(
    service_addr="localhost:9012", logdir="profile/", duration_ms=15000)
base.num_steps = 5
base.start_training()
torch_xla.sync(wait=True)
del server

Compiling model
Epoch 1 train begin  6:24AM UTC on Nov 10, 2024


  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
https://symbolize.corp.google.com/r/?trace=7987c70fb0ad,7995a9c49dcf,7987c70fa2bc,7987c71020c3,7987cc08cc11,7987cc08c678,7987c70a6ee1,7987c709ceb0,7987c97c41b2,7987ceaf4d12,7987ceafaed5,7987ceb03a44,7987cec9f602,7995a9bf6ea6&map= 
*** SIGSEGV (@0x18), see go/stacktraces#s15 received by PID 14024 (TID 14962) on cpu 90; stack trace: ***
PC: @     0x7987c70fb0ad  (unknown)  xla::jellyfish::(anonymous namespace)::HasActivationSemantics()
    @     0x7987cedc80e1       1888  FailureSignalHandler()
    @     0x7995a9c49dd0  (unknown)  (unknown)
    @     0x7987c70fa2bd        384  xla::jellyfish::AllReduceDecomposer::TryMatchAllGatherEinsumOrEinsumReduceScatter()
    @     0x7987c71020c4        400  xla::jellyfish::AllReduceDecomposer::Run()
    @     0x7987cc08cc12        416  xla::HloPassPipeline::RunPassesInternal<>()
    @     0x7987cc08c679         80  xla::HloPassPipeline::Run()