Skip to content

Commit

Permalink
Run translation validation on tracing error. (#106645)
Browse files Browse the repository at this point in the history
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

Pull Request resolved: #106645
Approved by: https://github.com/ezyang
  • Loading branch information
ysiraichi authored and pytorchmergebot committed Aug 14, 2023
1 parent 937cd37 commit d8ad748
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 283 deletions.
50 changes: 49 additions & 1 deletion test/dynamo/test_exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch._dynamo.test_case
from torch._dynamo.comptime import comptime
from torch._dynamo.exc import Unsupported
from torch.testing._internal.common_utils import munge_exc
from torch.testing._internal.common_device_type import skipIf
from torch.testing._internal.common_utils import munge_exc, TEST_Z3
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test


Expand Down Expand Up @@ -189,6 +190,53 @@ def fn001(x):
ReluCompileError:""",
)

@skipIf(not TEST_Z3, "z3 not installed")
@torch._dynamo.config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
assume_static_by_default=False,
translation_validation=True,
suppress_errors=False,
)
def test_trigger_on_error(self):
from torch.fx.experimental.validator import ValidationException

@torch.compile
def fn(x):
return x.reshape(-1, 4)

self.assertExpectedInlineMunged(
ValidationException,
lambda: fn(torch.randn(20)),
"""\
translation validation failed.
Model:
==> L['x'].storage_offset(): 0
==> s0: 4
==> L['x'].stride()[0]: 1
==> L['x'].size()[0]: 4
Assertions:
==> (== L['x'].size()[0] s0)
==> (> s0 1)
==> (Not (And (< L['x'].size()[0] 4) (>= L['x'].size()[0] 0)))
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (True)
Target Expressions:
==> (>= 9223372036854775806 s0)
==> (== 4 L['x'].size()[0])
==> (== 0 L['x'].storage_offset())
==> (> s0 0)
==> (== 1 L['x'].stride()[0])
==> (<= 2 s0)
==> (== 4 s0)
Failed Source Expressions:
==> (!= 4 L['x'].size()[0])""",
)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
300 changes: 147 additions & 153 deletions test/test_fx_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,171 +1702,165 @@ def test_normalize_args_op_overload(self):
self.assertIs(kwargs["the_template"], inp2)


class TestTranslationValidator(TestCase):
def _prepare_for_translation_validation(self):
from torch.fx.experimental.validator import TranslationValidator

validator = TranslationValidator()

# SymPy symbols.
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)

# Z3 symbols.
[validator.add_var(s, int) for s in (s0, s1, s2)]
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))

return (s0, s1, s2), (z0, z1, z2), validator

@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_sympy_to_z3_translation(self):
import z3
from torch.utils._sympy.functions import FloorDiv, Mod
from torch.fx.experimental.validator import SympyToZ3

(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()

test_cases = [
# Integer constants.
(sympy.S.Zero, z3.IntVal(0)),
(sympy.S.One, z3.IntVal(1)),
(sympy.S.NegativeOne, z3.IntVal(-1)),
(sympy.Integer(2), z3.IntVal(2)),
(
s0,
z0,
),
# Arithmetic operations.
*[
(op(s0, s1), op(z0, z1))
for op in (
operator.add,
operator.mul,
operator.pow,
)
],
# Logical operations.
*[
(sympy_op(s0, s1), z3_op(z0, z1))
for sympy_op, z3_op in (
(sympy.Eq, operator.eq),
(sympy.Ne, operator.ne),
(sympy.Lt, operator.lt),
(sympy.Le, operator.le),
(sympy.Gt, operator.gt),
(sympy.Ge, operator.ge),
)
],
# Other operations.
(
s0 - s1,
z0 + z3.IntVal(-1) * z1,
),
(
s0 / s1,
z3.ToReal(z0) * (z1**-1),
),
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
(
Mod(s2, (s0 / s1)),
z2
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
* (z3.ToReal(z0) * z1**-1),
),
(
Mod(s2, s0**3),
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
),
]
if TEST_Z3:
import z3

toZ3 = SympyToZ3(validator)
for sympy_expr, z3_expr in test_cases:
result = toZ3.run(sympy_expr)
self.assertTrue(
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
)
import torch._dynamo.config

@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_sat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
from torch.utils._sympy.functions import FloorDiv, Mod

validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
class TestTranslationValidation(TestCase):
def _prepare_for_translation_validation(self):
validator = TranslationValidator()

# Solutions for target is a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
validator.add_target_expr(s1 > s0**2)
# SymPy symbols.
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)

validator.validate()
# Z3 symbols.
[validator.add_var(s, int) for s in (s0, s1, s2)]
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))

@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_unsat(self):
from torch.fx.experimental.validator import ValidationException
return (s0, s1, s2), (z0, z1, z2), validator

(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
def test_sympy_to_z3(self):

validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()

test_cases = [
# Integer constants.
(sympy.S.Zero, z3.IntVal(0)),
(sympy.S.One, z3.IntVal(1)),
(sympy.S.NegativeOne, z3.IntVal(-1)),
(sympy.Integer(2), z3.IntVal(2)),
(
s0,
z0,
),
# Arithmetic operations.
*[
(op(s0, s1), op(z0, z1))
for op in (
operator.add,
operator.mul,
operator.pow,
)
],
# Logical operations.
*[
(sympy_op(s0, s1), z3_op(z0, z1))
for sympy_op, z3_op in (
(sympy.Eq, operator.eq),
(sympy.Ne, operator.ne),
(sympy.Lt, operator.lt),
(sympy.Le, operator.le),
(sympy.Gt, operator.gt),
(sympy.Ge, operator.ge),
)
],
# Other operations.
(
s0 - s1,
z0 + z3.IntVal(-1) * z1,
),
(
s0 / s1,
z3.ToReal(z0) * (z1**-1),
),
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
(
Mod(s2, (s0 / s1)),
z2
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
* (z3.ToReal(z0) * z1**-1),
),
(
Mod(s2, s0**3),
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
),
]

toZ3 = SympyToZ3(validator)
for sympy_expr, z3_expr in test_cases:
result = toZ3.run(sympy_expr)
self.assertTrue(
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
)

# Solutions for target is NOT a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
# This expression is less restrictive than its counterpart.
validator.add_target_expr(s1 > s0 + 2)
def test_sat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()

with self.assertRaisesRegex(ValidationException, "translation validation failed."):
validator.validate()
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)

@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_z3str(self):
import z3
from torch.fx.experimental.validator import z3str

a = z3.Int("a")
b = z3.Int("b")
special = z3.Real("this.size()[2]")

test_cases = [
(z3.IntVal(42), "42"),
# Variable.
(a, "a"),
# Name with special characters.
(special, "this.size()[2]"),
# Renamed function fpplications.
(a != b, "(!= a b)"),
(a ** b, "(pow a b)"),
# Chain of associative operations.
*[
(op(op(a, 5), b), f"({opstr} 5 a b)")
for op, opstr in [
(operator.add, "+"),
(operator.mul, "*")
]
],
# Revert 'Not' conversions.
(a != b, "(!= a b)"),
(a < b, "(> b a)"),
(a > b, "(> a b)"),
# Ignore 'ToInt' and 'ToReal' functions.
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
(z3.ToReal(a + b), "(+ a b)"),
# Convert to floor division: 'idiv'.
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
]
# Solutions for target is a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
validator.add_target_expr(s1 > s0**2)

for expr, expected in test_cases:
self.assertEqual(z3str(expr), expected)
validator.validate()

def test_unsat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()

validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)

# Solutions for target is NOT a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
# This expression is less restrictive than its counterpart.
validator.add_target_expr(s1 > s0 + 2)

with self.assertRaisesRegex(ValidationException, "translation validation failed."):
validator.validate()

def test_z3str(self):
a = z3.Int("a")
b = z3.Int("b")
special = z3.Real("this.size()[2]")

test_cases = [
(z3.IntVal(42), "42"),
# Variable.
(a, "a"),
# Name with special characters.
(special, "this.size()[2]"),
# Renamed function fpplications.
(a != b, "(!= a b)"),
(a ** b, "(pow a b)"),
# Chain of associative operations.
*[
(op(op(a, 5), b), f"({opstr} 5 a b)")
for op, opstr in [
(operator.add, "+"),
(operator.mul, "*")
]
],
# Revert 'Not' conversions.
(a != b, "(!= a b)"),
(a < b, "(> b a)"),
(a > b, "(> a b)"),
# Ignore 'ToInt' and 'ToReal' functions.
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
(z3.ToReal(a + b), "(+ a b)"),
# Convert to floor division: 'idiv'.
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
]

for expr, expected in test_cases:
self.assertEqual(z3str(expr), expected)


instantiate_device_type_tests(TestNormalizeOperators, globals())
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,11 @@ def is_fbcode():
# used for testing
inject_BUILD_SET_unimplemented_TESTING_ONLY = False

# wraps (un)equalities with 'Not' class after recording the correct expression
# in the FX graph. This should incorrectly construct the divisible and replacement
# lists, and incorrectly issue guards.
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False

_autograd_backward_strict_mode_banned_ops = [
"stride",
"requires_grad",
Expand Down

0 comments on commit d8ad748

Please sign in to comment.