Skip to content

Commit

Permalink
[Dynamo x torch_function] fix incorrect source
Browse files Browse the repository at this point in the history
Fixes #128964

The problem was that we were installing the source for a type
incorrectly.

Test Plan:
- new tests

[ghstack-poisoned]
  • Loading branch information
zou3519 committed Jun 18, 2024
1 parent 304c934 commit d5f4964
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
35 changes: 35 additions & 0 deletions test/dynamo/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,41 @@ def fn(x):
res = fn(input)
self.assertIsInstance(res, BadNewTorchFunction)

def test_no_torch_function_recompiles(self):
class NJT:
def __repr__(self):
return f"NJT(shape={self.shape})"

def __init__(self, values, offsets):
self._values = values
self._offsets = offsets

def sin(self):
return torch.sin(self)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func == torch.sin:
self = args[0]
return NJT(func(self._values), self._offsets)
raise AssertionError("should not get here")

values1 = torch.randn(10, 3, 4, requires_grad=True)
values2 = torch.randn(10, 3, 4, requires_grad=True)
offsets = torch.tensor([0, 3, 10])
njt1 = NJT(values1, offsets)
njt2 = NJT(values2, offsets)

@torch.compile(backend="eager", fullgraph=True)
def f(x):
return torch.sin(x)

with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
f(njt1)
f(njt2)

def test_base_torch_function_tracing(self):
def fn(x):
return torch.add(x, 1)
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/torch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalSource
from ..source import AttrSource, GlobalSource, TypeSource
from ..utils import has_torch_function, is_tensor_base_attr_getter
from .constant import ConstantVariable
from .lists import TupleVariable
Expand Down Expand Up @@ -88,7 +88,7 @@ def _get_subclass_type_var(tx, var):
from .builder import SourcelessBuilder, VariableBuilder

if var.source:
return VariableBuilder(tx, var.source)(var.python_type())
return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
else:
return SourcelessBuilder.create(tx, var.python_type())

Expand Down

0 comments on commit d5f4964

Please sign in to comment.