In [1]:
# Debugging flags
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1
%env PJRT_DEVICE=TPU

env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1
env: PJRT_DEVICE=TPU


In [2]:
import torch_xla
import torch_xla.runtime

import torch
from functorch.compile import aot_function

import time

## Trace fn without checkpoint

We obtain the forward of `fn` in terms of aten ops.

In [3]:
device = torch_xla.device()
w = torch.randn(4, 4, requires_grad=False, device=device)
torch_xla.sync()

def fn(a, w):
  """A simple function containing a few layers."""
  print("a:", type(a), a.shape)
  time.sleep(1)
  a = torch.sin(a)
  a = a @ w
  a = torch.cos(a)
  a = a @ w
  a = torch.sigmoid(a)
  return a

def compiler_fn(m: torch.fx.GraphModule, _):
  print(m.code)
  return m

a = torch.randn(4, 4, requires_grad=True, device=device)
torch_xla.sync()
aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
res = aot_print_fn(cloned_a, w)

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])
a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])



def forward(self, primals_1, primals_2):
    sin = torch.ops.aten.sin.default(primals_1)
    mm = torch.ops.aten.mm.default(sin, primals_2);  sin = None
    cos = torch.ops.aten.cos.default(mm)
    mm_1 = torch.ops.aten.mm.default(cos, primals_2);  cos = None
    sigmoid = torch.ops.aten.sigmoid.default(mm_1);  mm_1 = None
    return (sigmoid, primals_1, primals_2, mm, sigmoid)
    


## Trace fn with a torch_xla checkpoint

Runtime error within torch_xla checkpoint.

In [4]:
import logging
import torch_xla.utils.checkpoint

def checkpointed_fn(a, w):
  return torch_xla.utils.checkpoint.checkpoint(fn, a, w)

aot_print_fn = aot_function(checkpointed_fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()

try:
  res = aot_print_fn(cloned_a, w)
except RuntimeError as e:
  logging.exception(e)

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])
a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])


 (Triggered internally at /workspaces/torch/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
ERROR:root:Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward(). Please use .backward() and do not pass its `inputs` argument.
Traceback (most recent call last):
  File "/tmp/ipykernel_393588/1073587478.py", line 12, in <module>
    res = aot_print_fn(cloned_a, w)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 887, in returned_function
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 527, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 778, in _create_aot_dispatcher_function
    compiled_fn,

## Trace fn with a torch non-reentrant checkpoint

`aot_function` appears to skip over the checkpoint wrapper entirely. We still
get back an identical aten forward that saves all the intermediate activations.

In [5]:
import torch.utils.checkpoint

torch.xla = torch_xla.device()  # type:ignore

def checkpointed_fn(a, w):
  return torch.utils.checkpoint.checkpoint(fn, a, w, use_reentrant=False)

aot_print_fn = aot_function(checkpointed_fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
cloned_a = a.clone().detach().requires_grad_(True)
torch_xla.sync()
res = aot_print_fn(cloned_a, w)

a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])


a: <class 'torch._subclasses.functional_tensor.FunctionalTensor'> torch.Size([4, 4])



def forward(self, primals_1, primals_2):
    sin = torch.ops.aten.sin.default(primals_1)
    mm = torch.ops.aten.mm.default(sin, primals_2);  sin = None
    cos = torch.ops.aten.cos.default(mm)
    mm_1 = torch.ops.aten.mm.default(cos, primals_2);  cos = None
    sigmoid = torch.ops.aten.sigmoid.default(mm_1);  mm_1 = None
    return (sigmoid, primals_1, primals_2, mm, sigmoid)
    


## Use dynamo to trace the checkpointed function

We get a higher order `checkpoint` op.

In [6]:
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return gm.forward  # return a python callable
  
def fn_2(a, w):
  """A simple function containing a few layers."""
  a = torch.sin(a)
  a = a @ w
  a = torch.cos(a)
  a = a @ w
  a = torch.sigmoid(a)
  return a
  
def checkpointed_fn_2(a, w):
  return torch.utils.checkpoint.checkpoint(fn_2, a, w, use_reentrant=False)

dynamo_fn = torch.compile(checkpointed_fn_2, backend=my_compiler, fullgraph=True)
dynamo_fn(cloned_a, w)

my_compiler() called with FX graph:
opcode         name                       target                       args                            kwargs
-------------  -------------------------  ---------------------------  ------------------------------  ------------------------
placeholder    l_a_                       L_a_                         ()                              {}
placeholder    l_w_                       L_w_                         ()                              {}
get_attr       wrap_body_0                wrap_body_0                  ()                              {}
call_function  tag_activation_checkpoint  tag_activation_checkpoint    (wrap_body_0, l_a_, l_w_)       {'use_reentrant': False}
call_function  getitem                    <built-in function getitem>  (tag_activation_checkpoint, 0)  {}
output         output                     output                       ((getitem,),)                   {}

FX graph code:



def forward(self, L_a_ : torch.Tensor, L_w_ : tor

tensor([[0.5832, 0.6852, 0.6427, 0.7385],
        [0.7401, 0.1634, 0.5956, 0.1051],
        [0.6172, 0.5593, 0.6466, 0.6149],
        [0.3401, 0.9659, 0.6556, 0.6364]], device='xla:0',
       grad_fn=<SigmoidBackward0>)

## Use dynamo and then AOTAutograd

In [7]:
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    print()
    print("FX graph code:")
    print(gm.code)
    print()
    time.sleep(1)
    return make_boxed_func(gm.forward)
  
my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler
dynamo_fn = torch.compile(checkpointed_fn_2, backend=my_backend, fullgraph=True)
o = dynamo_fn(cloned_a, w)
assert o is not None
o.sum().backward()

my_compiler() called with FX graph:
opcode         name       target                args                                kwargs
-------------  ---------  --------------------  ----------------------------------  --------
placeholder    primals_1  primals_1             ()                                  {}
placeholder    primals_2  primals_2             ()                                  {}
call_function  sin        aten.sin.default      (primals_1,)                        {}
call_function  mm         aten.mm.default       (sin, primals_2)                    {}
call_function  cos        aten.cos.default      (mm,)                               {}
call_function  mm_1       aten.mm.default       (cos, primals_2)                    {}
call_function  sigmoid    aten.sigmoid.default  (mm_1,)                             {}
output         output     output                ((sigmoid, primals_1, primals_2),)  {}

FX graph code:



def forward(self, primals_1, primals_2):
    sin = torch.ops.aten

In [8]:
def checkpointed_fn_3(a, w):
  return torch_xla.utils.checkpoint.checkpoint(fn_2, a, w, use_reentrant=True)

try:
  my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler
  dynamo_fn = torch.compile(checkpointed_fn_3, backend=my_backend, fullgraph=True)
  o = dynamo_fn(cloned_a, w)
  assert o is not None
  o.sum().backward()
except RuntimeError as e:
  logging.exception(e)


ERROR:root:'skip function check_backward_validity in file /usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py'

from user code:
   File "/tmp/ipykernel_393588/204424578.py", line 2, in checkpointed_fn_3
    return torch_xla.utils.checkpoint.checkpoint(fn_2, a, w, use_reentrant=True)
  File "/workspaces/torch/pytorch/xla/torch_xla/utils/checkpoint.py", line 292, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/workspaces/torch/pytorch/xla/torch_xla/utils/checkpoint.py", line 87, in forward
    check_backward_validity(args)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
Traceback (most recent call last):
  File "/tmp/ipykernel_393588/204424578.py", line 7, in <module>
    o = dynamo_fn(cloned_a, w)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py

### Use torch.export with AOTAutograd

In [9]:
import torch.export
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  print("my_compiler() called with FX graph:")
  gm.graph.print_tabular()
  print()
  print("FX graph code:")
  print(gm.code)
  print()
  time.sleep(1)
  return make_boxed_func(gm.forward)
  
class FunctionModule(torch.nn.Module):
  def __init__(self, f):
    super().__init__()
    self.f = f
  
  def forward(self, *args):
    return self.f(*args)

exported = torch.export.export_for_training(FunctionModule(checkpointed_fn_2), args=(cloned_a, w))
print(exported)
module = exported.module()

# Now get the backward function
from functorch.compile import aot_module
aot_module(module, fw_compiler=my_compiler, bw_compiler=my_compiler)(cloned_a, w).sum().backward()

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, args_0: "f32[4, 4]", args_1: "f32[4, 4]"):
             # File: /tmp/ipykernel_393588/1239672230.py:17 in fn_2, code: a = torch.sin(a)
            sin: "f32[4, 4]" = torch.ops.aten.sin.default(args_0);  args_0 = None
            
             # File: /tmp/ipykernel_393588/1239672230.py:18 in fn_2, code: a = a @ w
            matmul: "f32[4, 4]" = torch.ops.aten.matmul.default(sin, args_1);  sin = None
            
             # File: /tmp/ipykernel_393588/1239672230.py:19 in fn_2, code: a = torch.cos(a)
            cos: "f32[4, 4]" = torch.ops.aten.cos.default(matmul);  matmul = None
            
             # File: /tmp/ipykernel_393588/1239672230.py:20 in fn_2, code: a = a @ w
            matmul_1: "f32[4, 4]" = torch.ops.aten.matmul.default(cos, args_1);  cos = args_1 = None
            
             # File: /tmp/ipykernel_393588/1239672230.py:21 in fn_2, code: a = torch.sigmoid(a)
            sigmo

