In [1]:
# These are the offload flags used by MaxText.
# %env LIBTPU_INIT_ARGS=--xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=24 --xla_tpu_aggressive_opt_barrier_removal=ENABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2

# Test other flags
%env LIBTPU_INIT_ARGS=--xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=24 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2 --xla_jf_rematerialization_percent_shared_memory_limit=10000

# Debugging flags
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1
%env PJRT_DEVICE=TPU

env: LIBTPU_INIT_ARGS=--xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=24 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2 --xla_jf_rematerialization_percent_shared_memory_limit=10000
env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1
env: PJRT_DEVICE=TPU


In [2]:
import torch_xla
import torch_xla.runtime

from functorch.compile import aot_function
import torch
from torch_xla.experimental.custom_kernel import flash_attention_xla

import torch_xla.runtime

with torch_xla.runtime.xla_device():
    # Here's how to teach AOTAutograd about a custom op.
    def pallas(c):
        return torch.ops.xla.flash_attention(c, c, c, True)  # type: ignore
        # return flash_attention_xla(c, c, c, causal=True)


    def fn(a, b, c, d):
        x = a + b + c + d + pallas(c)
        return x.cos().cos()

    a, b, c, d = [torch.randn(4, 4, 4, 4, requires_grad=True) for _ in range(4)]
    ref = fn(a, b, c, d)
    loss = ref.sum()
    loss.backward()


    # The compiler_fn is called after the forward and backward graphs are extracted.
    # Here, we just print the code in the compiler_fn. Return of this function is a callable.
    def compiler_fn(fx_module: torch.fx.GraphModule, _):
        print(fx_module.code)
        return fx_module

    # Pass on the compiler_fn to the aot_function API
    aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)

    # Run the aot_print_fn once to trigger the compilation and print the graphs
    cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]
    cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs
    res = aot_print_fn(cloned_a, cloned_b, cloned_c, cloned_d)
    res.sum().backward()
    assert torch.allclose(ref, res)


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass





def forward(self, primals_1, primals_2, primals_3, primals_4):
    add = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    add_1 = torch.ops.aten.add.Tensor(add, primals_3);  add = None
    add_2 = torch.ops.aten.add.Tensor(add_1, primals_4);  add_1 = primals_4 = None
    flash_attention = torch.ops.xla.flash_attention.default(primals_3, primals_3, primals_3, True);  primals_3 = None
    add_3 = torch.ops.aten.add.Tensor(add_2, flash_attention);  add_2 = flash_attention = None
    cos = torch.ops.aten.cos.default(add_3)
    cos_1 = torch.ops.aten.cos.default(cos)
    return (cos_1, add_3, cos)
    



def forward(self, add_3, cos, tangents_1):
    sin = torch.ops.aten.sin.default(cos);  cos = None
    neg = torch.ops.aten.neg.default(sin);  sin = None
    mul = torch.ops.aten.mul.Tensor(tangents_1, neg);  tangents_1 = neg = None
    sin_1 = torch.ops.aten.sin.default(add_3);  add_3 = None
    neg_1 = torch.ops.aten.neg.default(sin_1);  sin_1 = None
 

In [3]:
import torch
import itertools

import torch_xla.core.xla_model as xm
from torch_xla.debug.profiler import Trace
from torch_xla.experimental.stablehlo_custom_call import place_to_host, place_to_device


def offload(module: torch.nn.Module) -> torch.nn.Module:
  from functorch.compile import aot_module

  def should_offload(t: torch.Tensor):
    for p in module.parameters():
      if t is p:
        print(f"Skip offloading {type(t)} {t.shape}")
        return False
    return True

  def maybe_place_to_host(t):
    if should_offload(t):
      print(f"Offload {t.shape} tensor to host")
      return place_to_host(t)
    else:
      return t

  def maybe_place_to_device(t):
    if should_offload(t):
      print(f"Bring back {t.shape} tensor to device")
      return place_to_device(t)
    else:
      return t

  # The compiler_fn is called after the forward and backward graphs are extracted.
  # Here, we just print the code in the compiler_fn. Return of this function is a callable.
  def forward_comp(fx_module: torch.fx.GraphModule, _):
    print("Forward", fx_module.code)

    def compute_then_offload(*args, **kwargs):
      for a in args:
        print("Arg type: ", type(a), str(a.shape))
      for k, v in kwargs.items():
        print(f"kwarg {k} type: ", type(v), str(v.shape))
      with Trace("fwd"):
        res = fx_module(*args, **kwargs)
        res2 = [res[0]] + [maybe_place_to_host(r) for r in res[1:]]
        xm.optimization_barrier_(res2)
        print("Forward output shapes: " + ', '.join(str(a.shape) for a in res2))
        return res2

    return compute_then_offload

  def backward_comp(fx_module, _):
    print("Backward", fx_module.code)

    def compute_then_offload(*args, **kwargs):
      for a in args:
        print("Arg type: ", type(a), str(a.shape))
      for k, v in kwargs.items():
        print(f"kwarg {k} type: ", type(v), str(v.shape))
      with Trace("bwd"):
        # This barrier matches what we got from a simple JAX example.
        xm.optimization_barrier_(list(itertools.chain(args, kwargs.values())))
        args = [maybe_place_to_device(r) for r in args]
        kwargs = {k: maybe_place_to_device(v) for k, v in kwargs.items()}
        res = fx_module(*args, **kwargs)
        print("Backward output shapes: " + ', '.join(str(a.shape) for a in res))
        return res

    return compute_then_offload

  # Pass on the compiler_fn to the aot_module API
  return aot_module(module, fw_compiler=forward_comp, bw_compiler=backward_comp)


