This repository was archived by the owner on Aug 1, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 129
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Dynamo bug if using customized torch.autograd.Function #1899
Copy link
Copy link
Closed
Labels
Description
🐛 Describe the bug
This is from 7k github models: https://github.com/jansel/pytorch-jit-paritybench/blob/7dbde38cb69926ff522339ba9df0c5c9026a37b4/generated/test_eladhoffer_quantized_pytorch.py
The following code passed if using eager, but failed using dynamo-eager.
import torch
from torch import nn
import torch.nn.functional as F
import torch._dynamo
import logging
from torch.autograd.function import Function
torch._dynamo.config.log_level = logging.DEBUG
torch._dynamo.config.verbose = True
class Exp(torch.autograd.Function):
# the forward function can be staticmethod or classmethod
@classmethod
def forward(cls, ctx, i):
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
class CustomizedFunctionModule(torch.nn.Module):
"""
Test the customized operation from a torch.autograd.Function subclass.
"""
def forward(self, i):
return Exp().apply(i)
model = CustomizedFunctionModule().eval()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
# opt_model = model
x = torch.rand([3, 3])
print(model(x))
print(opt_model(x))
Error logs
Traceback (most recent call last):
File "/scratch/ybliang/work/repos/pytorch/debug/debug6.py", line 37, in <module>
print(opt_model(x))
File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1427, in _call_impl
return forward_call(*input, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 66, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 174, in _fn
return fn(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 286, in catch_errors
return callback(frame, cache_size)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
return fn(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 90, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 349, in _convert_frame_assert
return _compile(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 404, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 392, in transform
tracer.run()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1624, in run
super().run()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 490, in run
and self.step()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 460, in step
getattr(self, inst.opname)(inst)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 287, in wrapper
return inner_fn(self, inst)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 917, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 395, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/misc.py", line 585, in call_function
return self.obj.call_apply(tx, args, kwargs).add_options(self)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/misc.py", line 462, in call_apply
return variables.UserFunctionVariable(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/base.py", line 27, in __call__
obj = type.__call__(cls, *args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/functions.py", line 89, in __init__
assert isinstance(
AssertionError: expected FunctionType found method <bound method Exp.forward of <class '__main__.Exp'>>
Minified repro
No response