In this notebook we study the LazyTensor IR lowering of a toy decoder layer with
or without AOTAutograd.

In [1]:
# Initialization

from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel
from aot_flash_attention import flash_attention_2

import time
import os
import torch_xla
import torch_xla.debug.metrics
import torch
import torch_xla.distributed.spmd as xs
import torch_xla.utils.utils as xu
import torch_xla.distributed.parallel_loader as pl
import torch_xla.debug.profiler as xp
from torch_xla import runtime as xr
from itertools import chain
from tqdm import tqdm


# Sharding
num_devices = xr.global_runtime_device_count()
tensor_axis = 2
fsdp_axis = num_devices // tensor_axis
mesh_shape = (fsdp_axis, tensor_axis)
print(f"Single-slice sharding: mesh={mesh_shape}")
spmd_mesh = xs.Mesh(
    list(range(num_devices)), mesh_shape, ('fsdp', 'tensor'))
xs.set_global_mesh(spmd_mesh)
xr.use_spmd()

print("Building model")
device = torch_xla.device()
config = DecoderOnlyConfig(
    hidden_size=1024,
    num_hidden_layers=3,
    use_flash_attention=True)
config.intermediate_size = 4096
config.vocab_size = 8192
model = DecoderOnlyModel(config=config).bfloat16().to(device)
batch_size = 32
sequence_length = 1024

model.use_offload_(False)
model.use_scan_(True)
for layer in model.layers:
  layer.self_attn.flash_attention_impl = flash_attention_2  # type: ignore

# Mark model weights to be sharded
for name, param in chain(model.named_parameters(), model.named_buffers()):
  print('> [2D] Sharding tensor', name, param.shape)

  # Here we intentionally skip layernorm and moe.gate weights given they are small.
  if 'embed_tokens' in name:
    xs.mark_sharding(param, spmd_mesh, ('fsdp', 'tensor'))
  elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name:
    xs.mark_sharding(param, spmd_mesh, ('tensor', 'fsdp'))
  elif 'o_proj' in name:
    xs.mark_sharding(param, spmd_mesh, ('fsdp', 'tensor'))
  elif 'gate_proj' in name or 'up_proj' in name:
    xs.mark_sharding(param, spmd_mesh, ('tensor', 'fsdp'))
  elif 'down_proj' in name:
    xs.mark_sharding(param, spmd_mesh, ('fsdp', 'tensor'))
  elif 'lm_head' in name:
    xs.mark_sharding(param, spmd_mesh, (('tensor', 'fsdp'), None))

  print(f'{name} {torch_xla._XLAC._get_xla_sharding_spec(param)}')

# Generate random input_ids within the range of the vocabulary size
input_ids = torch.randint(
    0, config.vocab_size, (batch_size, sequence_length), device=device)
# Shard the input data too.
xs.mark_sharding(input_ids, spmd_mesh, ('fsdp', None))
xs.set_global_mesh(spmd_mesh)

optimizer = torch.optim.SGD(model.parameters(), lr=0.00001)
torch_xla.sync(wait=True)