In [4]:
with torch_xla.runtime.xla_device():
  layer = torch.nn.Linear(10, 10)
  orig_layer = layer
  layer = offload(layer)

  x = torch.ones(10, dtype=torch.float32, requires_grad=True)
  y = layer(x)

Forward 


def forward(self, primals_1, primals_2, primals_3):
    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
    unsqueeze = torch.ops.aten.unsqueeze.default(primals_3, 0);  primals_3 = None
    mm = torch.ops.aten.mm.default(unsqueeze, t)
    squeeze = torch.ops.aten.squeeze.dim(mm, 0);  mm = None
    add = torch.ops.aten.add.Tensor(squeeze, primals_2);  squeeze = primals_2 = None
    return (add, t, unsqueeze)
    
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([10, 10])
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([10])
Arg type:  <class 'torch.Tensor'> torch.Size([10])
Offload torch.Size([10, 10]) tensor to host
Offload torch.Size([1, 10]) tensor to host
Forward output shapes: torch.Size([10]), torch.Size([10, 10]), torch.Size([1, 10])


In [5]:
y.sum().backward()


Backward 


def forward(self, t, unsqueeze, tangents_1):
    unsqueeze_1 = torch.ops.aten.unsqueeze.default(tangents_1, 0)
    t_1 = torch.ops.aten.t.default(unsqueeze_1)
    mm_1 = torch.ops.aten.mm.default(t_1, unsqueeze);  t_1 = unsqueeze = None
    t_2 = torch.ops.aten.t.default(mm_1);  mm_1 = None
    t_3 = torch.ops.aten.t.default(t);  t = None
    mm_2 = torch.ops.aten.mm.default(unsqueeze_1, t_3);  unsqueeze_1 = t_3 = None
    squeeze_1 = torch.ops.aten.squeeze.dim(mm_2, 0);  mm_2 = None
    t_4 = torch.ops.aten.t.default(t_2);  t_2 = None
    return (t_4, tangents_1, squeeze_1)
    
Arg type:  <class 'torch.Tensor'> torch.Size([10, 10])
Arg type:  <class 'torch.Tensor'> torch.Size([1, 10])
Arg type:  <class 'torch.Tensor'> torch.Size([10])
Bring back torch.Size([10, 10]) tensor to device
Bring back torch.Size([1, 10]) tensor to device
Bring back torch.Size([10]) tensor to device
Backward output shapes: torch.Size([10, 10]), torch.Size([10]), torch.Size([10])




In [6]:
assert orig_layer.weight.grad is not None
print(xm.get_stablehlo([orig_layer.weight.grad]))
torch_xla.sync(wait=True)


#loc1 = /usr/local/lib/python3.10/site-packages/torch/utils/_device.py:106:0
#loc2 = /tmp/ipykernel_2729133/3320899134.py:21:0
#loc3 = /root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3577:0
#loc4 = /root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3517:0
#loc5 = /root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3334:0
#loc6 = /root/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py:128:0
#loc7 = /root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3130:0
#loc8 = /root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3075:0
#loc9 = /root/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py:549:0
#loc10 = /root/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py:449:0
#loc11 = /root/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:778:0
#loc12 = /root/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py:362:0
#loc13 = /root/.loca

In [7]:
print(orig_layer.weight.grad)


tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='xla:0')


# Test a 5 layer dense model

In [9]:
import torch
import torch.nn as nn

