In [1]:
# These are the offload flags used by MaxText.
%env LIBTPU_INIT_ARGS=--xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=24 --xla_tpu_aggressive_opt_barrier_removal=ENABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2

env: LIBTPU_INIT_ARGS=--xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=24 --xla_tpu_aggressive_opt_barrier_removal=ENABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2


In [2]:
import torch
import torch_xla
import torch_xla.runtime
from torch_xla.experimental.stablehlo_custom_call import (
  place_to_host, place_to_device
)
from torch.autograd.graph import saved_tensors_hooks



In [3]:
with torch_xla.runtime.xla_device():
  x = torch.ones(10, dtype=torch.float32, requires_grad=True)
  with saved_tensors_hooks(place_to_host, place_to_device):
    a = torch.sin(x)
    a = torch.sin(a)
    a.sum().backward()

In [4]:
print(torch_xla._XLAC._get_xla_tensors_text([x.grad]))

IR {
  %0 = f32[] prim::Constant(), xla_shape=f32[]
  %1 = f32[10]{0} aten::expand(%0), xla_shape=f32[10]{0}
  %2 = (f32[10]{0}) xla::custom_call(%1), xla_shape=(f32[10]{0})
  %3 = (f32[10]{0}) xla::custom_call(%2), xla_shape=(f32[10]{0})
  %4 = f32[10]{0} aten::cos(%3), xla_shape=f32[10]{0}
  %5 = f32[10]{0} aten::sin(%1), xla_shape=f32[10]{0}
  %6 = (f32[10]{0}) xla::custom_call(%5), xla_shape=(f32[10]{0})
  %7 = (f32[10]{0}) xla::custom_call(%6), xla_shape=(f32[10]{0})
  %8 = f32[10]{0} aten::cos(%7), xla_shape=f32[10]{0}
  %9 = f32[] prim::Constant(), xla_shape=f32[]
  %10 = f32[] aten::expand(%9), xla_shape=f32[]
  %11 = f32[10]{0} aten::expand(%10), xla_shape=f32[10]{0}
  %12 = f32[10]{0} aten::mul(%11, %8), xla_shape=f32[10]{0}
  %13 = f32[10]{0} aten::mul(%12, %4), xla_shape=f32[10]{0}, ROOT=0
}



In [5]:
torch_xla.sync()
x.grad

tensor([0.3600, 0.3600, 0.3600, 0.3600, 0.3600, 0.3600, 0.3600, 0.3600, 0.3600,
        0.3600], device='xla:0')

Test host offloading with optimization barrier

In [24]:
from types import MethodType
import torch
import torch.nn
from torch.autograd.function import Function
from collections.abc import Iterable
import torch_xla.core.xla_model as xm
import itertools


def _extract_tensors_from_list(inputs):
  tensor_inputs = []
  if torch.is_tensor(inputs):
    tensor_inputs.append(inputs)
  # tensor is Iterable so we need to avoid iterating through tensor
  elif isinstance(inputs, Iterable):
    for input in inputs:
      if torch.is_tensor(input):
        tensor_inputs.append(input)
  return tensor_inputs


def offload1(module: torch.nn.Module):
  """Collate all intermediate tensors into an optimization barrier op."""

  def offloaded_fn(*args):
    tensor_store = {}
    counter = 0

    def pack_fn(tensor: torch.Tensor) -> int:
      nonlocal counter
      idx = counter
      counter = counter + 1

      tensor_store[idx] = place_to_host(tensor)
      # print(f"Packing tensor {idx}")
      return idx

    def unpack_fn(idx: int) -> torch.Tensor:
      # print(f"Unpacking tensor {idx}")
      return place_to_device(tensor_store[idx])

    with saved_tensors_hooks(pack_fn, unpack_fn):
      out = module(*args)

    # Prevent offload operations from being moved after the output.
    xm.optimization_barrier_(
        _extract_tensors_from_list(itertools.chain(tensor_store.values(), out)))
    return out

  return offloaded_fn