## Use torch.compile with an exception trick to avoid evaluation

In [12]:
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  print("my_compiler() called with FX graph:")
  gm.graph.print_tabular()
  print()
  print("FX graph code:")
  print(gm.code)
  print()
  time.sleep(1)
  return make_boxed_func(gm.forward)
  
with FakeTensorMode():
  fake_a = torch.randn(4, 4)
  fake_w = torch.randn(4, 4)
  my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler
  dynamo_fn = torch.compile(checkpointed_fn_2, backend=my_backend, fullgraph=True)
  o = dynamo_fn(fake_a, fake_w)
  assert o is not None
  o.sum().backward()

BackendCompilerFailed: backend='compiler_fn' raised:
AssertionError: fake mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x75aa061c6ce0>) from tracing context 0 doesn't match mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x75a9e3f31870>) from fake tensor input 0

fake mode from tracing context 0 allocated at:
  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/root/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/root/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/local/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    await self.process_one()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    await dispatch(*args)
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    await result
  File "/root/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    await super().execute_request(stream, ident, parent)
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    reply_content = await reply_content
  File "/root/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    res = shell.run_cell(
  File "/root/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
    result = self._run_cell(
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
    result = runner(coro)
  File "/root/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_393588/2443371153.py", line 23, in <module>
    o = dynamo_fn(fake_a, fake_w)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 556, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1447, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 550, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 979, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 709, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 744, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1348, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 234, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 663, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2914, in run
    super().run()
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1120, in run
    while self.step():
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1032, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3105, in RETURN_VALUE
    self._return(inst)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3090, in _return
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1077, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1337, in compile_and_call_fx_graph
    backend_fake_mode = torch._subclasses.FakeTensorMode(
  File "/usr/local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1204, in __init__
    self._stack_trace = traceback.extract_stack()

fake mode from fake tensor input 0 allocated at:
  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/root/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/root/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/local/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/local/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    await self.process_one()
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    await dispatch(*args)
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    await result
  File "/root/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    await super().execute_request(stream, ident, parent)
  File "/root/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    reply_content = await reply_content
  File "/root/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    res = shell.run_cell(
  File "/root/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
    result = self._run_cell(
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
    result = runner(coro)
  File "/root/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_393588/2443371153.py", line 18, in <module>
    with FakeTensorMode():
  File "/usr/local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1204, in __init__
    self._stack_trace = traceback.extract_stack()


Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


In [13]:
import torch.export
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func  # type:ignore
from typing import List
import torch
import torch.utils.checkpoint

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  print("my_compiler() called with FX graph:")
  gm.graph.print_tabular()
  print()
  print("FX graph code:")
  print(gm.code)
  print()
  time.sleep(1)
  return make_boxed_func(gm.forward)
  
class FunctionModule(torch.nn.Module):
  def __init__(self, f):
    super().__init__()
    self.f = f
  
  def forward(self, *args):
    return self.f(*args)

with FakeTensorMode():
  fake_a = torch.randn(4, 4)
  fake_w = torch.randn(4, 4)
  exported = torch.export.export_for_training(FunctionModule(checkpointed_fn_2), args=(fake_a, fake_w))
  print(exported)
  module = exported.module()

  # Now get the backward function
  from functorch.compile import aot_module
  aot_module(module, fw_compiler=my_compiler, bw_compiler=my_compiler)(cloned_a, w).sum().backward()

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, args_0: "f32[4, 4]", args_1: "f32[4, 4]"):
             # File: /tmp/ipykernel_393588/1239672230.py:17 in fn_2, code: a = torch.sin(a)
            sin: "f32[4, 4]" = torch.ops.aten.sin.default(args_0);  args_0 = None
            
             # File: /tmp/ipykernel_393588/1239672230.py:18 in fn_2, code: a = a @ w
            matmul: "f32[4, 4]" = torch.ops.aten.matmul.default(sin, args_1);  sin = None
            
             # File: /tmp/ipykernel_393588/1239672230.py:19 in fn_2, code: a = torch.cos(a)
            cos: "f32[4, 4]" = torch.ops.aten.cos.default(matmul);  matmul = None
            
             # File: /tmp/ipykernel_393588/1239672230.py:20 in fn_2, code: a = a @ w
            matmul_1: "f32[4, 4]" = torch.ops.aten.matmul.default(cos, args_1);  cos = args_1 = None
            
             # File: /tmp/ipykernel_393588/1239672230.py:21 in fn_2, code: a = torch.sigmoid(a)
            sigmo

AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.sin.default(tensor([...], device='xla:0', size=(4, 4)))