Skip to content

Commit

Permalink
Bisect FX node asserts on ValidationException. (#107493)
Browse files Browse the repository at this point in the history
This PR introduces binary search for finding smaller validation errors, when they occur.

We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source
expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we
raise the error caused by the earliest node.

In summary, the changes are:
- Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_
- Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_

Edit: moved `ShapeEnv` replay-recording to #107989

Pull Request resolved: #107493
Approved by: https://github.com/ezyang
ghstack dependencies: #107989
  • Loading branch information
ysiraichi authored and pytorchmergebot committed Sep 15, 2023
1 parent a873f52 commit dfdc0b6
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 70 deletions.
123 changes: 110 additions & 13 deletions test/dynamo/test_exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,47 +195,144 @@ def fn001(x):
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
assume_static_by_default=False,
translation_validation=True,
translation_validation_no_bisect=True,
suppress_errors=False,
)
def test_trigger_on_error(self):
from torch.fx.experimental.validator import ValidationException

@torch.compile
def fn(x):
return x.reshape(-1, 4)
def fn(x, shape):
return x.split(shape)

self.assertExpectedInlineMunged(
ValidationException,
lambda: fn(torch.randn(20)),
lambda: fn(torch.randn(20), (5, 10, 5)),
"""\
translation validation failed.
Model:
==> L['shape'][0]: 2
==> L['shape'][1]: 2
==> L['shape'][2]: 4
==> L['x'].size()[0]: 9
==> L['x'].storage_offset(): 0
==> s0: 4
==> L['x'].stride()[0]: 1
==> L['x'].size()[0]: 4
==> s0: 9
==> s1: 2
==> s2: 2
==> s3: 4
Assertions:
==> (== L['x'].size()[0] s0)
==> (> s0 1)
==> (Not (And (< L['x'].size()[0] 4) (>= L['x'].size()[0] 0)))
==> (True)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 1)
==> (True)
Target Expressions:
==> (>= 9223372036854775806 s0)
==> (== 4 L['x'].size()[0])
==> (!= (+ s3 (* 2 s1)) s0)
==> (!= s1 s3)
==> (<= (* 2 s1) (+ s0 (* -1 s3)))
==> (<= (* 2 s1) s0)
==> (<= (* 2 s1) s0)
==> (<= (+ s3 (* 2 s1)) s0)
==> (<= 0 (+ s0 (* -1 s1)))
==> (<= 0 s1)
==> (<= 0 s3)
==> (<= 2 s1)
==> (<= 2 s2)
==> (<= 2 s3)
==> (<= 6 s0)
==> (<= s1 s0)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s1)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (== s2 s1)
==> (> s0 0)
==> (>= 9223372036854775802 s1)
==> (>= 9223372036854775802 s2)
==> (>= 9223372036854775802 s3)
==> (>= 9223372036854775806 s0)
==> (And (<= (* 2 s1) s0) (<= (* -1 s0) (* 2 s1)))
==> (And (<= s1 s0) (<= (* -1 s0) s1))
Failed Source Expressions:
==> (!= L['shape'][0] L['shape'][1])
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])
==> (== L['shape'][0] L['shape'][2])""",
)

@skipIf(not TEST_Z3, "z3 not installed")
@torch._dynamo.config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
assume_static_by_default=False,
translation_validation=True,
suppress_errors=False,
)
def test_trigger_bisect_on_error(self):
from torch.fx.experimental.validator import BisectValidationException

@torch.compile
def fn(x, shape):
return x.split(shape)

self.assertExpectedInlineMunged(
BisectValidationException,
lambda: fn(torch.randn(20), (5, 10, 5)),
"""\
translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
Failure ocurred while running node:
%split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
Model:
==> L['shape'][0]: -9223372036854775807
==> L['shape'][1]: -9223372036854775807
==> L['shape'][2]: -9223372036854775807
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: -9223372036854775807
==> s2: -9223372036854775807
==> s3: -9223372036854775807
Assertions:
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 1)
Target Expressions:
==> (!= (+ s1 s2 s3) s0)
==> (<= -9223372036854775808 s1)
==> (<= -9223372036854775808 s2)
==> (<= -9223372036854775808 s3)
==> (<= 2 s0)
==> (== 4 s0)
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (>= 9223372036854775806 s0)
==> (>= 9223372036854775807 s1)
==> (>= 9223372036854775807 s2)
==> (>= 9223372036854775807 s3)
Failed Source Expressions:
==> (!= 4 L['x'].size()[0])""",
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
)


Expand Down
9 changes: 4 additions & 5 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ def _compile(
compile_id=None,
) -> Optional[GuardedCode]:
from torch.fx.experimental.validator import (
bisect,
BisectValidationException,
translation_validation_enabled,
ValidationException,
)
Expand Down Expand Up @@ -451,11 +453,7 @@ def transform(instructions, code_options):
raise
except Exception:
if translation_validation_enabled():
fakes = tracer.output.tracked_fakes
tracer.output.shape_env.produce_guards(
[a.fake for a in fakes],
[a.source for a in fakes],
)
bisect(tracer.output.shape_env)
raise