def offload2(module: torch.nn.Module):
  """Collate all intermediate tensors into an optimization barrier op,
  and also clear the tensor store when done with those intermediate tensors.
  """

  def offloaded_fn(*args):
    tensor_store = {}
    counter = [0]  # Use list to make it mutable within closures
    refcount = [0]  # Manual reference counter

    def pack_fn(tensor: torch.Tensor):
      idx = counter[0]
      counter[0] += 1
      tensor_store[idx] = place_to_host(tensor)
      # Increase the reference count
      refcount[0] += 1

      class PackedTensor:

        def __init__(self, index):
          self.index = index

        def __del__(self):
          # Decrease the reference count
          refcount[0] -= 1
          if refcount[0] == 0:
            # Clear tensor_store when refcount reaches zero
            tensor_store.clear()

      return PackedTensor(idx)

    def unpack_fn(packed_tensor):
      idx = packed_tensor.index
      return place_to_device(tensor_store[idx])

    with saved_tensors_hooks(pack_fn, unpack_fn):
      out = module(*args)

    # Prevent offload operations from being moved after the output.
    xm.optimization_barrier_(
        _extract_tensors_from_list(itertools.chain(tensor_store.values(), out)))
    return out

  return offloaded_fn


def offload3(module: torch.nn.Module):
  """No optimization barrier. Simply move tensors between host and device."""

  def offloaded_fn(*args):

    def pack_fn(tensor: torch.Tensor):
      return place_to_host(tensor)

    def unpack_fn(input) -> torch.Tensor:
      return place_to_device(input)

    with saved_tensors_hooks(pack_fn, unpack_fn):
      out = module(*args)

    return out

  return offloaded_fn


def offload4(module: torch.nn.Module):
  """Each intermediate activation tensor gets its own optimization barrier."""

  def offloaded_fn(*args):

    def pack_fn(tensor: torch.Tensor):
      t = place_to_host(tensor)
      xm.optimization_barrier_([t])
      return t

    def unpack_fn(input) -> torch.Tensor:
      return place_to_device(input)

    with saved_tensors_hooks(pack_fn, unpack_fn):
      out = module(*args)

    return out

  return offloaded_fn


from torch_xla.distributed.spmd.xla_sharding import apply_backward_optimization_barrier


def offload5(module: torch.nn.Module):
  """Use an optimization barrier to tie together the transfer ops with the backward input."""

  counter = 0
  tensor_store = {}
  moved = False


  def offloaded_fn(*args):
    nonlocal counter
    nonlocal tensor_store
    nonlocal moved

    def pack_fn(tensor: torch.Tensor) -> int:
      nonlocal counter
      nonlocal tensor_store
      nonlocal moved

      moved = False

      # Record the tensor to some list
      idx = counter
      tensor_store[idx] = tensor
      counter = counter + 1
      print(f"Packing tensor {idx}")
      return idx

    def unpack_fn(idx: int) -> torch.Tensor:
      print(f"Unpacking tensor {idx}")
      return tensor_store[idx]

    # Too late.
    # torch.nn.modules.module.register_module_full_backward_hook(transfer_and_add_optimization_barrier)

    # Too late.
    # module.register_full_backward_hook(transfer_and_add_optimization_barrier, prepend=True)

    with saved_tensors_hooks(pack_fn, unpack_fn):
      out = module._xla_checkpointed_forward_original(*args)

    return out

  def transfer_and_add_optimization_barrier(module, grad_output):
    nonlocal tensor_store
    nonlocal moved
    nonlocal counter

    if moved:
      raise RuntimeError("Already moved once during a previous transfer_and_add_optimization_barrier")

    # Transfer all tensors to host
    for k, v in tensor_store.items():
      print(f"Place tensor {k} to host")
      tensor_store[k] = place_to_host(v)

    # Wrap with barrier
    from torch_xla.utils.checkpoint import CheckpointFunction
    gradients = []
    for param in module.parameters():
      if param.grad != None:
        gradients.append(param.grad)
    print(f"Add optimization barrier")
    xm.optimization_barrier_(
        CheckpointFunction._extract_tensors_from_list(list(tensor_store.values()) +
                                                      gradients +
                                                      list(grad_output)))

    # Transfer all tensor to device
    for k, v in tensor_store.items():
      print(f"Place tensor {k} to device")
      tensor_store[k] = place_to_device(v)

    moved = True
    counter = 0

    # Return the modified grad_output
    return tuple(grad_output)

  # Just right.
  module.register_full_backward_pre_hook(transfer_and_add_optimization_barrier, prepend=True)

  def _xla_checkpointed_forward_no_kwargs(m, num_args, num_kwargs,
                                          *packed_args):
    # unpack packed_args into args and kwargs
    assert num_args + num_kwargs * 2 == len(packed_args)
    args = packed_args[:num_args]
    kwargs = packed_args[num_args:]
    kwargs = dict(zip(kwargs[:num_kwargs], kwargs[num_kwargs:]))
    return m._xla_checkpointed_forward_original(*args, **kwargs)

  def _forward_with_checkpoint(m, *args, **kwargs):
    # pack args and kwargs together as `torch_xla.utils.checkpoint.checkpoint`
    # doesn't support keyword arguments
    packed_args = args + tuple(kwargs.keys()) + tuple(kwargs.values())
    input_requires_grad = any(
        isinstance(t, torch.Tensor) and t.requires_grad for t in packed_args)
    if input_requires_grad:
      outputs = offloaded_fn(len(args), len(kwargs), *packed_args)
    else:
      # No input requires gradients so we won't checkpoint this forward pass.
      # Note that `m`` might have parameters that require gradients, but they
      # are beyond what `torch_xla.utils.checkpoint.checkpoint` can handle.
      outputs = m._xla_checkpointed_forward_original(*args, **kwargs)
    return outputs

  assert isinstance(module, torch.nn.Module)
  # replace `module`'s forward method with its checkpointed version
  module._xla_checkpointed_forward_original = module.forward  # type: ignore
  module._xla_checkpointed_forward_no_kwargs = MethodType(    # type: ignore
      _xla_checkpointed_forward_no_kwargs, module)
  module.forward = MethodType(_forward_with_checkpoint, module)
  return module


