Skip to content

Commit

Permalink
Update on "fix soundness bug with unsupported constraints"
Browse files Browse the repository at this point in the history
We do not raise constraint violations for complex binary conditions, such as conditions involving `%`. Moreover, while these constraints are discovered by our solver, the solver does not inject new constraint violations. This can result in cases where export passes, appropriate assertions are not added, and we get runtime crashes. 

Now, when the solver discovers constraints that are too complex, we force-specialize the involved dimensions and raise a constraint violation when such dimensions are marked dynamic. This forces the user to remove the dynamic marking, and causes the appropriate specialization assertions to be added.

Differential Revision: [D46415786](https://our.internmc.facebook.com/intern/diff/D46415786/)

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov

[ghstack-poisoned]
  • Loading branch information
avikchaudhuri committed Jun 9, 2023
2 parents cc80e79 + 0f4be1e commit d8a72c8
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,25 +195,6 @@ def foo(x):
# There should be nonzero view nodes in the graph
self.assertTrue(view_count > 0)

def test_export_constrain_static(self):
def f(x, y):
b = x.item()
constrain_as_size(b, min=2, max=5)
c = y.dim()
constrain_as_value(c, min=1, max=3)
z = y[0:c]
return torch.empty((b, y.shape[0])), z

x = torch.tensor([3])
y = torch.randn([8, 8, 6])
example_inputs = (x, y)
constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10]
with self.assertRaisesRegex(
torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
"on a value which we evaluated to have a static value of 3. "
):
export(f, example_inputs, constraints)

def test_export_mod_constraints(self):
class BasicDynamiShapeModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -237,6 +218,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
em(x)

def test_export_constrain_static(self):
def f(x, y):
b = x.item()
constrain_as_size(b, min=2, max=5)
c = y.dim()
constrain_as_value(c, min=1, max=3)
z = y[0:c]
return torch.empty((b, y.shape[0])), z

x = torch.tensor([3])
y = torch.randn([8, 8, 6])
example_inputs = (x, y)
constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10]
with self.assertRaisesRegex(
torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
"on a value which we evaluated to have a static value of 3. "
):
export(f, example_inputs, constraints)


if __name__ == '__main__':
Expand Down

0 comments on commit d8a72c8

Please sign in to comment.