## Current checkpointing

In [3]:
import torch
import torch_xla
import torch_xla.utils.checkpoint
import torch_xla.core.xla_model as xm
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class SimpleMLP(nn.Module):

  def __init__(self, checkpoint=True):
    super(SimpleMLP, self).__init__()
    self.checkpoint = checkpoint
    self.fc1 = nn.Linear(128, 64)
    self.fc2 = nn.Linear(64, 32)
    self.fc3 = nn.Linear(32, 10)

  def forward(self, x):
    x = self.fc1(x)
    x = torch.sin(x)
    if self.checkpoint:
      x = torch_xla.utils.checkpoint.checkpoint(self.block, x)
    else:
      x = self.block(x)
    x = self.fc3(x)
    x = torch.sin(x)
    return x

  def block(self, x):
    x = self.fc2(x)
    x = torch.sin(x)
    return x


# Dummy data
device = xm.xla_device()
dummy_data = torch.randn(64, 128, device=device)

# Initialize the model and optimizer
model = SimpleMLP(checkpoint=True).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001)  # type:ignore

# Cut the graph here
xm.mark_step()

# Training step with gradient checkpointing
optimizer.zero_grad()
x = model(dummy_data)
assert x is not None

print("HLO for forward:")
ir = torch_xla._XLAC._get_xla_tensors_dot([x])
print(ir)

dummy_loss = x.sum()
dummy_loss.backward()
optimizer.step()

# Print the HLO graph for the forward + backward pass
print("HLO for forward + backward:")
optimizable_tensors = [p for p in model.parameters() if p.grad is not None]
ir = torch_xla._XLAC._get_xla_tensors_dot(optimizable_tensors + [dummy_loss, x])
print(ir)


