Skip to content

Commit

Permalink
Revert "[Dynamo] Treat integers stored on nn.Modules as dynamic (#126466
Browse files Browse the repository at this point in the history
)"

This reverts commit 6bb9d60.

Reverted #126466 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the ONNX test failure looks legit, not flaky, as it starts failing in trunk https://hud.pytorch.org/pytorch/pytorch/commit/6bb9d6080d33c817fcbf9e5ae8a59b76812a53d2 ([comment](#126466 (comment)))
  • Loading branch information
pytorchmergebot authored and ZelboK committed May 19, 2024
1 parent 367a0c5 commit 0ac2cec
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 57 deletions.
57 changes: 0 additions & 57 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.testing import expectedFailureDynamic, same
from torch._dynamo.utils import ifdynstaticdefault
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import Parameter, UninitializedParameter

Expand Down Expand Up @@ -1105,37 +1104,6 @@ def forward(self, x):
return self.m(x)


class ModuleWithIntAttr(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(4, 4)
self.step = 10

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
self.step += 1
return self.layer(x) + self.step


class UnspecInlinableModule(torch.nn.Module):
torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule

def forward(self, x):
return torch.sin(x)


class UnspecModuleWithIntAttr(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = UnspecInlinableModule()
self.step = 10

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
self.step += 1
return self.layer(x) + self.step


def make_test(fn, expected_ops=None):
def test_fn(self):
return torch._dynamo.testing.standard_test(
Expand Down Expand Up @@ -1389,31 +1357,6 @@ def forward(self, x):
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
self.assertTrue(torch._dynamo.testing.same(out1, out_post))

def test_nn_module_unspec_int_attr(self):
for module_class in [ModuleWithIntAttr, UnspecModuleWithIntAttr]:
mod = module_class()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch.compile(backend=cnt)(copy.deepcopy(mod))
x = torch.randn(3, 4)

# Compiling self.step as static.
ref1 = mod(x)
res1 = opt_mod(x)
self.assertTrue(torch.allclose(ref1, res1))
self.assertEqual(cnt.frame_count, 1)

# Compiling self.step as dynamic.
ref2 = mod(x)
res2 = opt_mod(x)
self.assertTrue(torch.allclose(ref2, res2))
self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))

# No re-compilation!
ref3 = mod(x)
res3 = opt_mod(x)
self.assertTrue(torch.allclose(ref3, res3))
self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))

# RuntimeError: SymIntArrayRef expected to contain only concrete integers
@expectedFailureDynamic
def test_lazy_module1(self):
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,10 @@ def wrap_literal(self, value):
value in self._common_constants()
# Assume integers from global variables want to be specialized
or not self.source.guard_source().is_local()
# Assume that integers that came from NN modules want to be
# specialized (as we don't expect users to be changing the
# NN modules on the fly)
or self.source.guard_source().is_nn_module()
or is_from_defaults(self.source)
or is_cell_contents(self.source)
):
Expand Down

0 comments on commit 0ac2cec

Please sign in to comment.