# Test different approaches
offload = offload5

In [25]:
with torch_xla.runtime.xla_device():
  layer = torch.nn.Linear(10, 10)
  orig_layer = layer
  layer = offload(layer)

  x = torch.ones(10, dtype=torch.float32, requires_grad=True)
  y = layer(x)

TypeError: Linear.forward() takes 2 positional arguments but 4 were given

In [17]:
y.sum().backward()


Place tensor 0 to host
Place tensor 1 to host
Add optimization barrier
Place tensor 0 to device
Place tensor 1 to device
Unpacking tensor 0
Unpacking tensor 1


In [18]:
print(torch_xla._XLAC._get_xla_tensors_hlo([orig_layer.weight.grad]))
torch_xla.sync(wait=True)


HloModule IrToHlo.54, entry_computation_layout={(s64[], f32[], f32[])->(f32[10,10]{0,1})}

ENTRY %IrToHlo.54 (p0.14: s64[], p1.19: f32[], p2.20: f32[]) -> (f32[10,10]) {
  %constant.17 = s64[] constant(2531011)
  %constant.15 = s64[] constant(214013)
  %p0.14 = s64[] parameter(0)
  %multiply.16 = s64[] multiply(s64[] %constant.15, s64[] %p0.14)
  %add.18 = s64[] add(s64[] %constant.17, s64[] %multiply.16)
  %convert.21 = u64[] convert(s64[] %add.18)
  %reshape.23 = u64[1]{0} reshape(u64[] %convert.21)
  %constant.22 = u64[] constant(0)
  %reshape.24 = u64[1]{0} reshape(u64[] %constant.22)
  %concatenate.25 = u64[2]{0} concatenate(u64[1]{0} %reshape.23, u64[1]{0} %reshape.24), dimensions={0}
  %rng-bit-generator.26 = (u64[2]{0}, u32[10,10]{1,0}) rng-bit-generator(u64[2]{0} %concatenate.25), algorithm=rng_default
  %get-tuple-element.28 = u64[2]{0} get-tuple-element((u64[2]{0}, u32[10,10]{1,0}) %rng-bit-generator.26), index=0
  %get-tuple-element.27 = u32[10,10]{1,0} get-tuple-element((u

In [19]:
print(orig_layer.weight.grad)

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='xla:0')


In [20]:
import sys
sys.path.append('/workspaces/torch/pytorch/xla/examples')

In [21]:
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel

device = torch_xla.device()
config = DecoderOnlyConfig(hidden_size=1024, num_hidden_layers=40)
config.intermediate_size = 4096
config.vocab_size = 8192
model = DecoderOnlyModel(config=config).to(device)
batch_size = 16
sequence_length = 512

# Generate random input_ids within the range of the vocabulary size
input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_length), device=device)
torch_xla.sync(wait=True)