class LinearSin(nn.Module):
  def __init__(self, input_size=16384, output_size=16384):
    super(LinearSin, self).__init__()
    self.linear = nn.Linear(input_size, output_size, device=torch_xla.device())

  def forward(self, x):
    x = self.linear(x)
    x = torch.sin(x)
    return x

In [9]:
orig_layers = []
for i in range(5):
  orig_layers.append(LinearSin())
layers = []
for i in range(5):
  layers.append(offload(orig_layers[i]))
model = nn.Sequential(*layers)

In [10]:
inp = torch.rand(16384, device=torch_xla.device())
torch_xla.sync(wait=True)

In [11]:
model.zero_grad()
y = model(inp)
y.sum().backward()



In [12]:
assert orig_layers[0].linear.weight.grad is not None
print(xm.get_stablehlo([orig_layers[0].linear.weight.grad]))

#loc1 = /usr/local/lib/python3.10/site-packages/torch/_ops.py:723:0
#loc2 = <eval_with_key>.26:9:0
#loc3 = /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747:0
#loc4 = /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736:0
#loc5 = /usr/local/lib/python3.10/site-packages/torch/fx/graph_module.py:387:0
#loc6 = /usr/local/lib/python3.10/site-packages/torch/fx/graph_module.py:822:0
#loc7 = /tmp/ipykernel_2307133/2910448454.py:19:0
#loc8 = /usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:100:0
#loc9 = /usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:671:0
#loc10 = /usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:489:0
#loc11 = /usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:126:0
#loc12 = /usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:1569:0
#loc13 = /usr/local/li

In [13]:
optimizable_tensors = [p.grad for p in model.parameters() if p.grad is not None]
ir = torch_xla._XLAC._get_xla_tensors_dot(optimizable_tensors + [y])

In [14]:
print(ir)

digraph G {
  node0 [label="prim::Constant\nf32[]\nlocation=_make_grads@__init__.py:220\nxla_shape=f32[]"]
  node1 [label="aten::expand\nf32[]\nlocation=_make_grads@__init__.py:220\nxla_shape=f32[]"]
  node2 [label="aten::expand\nf32[16384]{0}\nxla_shape=f32[16384]{0}"]
  node3 [label="prim::Constant\nf32[]\nscope=fwd.5\nlocation=__call__@_ops.py:723\nxla_shape=f32[]"]
  node4 [label="xla::device_data\nf32[16384]{0}\nscope=fwd.5\nlocation=__call__@_ops.py:723\nxla_shape=f32[16384]{0}"]
  node5 [label="xla::device_data\nf32[16384,16384]{1,0}\nscope=fwd.5\nlocation=__call__@_ops.py:723\nxla_shape=f32[16384,16384]{1,0}"]
  node6 [label="aten::permute\nf32[16384,16384]{0,1}\nscope=fwd.5\nlocation=__call__@_ops.py:723\nxla_shape=f32[16384,16384]{0,1}"]
  node7 [label="prim::Constant\nf32[]\nscope=fwd.4\nlocation=__call__@_ops.py:723\nxla_shape=f32[]"]
  node8 [label="xla::device_data\nf32[16384]{0}\nscope=fwd.4\nlocation=__call__@_ops.py:723\nxla_shape=f32[16384]{0}"]
  node9 [label="xla::d

In [15]:
torch_xla.sync(wait=True)

In [16]:
import time
import torch_xla.debug.profiler as xp
server = xp.start_server(9012)


In [17]:
for i in range(10):
  model.zero_grad()
  output = model(inp.clone())
  output.sum().backward()
  torch_xla.sync()
torch_xla.sync(wait=True)
model.zero_grad()
torch_xla.sync(wait=True)

# Start profiling
xp.trace_detached(service_addr="localhost:9012", logdir="profile/", duration_ms=10000)
time.sleep(1)
for i in range(10):
  model.zero_grad()
  output = model(inp.clone())
  output.sum().backward()
  torch_xla.sync()
torch_xla.sync(wait=True)
model.zero_grad()
torch_xla.sync(wait=True)
time.sleep(10)

Starting to trace for 10000 ms. Remaining attempt(s): 2


2024-11-05 02:33:51.989194: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 1359950 nanoseconds and will start immediately.


# Test LLM

In [8]:
import sys
sys.path.append('/workspaces/torch/pytorch/xla/examples')

In [9]:
import gc
gc.collect()

from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel

device = torch_xla.device()
config = DecoderOnlyConfig(hidden_size=1024, num_hidden_layers=30)
config.intermediate_size = 4096
config.vocab_size = 8192
model = DecoderOnlyModel(config=config).to(device)
batch_size = 16
sequence_length = 512

import gc
gc.collect()

# Generate random input_ids within the range of the vocabulary size
input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_length), device=device)
torch_xla.sync(wait=True)