Single-slice sharding: mesh=(4, 2)
Building model
> [2D] Sharding tensor embed_tokens.weight torch.Size([8192, 1024])
embed_tokens.weight {devices=[4,2]0,1,2,3,4,5,6,7}
> [2D] Sharding tensor layers.0.self_attn.q_proj.weight torch.Size([1024, 1024])
layers.0.self_attn.q_proj.weight {devices=[2,4]0,2,4,6,1,3,5,7}
> [2D] Sharding tensor layers.0.self_attn.k_proj.weight torch.Size([512, 1024])
layers.0.self_attn.k_proj.weight {devices=[2,4]0,2,4,6,1,3,5,7}
> [2D] Sharding tensor layers.0.self_attn.v_proj.weight torch.Size([512, 1024])
layers.0.self_attn.v_proj.weight {devices=[2,4]0,2,4,6,1,3,5,7}
> [2D] Sharding tensor layers.0.self_attn.o_proj.weight torch.Size([1024, 1024])
layers.0.self_attn.o_proj.weight {devices=[4,2]0,1,2,3,4,5,6,7}
> [2D] Sharding tensor layers.0.mlp.gate_proj.weight torch.Size([4096, 1024])
layers.0.mlp.gate_proj.weight {devices=[2,4]0,2,4,6,1,3,5,7}
> [2D] Sharding tensor layers.0.mlp.up_proj.weight torch.Size([4096, 1024])
layers.0.mlp.up_proj.weight {devices=[

In [2]:
decoder_layer = model.layers[0]
type(decoder_layer)

decoder_only_model.DecoderLayer

In [3]:
inputs_embeds = model.embed_tokens(input_ids).clone().detach()
torch_xla.sync(wait=True)
print(type(inputs_embeds), inputs_embeds.shape, inputs_embeds.dtype)

<class 'torch.Tensor'> torch.Size([32, 1024, 1024]) torch.bfloat16


Lower a toy decoder layer into LazyTensor IR directly.

In [4]:
import torch_xla.debug.metrics as met
met.clear_all()

In [5]:
decoder_out = decoder_layer(inputs_embeds)
print(torch_xla._XLAC._get_xla_tensors_text([decoder_out]))
torch_xla.sync(wait=True)

IR {
  %0 = bf16[] prim::Constant(), xla_shape=bf16[]
  %1 = bf16[1024,4096]{1,0} xla::device_data(), xla_shape=bf16[1024,4096]{1,0}
  %2 = bf16[4096,1024]{0,1} aten::permute(%1), xla_shape=bf16[4096,1024]{0,1}
  %3 = bf16[4096,1024]{1,0} xla::device_data(), xla_shape=bf16[4096,1024]{1,0}
  %4 = bf16[1024,4096]{0,1} aten::permute(%3), xla_shape=bf16[1024,4096]{0,1}
  %5 = f32[] prim::Constant(), xla_shape=f32[]
  %6 = f32[] xla::device_data(), xla_shape=f32[]
  %7 = f32[] prim::Constant(), xla_shape=f32[]
  %8 = bf16[] prim::Constant(), xla_shape=bf16[]
  %9 = bf16[1024,1024]{1,0} xla::device_data(), xla_shape=bf16[1024,1024]{1,0}
  %10 = bf16[1024,1024]{0,1} aten::permute(%9), xla_shape=bf16[1024,1024]{0,1}
  %11 = bf16[512,1024]{1,0} xla::device_data(), xla_shape=bf16[512,1024]{1,0}
  %12 = bf16[1024,512]{0,1} aten::permute(%11), xla_shape=bf16[1024,512]{0,1}
  %13 = f32[] prim::Constant(), xla_shape=f32[]
  %14 = f32[] xla::device_data(), xla_shape=f32[]
  %15 = f32[] prim::Constant

In [6]:
import torch_xla.debug.metrics as met
print(met.metrics_report())

Metric: DeviceLockWait
  TotalSamples: 2
  Accumulator: 034.990us
  ValueRate: 05s887ms871.508us / second
  Rate: 279330 / second
  Percentiles: 1%=008.080us; 5%=008.080us; 10%=008.080us; 20%=008.080us; 50%=026.910us; 80%=026.910us; 90%=026.910us; 95%=026.910us; 99%=026.910us
Metric: InputOutputAliasCount
  TotalSamples: 1
  Accumulator: 0.00
  Percentiles: 1%=0.00; 5%=0.00; 10%=0.00; 20%=0.00; 50%=0.00; 80%=0.00; 90%=0.00; 95%=0.00; 99%=0.00
Metric: LazyTracing
  TotalSamples: 132
  Accumulator: 023ms453.661us
  ValueRate: 050ms623.296us / second
  Rate: 279.286 / second
  Percentiles: 1%=007.330us; 5%=008.140us; 10%=023.490us; 20%=058.730us; 50%=089.600us; 80%=253.070us; 90%=360.600us; 95%=583.459us; 99%=001ms129.530us
Metric: TensorToData
  TotalSamples: 1
  Accumulator: 990.460us
  Percentiles: 1%=990.460us; 5%=990.460us; 10%=990.460us; 20%=990.460us; 50%=990.460us; 80%=990.460us; 90%=990.460us; 95%=990.460us; 99%=990.460us
Metric: TensorsGraphSize
  TotalSamples: 1
  Accumulator: 

Lower a toy decoder layer to ATen ops via AOTAutograd first, then into LazyTensor IR.

In [7]:
import torch_xla.debug.metrics as met
met.clear_all()

In [8]:
from functorch.compile import aot_function, make_boxed_func  # type: ignore
from torch.func import functional_call  # type: ignore

def print_graph_compiler(gm, _):
  print("AOTAutograd graph:")
  print(gm.code)
  return make_boxed_func(gm)

compiled_layer = aot_function(lambda x, params: functional_call(decoder_layer, params, x), print_graph_compiler)
decoder_out = compiled_layer(inputs_embeds, dict(decoder_layer.named_parameters()))
print(torch_xla._XLAC._get_xla_tensors_text([decoder_out]))
torch_xla.sync(wait=True)

AOTAutograd graph:



def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10):
    offload_name = torch.ops.xla.offload_name.default(primals_1, 'decoder_input');  primals_1 = None
    _to_copy = torch.ops.aten._to_copy.default(offload_name, dtype = torch.float32)
    pow_1 = torch.ops.aten.pow.Tensor_Scalar(_to_copy, 2)
    mean = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None
    add = torch.ops.aten.add.Tensor(mean, 1e-06);  mean = None
    rsqrt = torch.ops.aten.rsqrt.default(add);  add = None
    mul = torch.ops.aten.mul.Tensor(_to_copy, rsqrt);  _to_copy = rsqrt = None
    _to_copy_1 = torch.ops.aten._to_copy.default(mul, dtype = torch.bfloat16);  mul = None
    mul_1 = torch.ops.aten.mul.Tensor(primals_9, _to_copy_1);  primals_9 = None
    t = torch.ops.aten.t.default(primals_2);  primals_2 = None
    view = torch.ops.aten.view.default(mul_1, [32768, 1024])
    mm = torch.ops.aten.mm.default(view,

In [9]:
import torch_xla.debug.metrics as met
print(met.metrics_report())

Metric: DeviceLockWait
  TotalSamples: 2
  Accumulator: 035.420us
  ValueRate: 05s940ms027.894us / second
  Rate: 278940 / second
  Percentiles: 1%=008.380us; 5%=008.380us; 10%=008.380us; 20%=008.380us; 50%=027.040us; 80%=027.040us; 90%=027.040us; 95%=027.040us; 99%=027.040us
Metric: InputOutputAliasCount
  TotalSamples: 1
  Accumulator: 0.00
  Percentiles: 1%=0.00; 5%=0.00; 10%=0.00; 20%=0.00; 50%=0.00; 80%=0.00; 90%=0.00; 95%=0.00; 99%=0.00
Metric: LazyTracing
  TotalSamples: 151
  Accumulator: 017ms099.420us
  ValueRate: 705ms210.694us / second
  Rate: 6227.51 / second
  Percentiles: 1%=006.940us; 5%=007.240us; 10%=013.440us; 20%=024.140us; 50%=083.870us; 80%=121.610us; 90%=239.510us; 95%=318.030us; 99%=001ms002.750us
Metric: TensorsGraphSize
  TotalSamples: 1
  Accumulator: 112.00
  Percentiles: 1%=112.00; 5%=112.00; 10%=112.00; 20%=112.00; 50%=112.00; 80%=112.00; 90%=112.00; 95%=112.00; 99%=112.00
Metric: UnwrapXlaData
  TotalSamples: 1
  Accumulator: 018.730us
  Percentiles: 1%=0