output = tracer.output
Expand Down Expand Up @@ -569,6 +567,7 @@ def log_bytecode(prefix, name, filename, line_no, code):
GuardOnDataDependentSymNode,
ValidationException,
UncapturedHigherOrderOpError,
BisectValidationException,
) as e:
fail_reason = str(e)
exception_handler(e, code, frame, export=export)
Expand Down
65 changes: 43 additions & 22 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pstats
import sys
import textwrap
import threading
import time
import types
import typing
Expand Down Expand Up @@ -1387,6 +1388,23 @@ def visit(n: torch.fx.Node):
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None


_current_node = threading.local()


def get_current_node():
return getattr(_current_node, "value", None)


@contextmanager
def set_current_node(node):
old = get_current_node()
_current_node.value = node
try:
yield
finally:
_current_node.value = old


def run_node(tracer, node, args, kwargs, nnmodule):
"""
Runs a given node, with the given args and kwargs.
Expand All @@ -1404,28 +1422,31 @@ def run_node(tracer, node, args, kwargs, nnmodule):
"""
op = node.op

try:
if op == "call_function":
return node.target(*args, **kwargs)
elif op == "call_method":
return getattr(args[0], node.target)(*args[1:], **kwargs)
elif op == "call_module":
assert nnmodule is not None
return nnmodule(*args, **kwargs)
elif op == "get_attr":
return tracer.get_submodule(node.target)
elif op == "placeholder":
assert "example_value" in node.meta
return node.meta["example_value"]
except NotImplementedError as e:
# NB: mimic how wrap_fake_exception does it
from .exc import unimplemented

raise unimplemented(f"running {op} {node.target}(*{args}, **{kwargs})") from e

except Exception as e:
fn_str = f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n"
raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
with set_current_node(node):
try:
if op == "call_function":
return node.target(*args, **kwargs)
elif op == "call_method":
return getattr(args[0], node.target)(*args[1:], **kwargs)
elif op == "call_module":
assert nnmodule is not None
return nnmodule(*args, **kwargs)
elif op == "get_attr":
return tracer.get_submodule(node.target)
elif op == "placeholder":
assert "example_value" in node.meta
return node.meta["example_value"]
except NotImplementedError as e:
# NB: mimic how wrap_fake_exception does it
from .exc import unimplemented

raise unimplemented(
f"running {op} {node.target}(*{args}, **{kwargs})"
) from e

except Exception as e:
fn_str = f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n"
raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e

raise AssertionError(op)

Expand Down
33 changes: 28 additions & 5 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ class GuardOnDataDependentSymNode(RuntimeError):
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
"SymDispatchMode", "guard_int", "guard_float", "guard_scalar", "wrap_node",
"method_to_operator", "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
"is_concrete_bool",
"is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
]

# FX node metadata keys for symbolic shape FX graph.
SHAPEENV_EVENT_KEY = "shapeenv_event"
CURRENT_NODE_KEY = "current_node"

# These are modules that contain generic code for interacting with ShapeEnv
# which are unlikely to identify a particular interesting guard statement
@lru_cache(None)
Expand Down Expand Up @@ -2417,6 +2421,11 @@ def remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:
self.name_to_node.pop(node.name)
self.graph.erase_node(node)

def add_fx_node_metadata(self, node: torch.fx.Node) -> None:
from torch._dynamo.utils import get_current_node
node.meta[SHAPEENV_EVENT_KEY] = self.last_event_index()
node.meta[CURRENT_NODE_KEY] = get_current_node()

def _suppress_guards_tls(self):
return getattr(TLS, "suppress_guards", False)

Expand Down Expand Up @@ -3775,6 +3784,14 @@ def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):
eql, _ = self.create_fx_call_function(operator.eq, (fx_node, concrete_val))
node, fresh = self.create_fx_call_function(torch._assert, (eql,))

assert node is not None
# If this is a fresh node, we have to remember the event index that
# corresponds to this assertion node.
# Reason: so that, given an assertion node, we can replay the ShapeEnv
# events until the point where this assertion node was freshly created.
if fresh:
self.add_fx_node_metadata(node)

# After creating the FX node corresponding to orig_expr, we must make sure that
# no error will be raised until the end of this function.
#
Expand Down Expand Up @@ -3813,9 +3830,12 @@ def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):

self._check_frozen(expr, concrete_val)

if torch._dynamo.config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY and isinstance(hint, bool):
if isinstance(expr, (sympy.Eq, sympy.Ne)):
expr = sympy.Not(expr)
if (
torch._dynamo.config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
and isinstance(hint, bool)
and isinstance(expr, (sympy.Eq, sympy.Ne))
):
expr = sympy.Not(expr)

if isinstance(expr, (sympy.Eq, sympy.Ne)):
self._maybe_guard_eq(expr, bool(concrete_val))
Expand Down Expand Up @@ -3895,7 +3915,10 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
and fx_node is not None
and not self._suppress_guards_tls()
):
self.create_fx_call_function(torch._assert, (fx_node,))
node, fresh = self.create_fx_call_function(torch._assert, (fx_node,))
assert node is not None
if fresh:
self.add_fx_node_metadata(node)

self._check_frozen(expr, sympy.true)

Expand Down
Loading

0 comments on commit dfdc0b6

Please sign in to comment.