In [10]:
import time
import torch_xla.debug.profiler as xp
server = xp.start_server(9012)

In [11]:
for i, block in enumerate(model.layers):
  model.layers[i] = offload(block)

for i in range(10):
  model.zero_grad()
  output = model(input_ids.clone())
  output.sum().backward()
  torch_xla.sync()
torch_xla.sync(wait=True)
model.zero_grad()
torch_xla.sync(wait=True)

# Start profiling
xp.trace_detached(service_addr="localhost:9012", logdir="profile/", duration_ms=60000)
time.sleep(1)
for i in range(10):
  model.zero_grad()
  output = model(input_ids.clone())
  output.sum().backward()
  torch_xla.sync()
torch_xla.sync(wait=True)
model.zero_grad()
torch_xla.sync(wait=True)
time.sleep(60)

Forward 


def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10):
    pow_1 = torch.ops.aten.pow.Tensor_Scalar(primals_10, 2)
    mean = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None
    add = torch.ops.aten.add.Tensor(mean, 1e-06);  mean = None
    rsqrt = torch.ops.aten.rsqrt.default(add);  add = None
    mul = torch.ops.aten.mul.Tensor(primals_10, rsqrt)
    mul_1 = torch.ops.aten.mul.Tensor(primals_8, mul)
    t = torch.ops.aten.t.default(primals_1);  primals_1 = None
    view = torch.ops.aten.view.default(mul_1, [8192, 1024])
    mm = torch.ops.aten.mm.default(view, t)
    _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [16, 512, 1024]);  mm = None
    t_1 = torch.ops.aten.t.default(primals_2);  primals_2 = None
    view_1 = torch.ops.aten.view.default(mul_1, [8192, 1024])
    mm_1 = torch.ops.aten.mm.default(view_1, t_1)
    _unsafe_view_1 = torch.ops.aten._unsafe_view.default(mm_1, [16,



Backward 


def forward(self, primals_8, primals_9, primals_10, rsqrt, mul, t, view, t_1, view_1, t_2, view_2, _unsafe_view_5, view_6, _softmax, view_9, view_10, t_3, view_14, add_1, rsqrt_1, mul_2, t_4, view_15, _unsafe_view_7, t_5, view_16, _unsafe_view_8, silu, t_6, view_17, tangents_1):
    detach = torch.ops.aten.detach.default(rsqrt)
    detach_1 = torch.ops.aten.detach.default(detach);  detach = None
    detach_2 = torch.ops.aten.detach.default(_softmax);  _softmax = None
    detach_3 = torch.ops.aten.detach.default(detach_2);  detach_2 = None
    detach_4 = torch.ops.aten.detach.default(rsqrt_1)
    detach_5 = torch.ops.aten.detach.default(detach_4);  detach_4 = None
    view_18 = torch.ops.aten.view.default(tangents_1, [8192, 1024])
    t_7 = torch.ops.aten.t.default(view_18)
    mm_7 = torch.ops.aten.mm.default(t_7, view_17);  t_7 = view_17 = None
    t_8 = torch.ops.aten.t.default(mm_7);  mm_7 = None
    t_9 = torch.ops.aten.t.default(t_6);  t_6 = None
    mm_8 = torch.ops.a

2024-11-05 08:52:35.800751: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 1401149 nanoseconds and will start immediately.


Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([512, 1024])
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([512, 1024])
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([1024, 1024])
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([4096, 1024])
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([4096, 1024])
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([1024, 4096])
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([1024])
Arg type:  <class 'torch.nn.parameter.Parameter'> torch.Size([1024])
Arg type:  <class 'torch.Tensor'> torch.Size([16, 512, 1024])
Skip offloading <class 'torch.nn.parameter.Parameter'> torch.Size([1024])
Skip offloading <class 'torch.nn.parameter.Parameter'> torch.Size([1024])
Offload torch.Size([16, 512, 1024]) tensor to host
Offload torch.Size([16, 512, 1]) tensor to host
Offload torch.Size([16, 512, 1

In [12]:
import os
os.getenv("LIBTPU_INIT_ARGS")

'--xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=24 --xla_tpu_aggressive_opt_barrier_removal=DISABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2 --xla_jf_rematerialization_percent_shared_memory_limit=10000 --xla_tpu_prefer_async_allgather_to_allreduce=true --xla_tpu_enable_flash_attention=false --xla_tpu_use_enhanced_launch_barrier=true'