# How to get the forward/backward of a Dynamo function

We (PyTorch/XLA) would like to get a symbolic graph representation (e.g. FX graph)
of the forward and backward pass of a torch function, in order to implement a
scan operator similar to `jax.lax.scan`. The forward pass returns the output and
intermediate activations. The backward pass takes in intermediate activations and
gradients w.r.t the output and returns gradients w.r.t the input. The forward of
a `scan(fn, ...)` is built using the forward of `fn`, and the backward of a
`scan(fn, ...)` is built using the bacward of `fn`.

There are a few ways to get a symbolic representation. Since Dynamo has support
for higher order operators such as `torch.utils.checkpoint` [1], we investigate
getting the forward/backward of a `fn` using Dynamo in this notebook.

[1]: https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565

In [15]:
import torch
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func, aot_function # type:ignore
from typing import List
import torch.utils.checkpoint
import time

### Graph extraction using the `aot_autograd` backend

`torch.compile` takes a `backend` argument. We can extract the graph using an
`aot_autograd` backend, created with two no-op compilers that just saves the
input graph.

This is a promising approach but the graph signature extracted by `aot_autograd`
is not necessarilt the same as `fn`. For example, if `fn` references free
variables in its closure, the extracted graph will contain additional arguments.
Dynamo may also reorder the arguments. There doesn't appear a way to learn the
argument mapping from the compiled function.

In [19]:
h = torch.rand(4, 4, requires_grad=True)

def fn_inner(a, w):
  """A simple function containing a few layers."""
  a = torch.sin(a)
  a = a @ w
  a = torch.cos(a)
  a = a @ h
  a = torch.sigmoid(a)
  return a

def fn(a, w):
  return torch.utils.checkpoint.checkpoint(fn_inner, a, w, use_reentrant=False)

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return make_boxed_func(gm.forward)

my_backend = aot_autograd(fw_compiler=my_compiler)
dynamo_fn = torch.compile(fn, backend=my_backend, fullgraph=True)
a = torch.rand(4, requires_grad=True)
w = torch.rand(4, 4, requires_grad=True)

# This will print a `forward` pass with three arguments.
# Which one is a, w, h repectively? We don't know.
o = dynamo_fn(a, w)

FX graph code:



def forward(self, primals_1, primals_2, primals_3):
    sin = torch.ops.aten.sin.default(primals_1)
    unsqueeze = torch.ops.aten.unsqueeze.default(sin, 0);  sin = None
    mm = torch.ops.aten.mm.default(unsqueeze, primals_2);  unsqueeze = None
    squeeze = torch.ops.aten.squeeze.dim(mm, 0);  mm = None
    cos = torch.ops.aten.cos.default(squeeze);  squeeze = None
    unsqueeze_1 = torch.ops.aten.unsqueeze.default(cos, 0);  cos = None
    mm_1 = torch.ops.aten.mm.default(unsqueeze_1, primals_3);  unsqueeze_1 = None
    squeeze_1 = torch.ops.aten.squeeze.dim(mm_1, 0);  mm_1 = None
    sigmoid = torch.ops.aten.sigmoid.default(squeeze_1);  squeeze_1 = None
    return (sigmoid, primals_1, primals_2, primals_3)
    



### Symbolically trace a dynamo function

Another approach is to symbolically trace a dynamo function using `aot_autograd`.
The intuition is that since dynamo decomposes `torch.utils.checkpoint` [2], we
can use `aot_autograd` to record the result of that decomposition in terms of
Aten ops. Unfortunately, this is also not supported.

[2]: https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565#torchutilscheckpointcheckpoint-8

In [20]:
dynamo_fn = torch.compile(fn, backend='aot_eager', fullgraph=True)
aot_dynamo_fn = aot_function(dynamo_fn, fw_compiler=my_compiler)
o = aot_dynamo_fn(a, w)

W1113 23:11:06.145000 894409 site-packages/torch/_dynamo/utils.py:1136] ChromiumEventLogger: Start event not in stack, ignoring


AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.mm.default(FakeTensor(..., size=(1, 4)), tensor([...], size=(4, 4), requires_grad=True))

While executing %mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%unsqueeze_1, %primals_3), kwargs = {})
Original traceback:
  File "/tmp/ipykernel_894409/2380202194.py", line 13, in fn
    return torch.utils.checkpoint.checkpoint(fn_inner, a, w, use_reentrant=False)
  File "/tmp/ipykernel_894409/2380202194.py", line 8, in fn_inner
    a = a @ h


### Feature request

In a nutshell, we would like a version of `aot_function` that supports higher
order ops such as `torch.utils.checkpoint`. Reading the documentation of the
`aot_eager` backend, it's possible that Dynamo already has the forward/backward
functions internally. We just need to get access to those and use them in `scan`.