In [1]:
# Debugging flags
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1
%env PJRT_DEVICE=TPU

env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1
env: PJRT_DEVICE=TPU


In [2]:
import torch_xla
import torch_xla.runtime

import torch
from functorch.compile import aot_function

import time

## Trace fn without checkpoint

We obtain the forward of `fn` in terms of aten ops.

In [19]:
device = torch_xla.device()
w = torch.randn(4, 4, requires_grad=False, device=device)
torch_xla.sync()

def print_traceback():
  import traceback
  traceback.print_stack()


def fn(a, w):
  """A simple function containing a few layers."""
  print("a:", type(a), a.shape)
  time.sleep(1)
  print_traceback()
  time.sleep(1)
  a = torch.sin(a)
  a = a @ w
  a = torch.cos(a)
  a = a @ w
  a = torch.sigmoid(a)
  return a

def compiler_fn(m: torch.fx.GraphModule, _):
  print(m.code)
  return m

a = torch.randn(4, 4, requires_grad=True, device=device)
torch_xla.sync()
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
res = aot_print_fn(cloned_a, w)

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])


  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/root/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/root/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/local/lib/python3.10/asyncio/events.py", line 80, in

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])


  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/root/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/root/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/local/lib/python3.10/asyncio/events.py", line 80, in




def forward(self, primals_1, primals_2):
    sin = torch.ops.aten.sin.default(primals_1)
    mm = torch.ops.aten.mm.default(sin, primals_2);  sin = None
    cos = torch.ops.aten.cos.default(mm)
    mm_1 = torch.ops.aten.mm.default(cos, primals_2);  cos = None
    sigmoid = torch.ops.aten.sigmoid.default(mm_1);  mm_1 = None
    return (sigmoid, primals_1, primals_2, mm, sigmoid)
    


## Trace fn with a torch_xla checkpoint

Runtime error within torch_xla checkpoint.

In [20]:
import logging
import torch_xla.utils.checkpoint

def checkpointed_fn(a, w):
  return torch_xla.utils.checkpoint.checkpoint(fn, a, w)

aot_print_fn = aot_function(checkpointed_fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()

try:
  res = aot_print_fn(cloned_a, w)
except RuntimeError as e:
  logging.exception(e)

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])


  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/root/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/root/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/local/lib/python3.10/asyncio/events.py", line 80, in

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])


  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/root/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/root/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/local/lib/python3.10/asyncio/events.py", line 80, in

## Trace fn with a torch non-reentrant checkpoint

`aot_function` appears to skip over the checkpoint wrapper entirely. We still
get back an identical aten forward that saves all the intermediate activations.

In [23]:
import torch.utils.checkpoint

torch.xla = torch_xla.device()  # type:ignore

def checkpointed_fn(a, w):
  return torch.utils.checkpoint.checkpoint(fn, a, w, use_reentrant=False)

aot_print_fn = aot_function(checkpointed_fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
res = aot_print_fn(cloned_a, w)
print("Backward")
res.sum().backward()

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])


  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/root/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/root/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/local/lib/python3.10/asyncio/events.py", line 80, in

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])


  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/root/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/root/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/local/lib/python3.10/asyncio/events.py", line 80, in




def forward(self, primals_1, primals_2):
    sin = torch.ops.aten.sin.default(primals_1)
    mm = torch.ops.aten.mm.default(sin, primals_2);  sin = None
    cos = torch.ops.aten.cos.default(mm)
    mm_1 = torch.ops.aten.mm.default(cos, primals_2);  cos = None
    sigmoid = torch.ops.aten.sigmoid.default(mm_1);  mm_1 = None
    return (sigmoid, primals_1, primals_2, mm, sigmoid)
    
Backward



