Skip to content

Commit

Permalink
fix regression which creates a new fake tensor (#111864)
Browse files Browse the repository at this point in the history
Fixes regression identified here: https://github.com/pytorch/pytorch/pull/111565/files/ccd6b373b5fba81b6141a6b93b2ffb6595dd49ae#r1369334484

Now that `get_fake_value` will identify aliases, we should not try to wrap the fake value again.

Pull Request resolved: #111864
Approved by: https://github.com/eellison
  • Loading branch information
jon-chuang authored and pytorchmergebot committed Oct 24, 2023
1 parent 0e0f6a2 commit 6d78f34
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
6 changes: 3 additions & 3 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2221,18 +2221,18 @@ def foo(mod, x):

mod = Mod()
foo(mod, torch.rand([4]))
self.assertEqual(compiles_without_buffers, 1)
self.assertEqual(compiles_without_buffers, 0)

foo(mod, torch.rand([4], dtype=torch.half))
self.assertEqual(compiles_without_buffers, 2)
self.assertEqual(compiles_without_buffers, 1)

class Mod2(Mod):
def __setattr__(self, name, value):
return super().__setattr__(name, value)

foo(Mod2(), torch.rand([4]))
# causes two compilations, bc unimplemented custom setattr
self.assertTrue(compiles_without_buffers >= 4)
self.assertTrue(compiles_without_buffers >= 2)

def test_unspec_non_inlinable_module(self):
mod = UnspecNonInlinableModule()
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,7 @@ def get_fake_value(node, tx):

op = node.op

# FX Node should always return the same value
# FX Node should always return the same fake value
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
return node.meta["example_value"]

Expand Down
9 changes: 2 additions & 7 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,13 +1238,8 @@ def call_setattr(
getattr_var = None

if isinstance(getattr_var, variables.TensorVariable):
# get_fake_val will return a real tensor here because it's an attribute on the module (get_attr node)
existing_attr = get_fake_value(getattr_var.as_proxy().node, tx)
existing_fake_attr = (
variables.builder.wrap_to_fake_tensor_and_record(
existing_attr, tx, source=getattr_var.source, is_tensor=True
)
)
# get_fake_val will get the same fake tensor
existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx)

# same tensor identiy, setattr is a no-op
mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__")
Expand Down

0 comments on commit 6d78f34

Please sign in to comment.