In [None]:
# Debugging flags
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1
%env PJRT_DEVICE=TPU

In [2]:
import torch_xla
import torch_xla.runtime
from torch_xla.utils.checkpoint import checkpoint

import torch
from functorch.compile import aot_function

import time

## Trace fn without checkpoint

We obtain the forward of `fn` in terms of aten ops.

In [4]:
device = torch_xla.device()
w = torch.randn(4, 4, requires_grad=False, device=device)
torch_xla.sync()

def fn(a, w):
  """A simple function containing a few layers."""
  print("a:", type(a), a.shape)
  time.sleep(1)
  a = torch.sin(a)
  a = a @ w
  a = torch.cos(a)
  a = a @ w
  a = torch.sigmoid(a)
  return a

def compiler_fn(m: torch.fx.GraphModule, _):
  print(m.code)
  return m

a = torch.randn(4, 4, requires_grad=True, device=device)
torch_xla.sync()
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
res = aot_print_fn(cloned_a, w)

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])
a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])



def forward(self, primals_1, primals_2):
    sin = torch.ops.aten.sin.default(primals_1)
    mm = torch.ops.aten.mm.default(sin, primals_2);  sin = None
    cos = torch.ops.aten.cos.default(mm)
    mm_1 = torch.ops.aten.mm.default(cos, primals_2);  cos = None
    sigmoid = torch.ops.aten.sigmoid.default(mm_1);  mm_1 = None
    return (sigmoid, primals_1, primals_2, mm, sigmoid)
    


## Trace fn with a torch_xla checkpoint

Runtime error within torch_xla checkpoint.

In [9]:
import torch_xla.utils.checkpoint

def checkpointed_fn(a, w):
  return torch_xla.utils.checkpoint.checkpoint(fn, a, w)

aot_print_fn = aot_function(checkpointed_fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
res = aot_print_fn(cloned_a, w)

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])
a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])


 (Triggered internally at /workspaces/torch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


RuntimeError: Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward(). Please use .backward() and do not pass its `inputs` argument.

## Trace fn with a torch non-reentrant checkpoint

`aot_function` appears to skip over the checkpoint wrapper entirely. We still
get back an identical aten forward that saves all the intermediate activations.

In [12]:
import torch.utils.checkpoint

torch.xla = torch_xla.device()  # type:ignore

def checkpointed_fn(a, w):
  return torch.utils.checkpoint.checkpoint(fn, a, w, use_reentrant=False)

aot_print_fn = aot_function(checkpointed_fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
res = aot_print_fn(cloned_a, w)

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])
a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])



def forward(self, primals_1, primals_2):
    sin = torch.ops.aten.sin.default(primals_1)
    mm = torch.ops.aten.mm.default(sin, primals_2);  sin = None
    cos = torch.ops.aten.cos.default(mm)
    mm_1 = torch.ops.aten.mm.default(cos, primals_2);  cos = None
    sigmoid = torch.ops.aten.sigmoid.default(mm_1);  mm_1 = None
    return (sigmoid, primals_1, primals_2, mm, sigmoid)
    


## Use dynamo to trace the checkpointed function

We get a higher order `checkpoint` op.

In [21]:
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return gm.forward  # return a python callable
  
def fn_2(a, w):
  """A simple function containing a few layers."""
  a = torch.sin(a)
  a = a @ w
  a = torch.cos(a)
  a = a @ w
  a = torch.sigmoid(a)
  return a
  
def checkpointed_fn_2(a, w):
  return torch.utils.checkpoint.checkpoint(fn_2, a, w, use_reentrant=False)

dynamo_fn = torch.compile(checkpointed_fn_2, backend=my_compiler, fullgraph=True)
dynamo_fn(cloned_a, w)

my_compiler() called with FX graph:
opcode         name                       target                       args                            kwargs
-------------  -------------------------  ---------------------------  ------------------------------  ------------------------
placeholder    l_a_                       L_a_                         ()                              {}
placeholder    l_w_                       L_w_                         ()                              {}
get_attr       wrap_body_0                wrap_body_0                  ()                              {}
call_function  tag_activation_checkpoint  tag_activation_checkpoint    (wrap_body_0, l_a_, l_w_)       {'use_reentrant': False}
call_function  getitem                    <built-in function getitem>  (tag_activation_checkpoint, 0)  {}
output         output                     output                       ((getitem,),)                   {}

FX graph code:



def forward(self, L_a_ : torch.Tensor, L_w_ : tor

tensor([[0.6009, 0.4689, 0.3773, 0.6870],
        [0.7962, 0.3210, 0.7358, 0.7793],
        [0.7023, 0.3950, 0.6428, 0.6828],
        [0.6641, 0.4357, 0.4625, 0.7164]], device='xla:0',
       grad_fn=<SigmoidBackward0>)

## Use dynamo and then AOTAutograd

In [23]:
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return make_boxed_func(gm.forward)
  
my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler
dynamo_fn = torch.compile(checkpointed_fn_2, backend=my_backend, fullgraph=True)
o = dynamo_fn(cloned_a, w)
assert o is not None
o.sum().backward()

my_compiler() called with FX graph:
opcode         name       target                args                                kwargs
-------------  ---------  --------------------  ----------------------------------  --------
placeholder    primals_1  primals_1             ()                                  {}
placeholder    primals_2  primals_2             ()                                  {}
call_function  sin        aten.sin.default      (primals_1,)                        {}
call_function  mm         aten.mm.default       (sin, primals_2)                    {}
call_function  cos        aten.cos.default      (mm,)                               {}
call_function  mm_1       aten.mm.default       (cos, primals_2)                    {}
call_function  sigmoid    aten.sigmoid.default  (mm_1,)                             {}
output         output     output                ((sigmoid, primals_1, primals_2),)  {}

FX graph code:



def forward(self, primals_1, primals_2):
    sin = torch.ops.aten