In [22]:
import time
import torch_xla.debug.profiler as xp
server = xp.start_server(9012)

In [13]:
# No offload
# Compile
for i in range(10):
  model.zero_grad()
  output = model.forward(input_ids.clone())
  output.sum().backward()
  torch_xla.sync()
torch_xla.sync(wait=True)
model.zero_grad()
torch_xla.sync(wait=True)

# Start profiling
xp.trace_detached(service_addr="localhost:9012", logdir="profile/", duration_ms=60000)
time.sleep(1)
for i in range(10):
  model.zero_grad()
  output = model.forward(input_ids.clone())
  output.sum().backward()
  torch_xla.sync()
torch_xla.sync(wait=True)
model.zero_grad()
torch_xla.sync(wait=True)
time.sleep(60)

Starting to trace for 60000 ms. Remaining attempt(s): 2


2024-11-04 07:13:38.134417: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 1118070 nanoseconds and will start immediately.


In [13]:
# model.zero_grad()
# model.requires_grad_(False)
# model.layers[0].mlp.up_proj.weight.requires_grad_(True)
# model.layers[0].self_attn.requires_grad_(True)
# model.layers[0].requires_grad_(True)
# model.layers[1].requires_grad_(True)
# 
# # This doesn't work.
# model.embed_tokens.requires_grad_(True)
#
# model_offloaded = offload(model)
# output = model_offloaded(input_ids.clone())
# output.sum().backward()
# print(torch_xla._XLAC._get_xla_tensors_hlo([model.embed_tokens.weight.grad]))

In [None]:
# Everything offload
# Compile
from types import MethodType


def offload_module(module):
  """
  Wrap a `module`'s `forward` method with gradient checkpointing (also called
  activation checkpointing) via `torch_xla.utils.checkpoint.checkpoint`.
  """

  def _xla_checkpointed_forward_no_kwargs(m, num_args, num_kwargs,
                                          *packed_args):
    # unpack packed_args into args and kwargs
    assert num_args + num_kwargs * 2 == len(packed_args)
    args = packed_args[:num_args]
    kwargs = packed_args[num_args:]
    kwargs = dict(zip(kwargs[:num_kwargs], kwargs[num_kwargs:]))
    return m._xla_checkpointed_forward_original(*args, **kwargs)

  def _forward_with_checkpoint(m, *args, **kwargs):
    # pack args and kwargs together as `torch_xla.utils.checkpoint.checkpoint`
    # doesn't support keyword arguments
    packed_args = args + tuple(kwargs.keys()) + tuple(kwargs.values())
    input_requires_grad = any(
        isinstance(t, torch.Tensor) and t.requires_grad for t in packed_args)
    if input_requires_grad:
      outputs = offload(m._xla_checkpointed_forward_no_kwargs)(len(args), len(kwargs), *packed_args)
    else:
      # No input requires gradients so we won't checkpoint this forward pass.
      # Note that `m`` might have parameters that require gradients, but they
      # are beyond what `torch_xla.utils.checkpoint.checkpoint` can handle.
      outputs = m._xla_checkpointed_forward_original(*args, **kwargs)
    return outputs

  assert isinstance(module, torch.nn.Module)
  # replace `module`'s forward method with its checkpointed version
  module._xla_checkpointed_forward_original = module.forward  # type: ignore
  module._xla_checkpointed_forward_no_kwargs = MethodType(    # type: ignore
      _xla_checkpointed_forward_no_kwargs, module)
  module.forward = MethodType(_forward_with_checkpoint, module)
  return module


for i, block in enumerate(model.layers):
  model.layers[i] = offload_module(block)
  apply_backward_optimization_barrier(model.layers[i])

for i in range(10):
  model.zero_grad()
  output = model(input_ids.clone())
  output.sum().backward()
  torch_xla.sync()
torch_xla.sync(wait=True)
model.zero_grad()
torch_xla.sync(wait=True)

# Start profiling
xp.trace_detached(service_addr="localhost:9012", logdir="profile/", duration_ms=60000)
time.sleep(1)
for i in range(10):
  model.zero_grad()
  output = model(input_ids.clone())
  output.sum().backward()
  torch_xla.sync()
torch_xla.sync(wait=True)
model.zero_grad()
torch_xla.sync(wait=True)
time.sleep(60)

AttributeError: 'function' object has no attribute 'register_full_backward_pre_hook'