Skip to content

Commit

Permalink
Fail fast when dynamo attempts to add unspecialized int/float as addi…
Browse files Browse the repository at this point in the history
…tional graph inputs (#96786)

Summary:
Pull Request resolved: #96786

Verified the changes to catch unspecialized int/floats being added as additional graph in D44037548 prior to RP(#95621).

However, with #95621 the issue to be solved originally is no longer valid because int & float in `forward` will always be specialized in export. This RP is to add the assertion anyway *(though not be hit unless there is a regression)* to immediately catch the attempt to add unspecialized int/float to additional graphargs

Test Plan:
Example of the error message would look like:
```
Dynamo attempts to add additional input: value=9.999999747378752e-06, source=NNModuleSource(inner=AttrSource(base=NNModuleSource(inner=AttrSource(base=LocalInputSource(local_name='self', pos=0), member='torch_module')), member='eps'))
```
Passed all export tests
```
Buck UI: https://www.internalfb.com/buck2/fea72653-5549-47e7-a9bf-740eb86a8e26
Test UI: https://www.internalfb.com/intern/testinfra/testrun/8725724422167257
RE: reSessionID-7b3470b1-c293-4c4a-9671-dd0b7a2839b8  Up: 6.0 KiB  Down: 0 B
Jobs completed: 101. Time elapsed: 115.7s.
Tests finished: Pass 98. Fail 0. Fatal 0. Skip 0. 0 builds failed
```

Reviewed By: tugsbayasgalan

Differential Revision: D44075910

fbshipit-source-id: 968562938c8ea1fd3e065e2ee162687cbb3112fc
  • Loading branch information
guangy10 authored and facebook-github-bot committed Mar 16, 2023
1 parent 397fb27 commit e96b7d3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
21 changes: 13 additions & 8 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,7 +2210,7 @@ def func(x):
dynamo_result = exported(inp)
self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result))

def test_export_specialized_int_float(self):
def test_export_specialized_int(self):
class Foo(torch.nn.Module):
def __init__(
self,
Expand All @@ -2220,19 +2220,24 @@ def __init__(
self.torch_module = torch.nn.LayerNorm(
input_dim, eps=1e-5, elementwise_affine=True
)
self.int_var = 100

def forward(self, input):
return input.cos() * self.torch_module.eps
return input.cos() * self.torch_module.eps * self.int_var

mod = Foo(128)
inp = torch.randn(3, 128)

gm, _ = torch._dynamo.export(mod, inp, aten_graph=True, tracing_mode="symbolic")
count = 0
for node in gm.graph.nodes:
if node.op == "placeholder":
count += 1
self.assertEqual(count, 1)
# In export, int & float in forward should always be specialized
with config.patch(dynamic_shapes=True):
gm, _ = torch._dynamo.export(
mod, inp, aten_graph=True, tracing_mode="symbolic"
)
count = 0
for node in gm.graph.nodes:
if node.op == "placeholder":
count += 1
self.assertEqual(count, 1)

def test_export_pass_arg_by_name(self):
class BasicModule(torch.nn.Module):
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,14 @@ def wrap_unspecialized_primitive(self, value):
options = {"guards": guards}
else:
options = {}

if self.tx.export and not isinstance(self.get_source(), LocalInputSource):
raise AssertionError(
"Dynamo attempts to add additional input: value={}, source={}".format(
wrapped_value, self.get_source()
)
)

options.update({"source": self.get_source()})
if isinstance(wrapped_value, torch.Tensor):
options.update({"raw_value": value})
Expand Down

0 comments on commit e96b7d3

Please sign in to comment.