def forward(self, primals_1, primals_2, mm, sigmoid, tangents_1):
    detach = torch.ops.aten.detach.default(sigmoid);  sigmoid = None
    detach_1 = torch.ops.aten.detach.default(detach);  detach = None
    detach_2 = torch.ops.aten.detach.default(detach_1);  detach_1 = None
    detach_3 = torch.ops.aten.detach.default(detach_2);  detach_2 = None
    sigmoid_backward = torch.ops.aten.sigmoid_backward.default(tangents_1, detach_3);  tangents_1 = detach_3 = None
    t = torch.ops.aten.t.default(primals_2)
    mm_2 = torch.ops.aten.mm.default(sigmoid_backward, t);  sigmoid_backward = t = None
 



## Use dynamo to trace the checkpointed function

We get a higher order `checkpoint` op.

In [25]:
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return gm.forward  # return a python callable
  
def fn_2(a, w):
  """A simple function containing a few layers."""
  a = torch.sin(a)
  a = a @ w
  a = torch.cos(a)
  a = a @ w
  a = torch.sigmoid(a)
  return a
  
def checkpointed_fn_2(a, w):
  return torch.utils.checkpoint.checkpoint(fn_2, a, w, use_reentrant=False)

dynamo_fn = torch.compile(checkpointed_fn_2, backend=my_compiler, fullgraph=True)
dynamo_fn(cloned_a, w)

my_compiler() called with FX graph:
opcode         name                       target                       args                            kwargs
-------------  -------------------------  ---------------------------  ------------------------------  ------------------------
placeholder    l_a_                       L_a_                         ()                              {}
placeholder    l_w_                       L_w_                         ()                              {}
get_attr       wrap_body_0                wrap_body_0                  ()                              {}
call_function  tag_activation_checkpoint  tag_activation_checkpoint    (wrap_body_0, l_a_, l_w_)       {'use_reentrant': False}
call_function  getitem                    <built-in function getitem>  (tag_activation_checkpoint, 0)  {}
output         output                     output                       ((getitem,),)                   {}

FX graph code:



def forward(self, L_a_ : torch.Tensor, L_w_ : tor

tensor([[0.6119, 0.5826, 0.8878, 0.0959],
        [0.4757, 0.3571, 0.6625, 0.1956],
        [0.5700, 0.5923, 0.8227, 0.1141],
        [0.7144, 0.5106, 0.8910, 0.2592]], device='xla:0',
       grad_fn=<SigmoidBackward0>)

## Use dynamo and then AOTAutograd

In [7]:
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return make_boxed_func(gm.forward)
  
my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler
dynamo_fn = torch.compile(checkpointed_fn_2, backend=my_backend, fullgraph=True)
o = dynamo_fn(cloned_a, w)
assert o is not None
o.sum().backward()

my_compiler() called with FX graph:
opcode         name       target                args                                kwargs
-------------  ---------  --------------------  ----------------------------------  --------
placeholder    primals_1  primals_1             ()                                  {}
placeholder    primals_2  primals_2             ()                                  {}
call_function  sin        aten.sin.default      (primals_1,)                        {}
call_function  mm         aten.mm.default       (sin, primals_2)                    {}
call_function  cos        aten.cos.default      (mm,)                               {}
call_function  mm_1       aten.mm.default       (cos, primals_2)                    {}
call_function  sigmoid    aten.sigmoid.default  (mm_1,)                             {}
output         output     output                ((sigmoid, primals_1, primals_2),)  {}

FX graph code:



def forward(self, primals_1, primals_2):
    sin = torch.ops.aten

In [8]:
def checkpointed_fn_3(a, w):
  return torch_xla.utils.checkpoint.checkpoint(fn_2, a, w, use_reentrant=True)

try:
  my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler
  dynamo_fn = torch.compile(checkpointed_fn_3, backend=my_backend, fullgraph=True)
  o = dynamo_fn(cloned_a, w)
  assert o is not None
  o.sum().backward()
except RuntimeError as e:
  logging.exception(e)


ERROR:root:'skip function check_backward_validity in file /usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py'

from user code:
   File "/tmp/ipykernel_1664539/204424578.py", line 2, in checkpointed_fn_3
    return torch_xla.utils.checkpoint.checkpoint(fn_2, a, w, use_reentrant=True)
  File "/workspaces/torch/pytorch/xla/torch_xla/utils/checkpoint.py", line 292, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/workspaces/torch/pytorch/xla/torch_xla/utils/checkpoint.py", line 87, in forward
    check_backward_validity(args)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
Traceback (most recent call last):
  File "/tmp/ipykernel_1664539/204424578.py", line 7, in <module>
    o = dynamo_fn(cloned_a, w)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.

### Use torch.export with AOTAutograd

In [9]:
import torch.export
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  print("my_compiler() called with FX graph:")
  gm.graph.print_tabular()
  print()
  print("FX graph code:")
  print(gm.code)
  print()
  time.sleep(1)
  return make_boxed_func(gm.forward)
  
class FunctionModule(torch.nn.Module):
  def __init__(self, f):
    super().__init__()
    self.f = f
  
  def forward(self, *args):
    return self.f(*args)

exported = torch.export.export_for_training(FunctionModule(checkpointed_fn_2), args=(cloned_a, w))
print(exported)
module = exported.module()

# Now get the backward function
from functorch.compile import aot_module
aot_module(module, fw_compiler=my_compiler, bw_compiler=my_compiler)(cloned_a, w).sum().backward()

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, args_0: "f32[4, 4]", args_1: "f32[4, 4]"):
             # File: /tmp/ipykernel_1664539/1239672230.py:17 in fn_2, code: a = torch.sin(a)
            sin: "f32[4, 4]" = torch.ops.aten.sin.default(args_0);  args_0 = None
            
             # File: /tmp/ipykernel_1664539/1239672230.py:18 in fn_2, code: a = a @ w
            matmul: "f32[4, 4]" = torch.ops.aten.matmul.default(sin, args_1);  sin = None
            
             # File: /tmp/ipykernel_1664539/1239672230.py:19 in fn_2, code: a = torch.cos(a)
            cos: "f32[4, 4]" = torch.ops.aten.cos.default(matmul);  matmul = None
            
             # File: /tmp/ipykernel_1664539/1239672230.py:20 in fn_2, code: a = a @ w
            matmul_1: "f32[4, 4]" = torch.ops.aten.matmul.default(cos, args_1);  cos = args_1 = None
            
             # File: /tmp/ipykernel_1664539/1239672230.py:21 in fn_2, code: a = torch.sigmoid(a)
            

## Use torch.compile with an exception trick to avoid evaluation

In [10]:
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint
from dataclasses import dataclass

@dataclass
class ReturnGraph(Exception):
  graph: torch.fx.GraphModule

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  print("my_compiler() called with FX graph:")
  gm.graph.print_tabular()
  print()
  print("FX graph code:")
  print(gm.code)
  print()
  time.sleep(1)

  def f(*args):
    print(args)
    raise ReturnGraph(gm)

  return make_boxed_func(f)

try:
  my_backend = aot_autograd(fw_compiler=my_compiler, bw_compiler=my_compiler)
  dynamo_fn = torch.compile(checkpointed_fn_2, backend=my_backend, fullgraph=True)
  o = dynamo_fn(cloned_a, w)
  assert o is not None
  o.sum().backward()
except ReturnGraph as e:
  print(e)

my_compiler() called with FX graph:
opcode         name       target                args                                kwargs
-------------  ---------  --------------------  ----------------------------------  --------
placeholder    primals_1  primals_1             ()                                  {}
placeholder    primals_2  primals_2             ()                                  {}
call_function  sin        aten.sin.default      (primals_1,)                        {}
call_function  mm         aten.mm.default       (sin, primals_2)                    {}
call_function  cos        aten.cos.default      (mm,)                               {}
call_function  mm_1       aten.mm.default       (cos, primals_2)                    {}
call_function  sigmoid    aten.sigmoid.default  (mm_1,)                             {}
output         output     output                ((sigmoid, primals_1, primals_2),)  {}

FX graph code:



def forward(self, primals_1, primals_2):
    sin = torch.ops.aten

In [11]:
import torch.export
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  print("my_compiler() called with FX graph:")
  gm.graph.print_tabular()
  print()
  print("FX graph code:")
  print(gm.code)
  print()
  time.sleep(1)
  return make_boxed_func(gm.forward)
  
class FunctionModule(torch.nn.Module):
  def __init__(self, f):
    super().__init__()
    self.f = f
  
  def forward(self, *args):
    return self.f(*args)

try:
  with FakeTensorMode():
    fake_a = torch.randn(4, 4)
    fake_w = torch.randn(4, 4)
    exported = torch.export.export_for_training(FunctionModule(checkpointed_fn_2), args=(fake_a, fake_w))
    print(exported)
    module = exported.module()

    # Now get the backward function
    from functorch.compile import aot_module
    aot_module(module, fw_compiler=my_compiler, bw_compiler=my_compiler)(cloned_a, w).sum().backward()
except AssertionError as e:
  logging.exception(e)

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, args_0: "f32[4, 4]", args_1: "f32[4, 4]"):
             # File: /tmp/ipykernel_1664539/1239672230.py:17 in fn_2, code: a = torch.sin(a)
            sin: "f32[4, 4]" = torch.ops.aten.sin.default(args_0);  args_0 = None
            
             # File: /tmp/ipykernel_1664539/1239672230.py:18 in fn_2, code: a = a @ w
            matmul: "f32[4, 4]" = torch.ops.aten.matmul.default(sin, args_1);  sin = None
            
             # File: /tmp/ipykernel_1664539/1239672230.py:19 in fn_2, code: a = torch.cos(a)
            cos: "f32[4, 4]" = torch.ops.aten.cos.default(matmul);  matmul = None
            
             # File: /tmp/ipykernel_1664539/1239672230.py:20 in fn_2, code: a = a @ w
            matmul_1: "f32[4, 4]" = torch.ops.aten.matmul.default(cos, args_1);  cos = args_1 = None
            
             # File: /tmp/ipykernel_1664539/1239672230.py:21 in fn_2, code: a = torch.sigmoid(a)
            