HLO for forward:
digraph G {
  node0 [label="xla::device_data\nf32[10]{0}\nxla_shape=f32[10]{0}"]
  node1 [label="xla::device_data\nf32[10,32]{1,0}\nxla_shape=f32[10,32]{1,0}"]
  node2 [label="aten::permute\nf32[32,10]{0,1}\nxla_shape=f32[32,10]{0,1}"]
  node3 [label="xla::device_data\nf32[32]{0}\nxla_shape=f32[32]{0}"]
  node4 [label="xla::device_data\nf32[32,64]{1,0}\nxla_shape=f32[32,64]{1,0}"]
  node5 [label="aten::permute\nf32[64,32]{0,1}\nxla_shape=f32[64,32]{0,1}"]
  node6 [label="xla::device_data\nf32[64]{0}\nxla_shape=f32[64]{0}"]
  node7 [label="xla::device_data\nf32[64,128]{1,0}\nxla_shape=f32[64,128]{1,0}"]
  node8 [label="aten::permute\nf32[128,64]{0,1}\nxla_shape=f32[128,64]{0,1}"]
  node9 [label="xla::device_data\nf32[64,128]{1,0}\nxla_shape=f32[64,128]{1,0}"]
  node10 [label="aten::addmm\nf32[64,64]{1,0}\nxla_shape=f32[64,64]{1,0}"]
  node11 [label="aten::sin\nf32[64,64]{1,0}\nxla_shape=f32[64,64]{1,0}"]
  node12 [label="aten::addmm\nf32[64,32]{1,0}\nxla_shape=f32[64,32

## New checkpointing

In [4]:
import torch
import torch_xla
import torch_xla.utils.checkpoint
import torch_xla.core.xla_model as xm
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import contextlib
from torch.utils.weak import WeakTensorKeyDictionary
from torch.overrides import TorchFunctionMode
from torch.utils._pytree import tree_map_only
from torch.utils.checkpoint import checkpoint

# TODO: solve the monkey patching
torch.xla = torch_xla


class MarkInputsToRegion(TorchFunctionMode):

  def __init__(self, barrier_function):
    # tensor -> bool
    self.is_marked = WeakTensorKeyDictionary()
    self.barrier_function = barrier_function
  
  def __enter__(self):
    print("Enter MarkInputsToRegion")
    # We could handle RNG here.
    return super().__enter__()

  def __exit__(self, exc_type, exc_val, exc_tb):
    print("Exit MarkInputsToRegion")
    # We could handle RNG here.
    return super().__exit__(exc_type, exc_val, exc_tb)

  # This will be called on every torch function call during the
  # recomputation of the checkpointed function.
  def __torch_function__(self, func, types, args=(), kwargs=None):
    if kwargs is None:
      kwargs = {}

    def visit(x):
      # If we have not seen this tensor, wrap it with optimization barrier
      # for this function.
      if not self.is_marked.get(x, False):
        val = self.barrier_function(x)
      else:
        val = x
      self.is_marked[x] = True
      return val

    args = tree_map_only(torch.Tensor, visit, args)
    kwargs = tree_map_only(torch.Tensor, visit, kwargs)
    out = func(*args, **kwargs)
    # Never wrap output tensors within the recomputation with optimization
    # barrier.
    self.is_marked[out] = True
    return out


def context_fn():

  def barrier_function(x: torch.Tensor):
    # Now we can do something with the input.
    x = x.clone()
    xm.optimization_barrier_([x])
    return x

  return contextlib.nullcontext(), MarkInputsToRegion(barrier_function)


class SimpleMLP(nn.Module):

  def __init__(self, checkpoint=True):
    super(SimpleMLP, self).__init__()
    self.checkpoint = checkpoint
    self.fc1 = nn.Linear(128, 64)
    self.fc2 = nn.Linear(64, 32)
    self.fc3 = nn.Linear(32, 10)

  def forward(self, x):
    x = self.fc1(x)
    x = torch.sin(x)
    if self.checkpoint:
      x = checkpoint(self.block, x, context_fn=context_fn, use_reentrant=False)
    else:
      x = self.block(x)
    x = self.fc3(x)
    x = torch.sin(x)
    return x

  def block(self, x):
    x = self.fc2(x)
    x = torch.sin(x)
    return x


# Dummy data
device = xm.xla_device()
dummy_data = torch.randn(64, 128, device=device)

# Initialize the model and optimizer
model = SimpleMLP(checkpoint=True).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001)  # type:ignore

# Cut the graph here
xm.mark_step()

# Training step with gradient checkpointing
optimizer.zero_grad()
x = model(dummy_data)
assert x is not None

print("HLO for forward:")
ir = torch_xla._XLAC._get_xla_tensors_dot([x])
print(ir)

print("Do backward pass")
dummy_loss = x.sum()
dummy_loss.backward()
optimizer.step()

# Print the HLO graph for the forward + backward pass
print("HLO for forward + backward:")
optimizable_tensors = [p for p in model.parameters() if p.grad is not None]
ir = torch_xla._XLAC._get_xla_tensors_dot(optimizable_tensors + [dummy_loss, x])
print(ir)


HLO for forward:
digraph G {
  node0 [label="xla::device_data\nf32[10]{0}\nxla_shape=f32[10]{0}"]
  node1 [label="xla::device_data\nf32[10,32]{1,0}\nxla_shape=f32[10,32]{1,0}"]
  node2 [label="aten::permute\nf32[32,10]{0,1}\nxla_shape=f32[32,10]{0,1}"]
  node3 [label="xla::device_data\nf32[32]{0}\nxla_shape=f32[32]{0}"]
  node4 [label="xla::device_data\nf32[32,64]{1,0}\nxla_shape=f32[32,64]{1,0}"]
  node5 [label="aten::permute\nf32[64,32]{0,1}\nxla_shape=f32[64,32]{0,1}"]
  node6 [label="xla::device_data\nf32[64]{0}\nxla_shape=f32[64]{0}"]
  node7 [label="xla::device_data\nf32[64,128]{1,0}\nxla_shape=f32[64,128]{1,0}"]
  node8 [label="aten::permute\nf32[128,64]{0,1}\nxla_shape=f32[128,64]{0,1}"]
  node9 [label="xla::device_data\nf32[64,128]{1,0}\nxla_shape=f32[64,128]{1,0}"]
  node10 [label="aten::addmm\nf32[64,64]{1,0}\nxla_shape=f32[64,64]{1,0}"]
  node11 [label="aten::sin\nf32[64,64]{1,0}\nxla_shape=f32[64,64]{1,0}"]
  node12 [label="aten::addmm\nf32[64,32]{1,0}\nxla_shape=f32[64,32