Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Dynamo bug if using customized torch.autograd.Function #1899

@yanboliang

Description

@yanboliang

🐛 Describe the bug

This is from 7k github models: https://github.com/jansel/pytorch-jit-paritybench/blob/7dbde38cb69926ff522339ba9df0c5c9026a37b4/generated/test_eladhoffer_quantized_pytorch.py

The following code passed if using eager, but failed using dynamo-eager.

import torch
from torch import nn
import torch.nn.functional as F
import torch._dynamo
import logging
from torch.autograd.function import Function

torch._dynamo.config.log_level = logging.DEBUG
torch._dynamo.config.verbose = True

class Exp(torch.autograd.Function):
    # the forward function can be staticmethod or classmethod
    @classmethod
    def forward(cls, ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result


class CustomizedFunctionModule(torch.nn.Module):
    """
    Test the customized operation from a torch.autograd.Function subclass.
    """
    def forward(self, i):
        return Exp().apply(i)

model = CustomizedFunctionModule().eval()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
# opt_model = model
x = torch.rand([3, 3])
print(model(x))
print(opt_model(x))

Error logs

Traceback (most recent call last):
  File "/scratch/ybliang/work/repos/pytorch/debug/debug6.py", line 37, in <module>
    print(opt_model(x))
  File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1427, in _call_impl
    return forward_call(*input, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 66, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 174, in _fn
    return fn(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/eval_frame.py", line 286, in catch_errors
    return callback(frame, cache_size)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 118, in _fn
    return fn(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 90, in time_wrapper
    r = func(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 349, in _convert_frame_assert
    return _compile(
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 404, in _compile
    out_code = transform_code_object(code, transform)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 392, in transform
    tracer.run()
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1624, in run
    super().run()
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 490, in run
    and self.step()
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 460, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 287, in wrapper
    return inner_fn(self, inst)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 917, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 395, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/misc.py", line 585, in call_function
    return self.obj.call_apply(tx, args, kwargs).add_options(self)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/misc.py", line 462, in call_apply
    return variables.UserFunctionVariable(
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/base.py", line 27, in __call__
    obj = type.__call__(cls, *args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/functions.py", line 89, in __init__
    assert isinstance(
AssertionError: expected FunctionType found method <bound method Exp.forward of <class '__main__.Exp'>>

Minified repro

No response

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions