Skip to content

Commit

Permalink
Fail fast when dynamo attempts to add unspecialized int/float as addi…
Browse files Browse the repository at this point in the history
…tional graph inputs

Pull Request resolved: #96786

Verified the changes to catch unspecialized int/floats being added as additional graph in D44037548 prior to RP(#95621).

However, with #95621 the issue to be solved originally is no longer valid because int & float in `forward` will always be specialized in export. This RP is to add the assertion anyway *(though not be hit unless there is a regression)* to immediately catch the attempt to add unspecialized int/float to additional graphargs

Differential Revision: [D44075910](https://our.internmc.facebook.com/intern/diff/D44075910/)

[ghstack-poisoned]
  • Loading branch information
guangyang committed Mar 16, 2023
1 parent 070cefa commit cd42558
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
21 changes: 13 additions & 8 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,7 +2210,7 @@ def func(x):
dynamo_result = exported(inp)
self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result))

def test_export_specialized_int_float(self):
def test_export_specialized_int(self):
class Foo(torch.nn.Module):
def __init__(
self,
Expand All @@ -2220,19 +2220,24 @@ def __init__(
self.torch_module = torch.nn.LayerNorm(
input_dim, eps=1e-5, elementwise_affine=True
)
self.int_val = 100

def forward(self, input):
return input.cos() * self.torch_module.eps
return input.cos() * self.int_val * self.torch_module.eps

mod = Foo(128)
inp = torch.randn(3, 128)

gm, _ = torch._dynamo.export(mod, inp, aten_graph=True, tracing_mode="symbolic")
count = 0
for node in gm.graph.nodes:
if node.op == "placeholder":
count += 1
self.assertEqual(count, 1)
# In export, int & float in forward should always be specialized
with config.patch(dynamic_shapes=True):
gm, _ = torch._dynamo.export(
mod, inp, aten_graph=True, tracing_mode="symbolic"
)
count = 0
for node in gm.graph.nodes:
if node.op == "placeholder":
count += 1
self.assertEqual(count, 1)

def test_export_pass_arg_by_name(self):
class BasicModule(torch.nn.Module):
Expand Down
6 changes: 6 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,12 @@ def wrap_unspecialized_primitive(self, value):
)
self.tx.output.unspec_variable_map[self.name] = unspec_var
if not is_constant_source(self.get_source()):
if self.tx.export and not isinstance(self.get_source(), LocalInputSource):
raise AssertionError(
"Dynamo attempts to add additional input during export: value={}, source={}".format(
wrapped_value, self.get_source()
)
)
fake_tensor_value = None
example_value = unspec_var.proxy.node.meta["example_value"]
if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor):
Expand Down

0 comments on commit cd42558

Please sign in to comment.