ERROR:root:Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.sin.default(tensor([...], device='xla:0', size=(4, 4)))
Traceback (most recent call last):
  File "/tmp/ipykernel_1664539/1049874456.py", line 36, in <module>
    aot_module(module, fw_compiler=my_compiler, bw_compiler=my_compiler)(cloned_a, w).sum().backward()
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 945, in forward
    return compiled_f(
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 897, in returned_function
    out = cached_fn(flat_args)
  File "/usr/local/lib/python3.10/si

## Test dynamo + AOTAutograd on graphs with multiple inputs/outputs

Not sure why, we need to clone any outputs that may be aliased.
Otherwise, the extracted AOT autograd graph is missing outputs.

In [12]:
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint
from torch_xla.experimental.scan import tree_map

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return make_boxed_func(gm.forward)
  
def fn_4(a, w):
  """A simple function containing a few layers."""
  a = a @ w
  w = a
  return a, w.clone()
  
my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler
dynamo_fn = torch.compile(fn_4, backend=my_backend, fullgraph=True)
fake_a = torch.empty_like(cloned_a, requires_grad=True, device=torch_xla.device())
fake_w = torch.empty_like(w, requires_grad=True, device=torch_xla.device())
o = dynamo_fn(fake_a, fake_w)
# o = dynamo_fn(cloned_a, w)
assert o is not None
torch.autograd.backward(o, tree_map(lambda v: torch.ones_like(v), o))

my_compiler() called with FX graph:
opcode         name       target              args                                  kwargs
-------------  ---------  ------------------  ------------------------------------  --------
placeholder    primals_1  primals_1           ()                                    {}
placeholder    primals_2  primals_2           ()                                    {}
call_function  mm         aten.mm.default     (primals_1, primals_2)                {}
call_function  clone      aten.clone.default  (mm,)                                 {}
output         output     output              ((mm, clone, primals_1, primals_2),)  {}

FX graph code:



def forward(self, primals_1, primals_2):
    mm = torch.ops.aten.mm.default(primals_1, primals_2)
    clone = torch.ops.aten.clone.default(mm)
    return (mm, clone, primals_1, primals_2)
    

my_compiler() called with FX graph:
opcode         name        target           args                      kwargs
-------------  ----

### Test Dynamo + AOTAutograd argument ordering in complex graphs

When we extract the forward/backward graph using dynamo, it turns out that the
argument orders passed to our compiler isn't the same as the argument orders to
`fn`. The example below demonstrates that and inspects the bytecode.

Expectation: `fn` gets two PyTrees, `carry` and `x`. So `my_compiler` should
get a `GraphModule` taking for tensors, `carry['a']`, `carry['b']`,
`x['weights']`, `x['biases']`, respectively.

Reality: `carry` and `x` are reversed in the `GraphModule` given to `my_compiler`.

Hypothesis: Dyanmo is just using `my_compiler` as a backend. Within the bytecode
generated by Dynamo, it can reorder things however it wants. It can even eliminate
certain input/output arguments in case of aliasing. That's fine if `my_compiler`
is lowering the graph. It should lower whatever input graph appropriately to e.g.
CUDA, XLA, etc. However, we're abusing `my_compiler` to extract the forward and
backward graphs. There is no guarantee that the graph given to `my_compiler` will
have the same signature as `fn`. There is even no guarantee that `my_compiler`
receives the whole graph if `fullgrah=False`.

In [13]:
from torch_xla.experimental.scan import tree_flatten


device = torch_xla.device()

def fn(carry, x):
  weights = x['weights']
  biases = x['biases']
  carry_a = carry['a']
  carry_b = carry['b']
  new_carry_a = torch.sin((carry_a @ weights) + biases)
  new_carry_b = torch.cos((carry_b @ weights) + biases)
  y = torch.sigmoid(new_carry_a + new_carry_b)
  return {'a': new_carry_a, 'b': new_carry_b}, y

init = {
    'a': torch.randn(2, 3, requires_grad=True, device=device),
    'b': torch.randn(2, 3, requires_grad=True, device=device)
}
x = {
    'weights': torch.randn(3, 3, requires_grad=True, device=device),
    'biases': torch.randn(2, 3, requires_grad=True, device=device)
}


def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return make_boxed_func(gm.forward)
  
my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler
dynamo_fn = torch.compile(fn, backend=my_backend, fullgraph=True)
fake_init = tree_map(lambda v: torch.empty_like(v, requires_grad=v.requires_grad, device=torch_xla.device()), init)
fake_x = tree_map(lambda v: torch.empty_like(v, requires_grad=v.requires_grad, device=torch_xla.device()), x)
o = dynamo_fn(fake_init, fake_x)
assert o is not None
o, _ = tree_flatten(o)
torch.autograd.backward(o, tree_map(lambda v: torch.ones_like(v), o))

my_compiler() called with FX graph:
opcode         name       target                args                                                                          kwargs
-------------  ---------  --------------------  ----------------------------------------------------------------------------  --------
placeholder    primals_1  primals_1             ()                                                                            {}
placeholder    primals_2  primals_2             ()                                                                            {}
placeholder    primals_3  primals_3             ()                                                                            {}
placeholder    primals_4  primals_4             ()                                                                            {}
call_function  mm         aten.mm.default       (primals_3, primals_1)                                                        {}
call_function  add        aten.add.Tensor       (mm

Let's investigate the generate bytecode. We'll see that Dynamo generates a small
wrapper around the `__compiled_function_18`. Within the wrapper, it exchanged the
`carry` and `x` arguments, i.e. `weights` and `biases` comes first.

In [14]:
from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
cache_entries = _debug_get_cache_entry_list(innermost_fn(dynamo_fn))
cache_entry = cache_entries[0]
code = cache_entry.code
# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.
import dis
dis.dis(code)
from depyf import decompile
print(decompile(code))

  6           0 LOAD_GLOBAL              4 (__compiled_fn_14)
              2 LOAD_FAST                1 (x)
              4 LOAD_CONST               1 ('weights')
              6 BINARY_SUBSCR
              8 LOAD_FAST                1 (x)
             10 LOAD_CONST               2 ('biases')
             12 BINARY_SUBSCR
             14 LOAD_FAST                0 (carry)
             16 LOAD_CONST               3 ('a')
             18 BINARY_SUBSCR
             20 LOAD_FAST                0 (carry)
             22 LOAD_CONST               4 ('b')
             24 BINARY_SUBSCR
             26 CALL_FUNCTION            4
             28 STORE_FAST               9 (graph_out_0)
             30 LOAD_CONST               3 ('a')
             32 LOAD_FAST                9 (graph_out_0)
             34 LOAD_CONST               6 (0)
             36 BINARY_SUBSCR
             38 LOAD_CONST               4 ('b')
             40 LOAD_FAST                9 (graph_out_0)
             42 LOAD_CONST

In [15]:
__compiled_fn_18  # type:ignore

NameError: name '__compiled_fn_18' is not defined

### Conclusion

- Dynamo invokes user compiler with a graph with unstable argument ordering.
- The user compiler output is then weaved into the series of Python functions,
  some of which are in bytecode.
- In contrast, JAX models this as a stack of interpreters/function transforms.
  The JAX approach is way more composable but you lose the ability to graph break
  or to call into foreign functions.
- Dynamo is missing a crucial API to get `GraphModule`s representing both
  forward and backward passes, for a particular guarded monomorphic condition.
  That will make our lives much easier.
- We need to bend backwards to extract aten forward and backward graphs with
  the expected argument ordering.

In [None]:
from torch_xla.experimental.scan import tree_flatten, tree_unflatten
import torch.func


device = torch_xla.device()

def get_fn():
  # placeholder = torch.ones(3, device=device, dtype=torch.float32)
  def fn(carry, x):
    weights = x['weights']
    biases = x['biases']
    carry_a = carry['a']
    carry_b = carry['b']
    new_carry_a = torch.sin((carry_a @ weights) + biases)
    new_carry_b = torch.cos((carry_b @ weights) + biases)
    y = torch.sigmoid(new_carry_a + new_carry_b)
    return {'a': new_carry_a, 'b': new_carry_b}, y
  return fn

fn = get_fn()

init = {
    'a': torch.randn(2, 3, requires_grad=True, device=device),
    'b': torch.randn(2, 3, requires_grad=True, device=device)
}
x = {
    'weights': torch.randn(3, 3, requires_grad=True, device=device),
    'biases': torch.randn(2, 3, requires_grad=True, device=device)
}


fake_init = tree_map(lambda v: torch.empty_like(v, requires_grad=v.requires_grad, device=torch_xla.device()), init)
fake_x = tree_map(lambda v: torch.empty_like(v, requires_grad=v.requires_grad, device=torch_xla.device()), x)

flat_fake_init, _ = tree_flatten(fake_init)
flat_fake_x, _ = tree_flatten(fake_x)


flat_init, init_spec = tree_flatten(init)
flat_xs, xs_spec = tree_flatten(x)
flat_init_len = len(flat_init)
flat_xs_len = len(flat_xs)

def fn_flattened(*args):
  flat_init = args[:flat_init_len]
  flat_xs = args[flat_init_len:]
  return fn(
      tree_unflatten(flat_init, init_spec), tree_unflatten(flat_xs, xs_spec))


def determine_args_ordering(example_inputs: List[torch.Tensor]):
  for i, t in enumerate(example_inputs):
    try:
      idx = flat_fake_init.index(t)
    except:
      idx = flat_fake_x.index(t)
      idx += len(flat_fake_init)
    print(f"arg {i} should be idx {i}")


def fw_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print(gm.meta)
    print(gm.graph)
    
    # Idea: inspect tensors in example_inputs correspond to which tensors in the input.
    # determine_args_ordering(example_inputs)
    for t in example_inputs:
      print(t.storage(), t.device, t.grad_fn)

    print()
    time.sleep(1)

    return make_boxed_func(gm.forward)
  
  
def bw_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return make_boxed_func(gm.forward)
  
  
my_backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
dynamo_fn = torch.compile(fn_flattened, backend=my_backend, fullgraph=True)
o = dynamo_fn(*flat_fake_init, *flat_fake_x)
assert o is not None
o, _ = tree_flatten(o)
torch.autograd.backward(o, tree_map(lambda v: torch.ones_like(v), o))

# These don't work
# AOTAutograd of dynamo
# aot_dynamo_fn = aot_function(dynamo_fn, fw_compiler=fw_compiler, bw_compiler=bw_compiler)
# o = aot_dynamo_fn(fake_init, fake_x)
# grad_and_value of dynamo
# o = torch.func.grad_and_value(dynamo_fn)(fake_init, fake_x)

my_compiler() called with FX graph:
opcode         name       target                args                                                                          kwargs
-------------  ---------  --------------------  ----------------------------------------------------------------------------  --------
placeholder    primals_1  primals_1             ()                                                                            {}
placeholder    primals_2  primals_2             ()                                                                            {}
placeholder    primals_3  primals_3             ()                                                                            {}
placeholder    primals_4  primals_4             ()                                                                            {}
call_function  mm         aten.mm.default       (primals_3, primals_1)                                                        {}
call_function  add        aten.add.Tensor       (mm

In [None]:
from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
cache_entries = _debug_get_cache_entry_list(innermost_fn(dynamo_fn))
cache_entry = cache_entries[0]
code = cache_entry.code
# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.
import dis
dis.dis(code)
from depyf import decompile
print(decompile(code))

 44           0 LOAD_GLOBAL              6 (__compiled_fn_54)
              2 LOAD_FAST                0 (args)
              4 LOAD_CONST               1 (2)
              6 BINARY_SUBSCR
              8 LOAD_FAST                0 (args)
             10 LOAD_CONST               2 (3)
             12 BINARY_SUBSCR
             14 LOAD_FAST                0 (args)
             16 LOAD_CONST               3 (0)
             18 BINARY_SUBSCR
             20 LOAD_FAST                0 (args)
             22 LOAD_CONST               4 (1)
             24 BINARY_SUBSCR
             26 CALL_FUNCTION            4
             28 STORE_FAST               3 (graph_out_0)
             30 LOAD_CONST               5 ('a')
             32 LOAD_FAST                3 (graph_out_0)
             34 LOAD_CONST               3 (0)
             36 BINARY_SUBSCR
             38 LOAD_CONST               6 ('b')
             40 LOAD_FAST                3 (graph_out_0)
             42 LOAD_CONST               

In [None]:
import types

cached_func = created_function = types.FunctionType(cache_entry.code, globals())
cached_func

<function __main__.fn_flattened(*args)>

Try tracing the cached function and skipping the guard? Still broke.

In [None]:
# AOTAutograd of dynamo
aot_dynamo_fn = aot_function(cached_func, fw_compiler=fw_compiler, bw_compiler=bw_compiler)
o = aot_dynamo_fn(*flat_fake_init, *flat_fake_x)

 (Triggered internally at /workspaces/torch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


RuntimeError: 
During the backward, we encountered a tensor subclass where we guessed its
metadata incorrectly.

Expected metadata: None, expected type: <class 'torch.Tensor'>

Runtime metadata: None, runtime type: <class 'torch._subclasses.functional_tensor.FunctionalTensor'>

shape: torch.Size([2, 3])
To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.


Disassemble and modify the bytecode to load a function from kwargs.