Skip to content

Commit

Permalink
[dynamo] report guard failure user stack, fix incorrectly skipping in…
Browse files Browse the repository at this point in the history
…teresting files (#114053)

Fixes #114015

Before:
```
test/dynamo/test_functions.py::DefaultsTests::test_zip_strict [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'], 94696321555200)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['ys']) == 3
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'], 94696321555200)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['zs']) == 3
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][0], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][0] == 1.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][1], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][1] == 2.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][2], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][2] == 3.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][0], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][0] == 2.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][1], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][1] == 5.0
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][2], 94696321556032)
[2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][2] == 8.0
[2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:365 in init_ambient_guards
[2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140084534469552))  # _dynamo/output_graph.py:371 in init_ambient_guards
[2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3], stride=[1])
[2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function fn in /home/jonch/Desktop/Programming/mlsys/pytorch/test/dynamo/test_functions.py:2539
[2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG]     triggered by the following guard failure(s):
[2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG]     - L['zs'][2] == 8.0

```

After:
```
test/dynamo/test_functions.py::DefaultsTests::test_zip_strict [2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False           # x = x.clone()  # test/dynamo/test_functions.py:2540 in fn
[2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'], 94568804551424)                     # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['ys']) == 3                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'], 94568804551424)                     # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['zs']) == 3                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][0], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][0] == 1.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][1], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][1] == 2.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][2], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][2] == 3.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][0], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][0] == 2.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][1], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][1] == 5.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][2], 94568804552256)                  # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][2] == 8.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:365 in init_ambient_guards
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140370726823264))  # _dynamo/output_graph.py:371 in init_ambient_guards
[2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3], stride=[1])  # x = x.clone()  # test/dynamo/test_functions.py:2540 in fn
[2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function fn in /home/jonch/Desktop/Programming/mlsys/pytorch/test/dynamo/test_functions.py:2539
[2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG]     triggered by the following guard failure(s):
[2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG]     - L['zs'][2] == 8.0                                             # for y, z in zip(ys, zs, strict=True):  # test/dynamo/test_functions.py:2541 in fn

```

Pull Request resolved: #114053
Approved by: https://github.com/ezyang
  • Loading branch information
jon-chuang authored and pytorchmergebot committed Nov 22, 2023
1 parent 2b72543 commit b4faa6b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 33 deletions.
32 changes: 16 additions & 16 deletions test/dynamo/test_aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ def guard_fail_fn(failure):
fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
compare_equal_outs_and_grads(self, F(), fxy, (x, y))
compare_equal_outs_and_grads(self, F(), fxy, (x, z))
self.assertExpectedInline(
failure_reason,
self.assertIn(
"""tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""",
failure_reason,
)

# Reset failure reason
Expand Down Expand Up @@ -421,7 +421,7 @@ def guard_fail_fn(failure):
fxx(x3, x3)
fxx(x4, y4)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(failure_reason, """L['x'] is L['y']""")
self.assertIn("""L['x'] is L['y']""", failure_reason)

@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self):
Expand Down Expand Up @@ -456,9 +456,9 @@ def guard_fail_fn(failure):
f(a1, a1, a1, a1, 2, 2)
f(a2, b2, b2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(
failure_reason,
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)

torch._dynamo.reset()
Expand All @@ -474,7 +474,7 @@ def guard_fail_fn(failure):
f(a3, b3, c3, c3, 3, 3)
f(a4, b4, c4, d4, 3, 3)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
self.assertIn("""L['c'] is L['d']""", failure_reason)

@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
Expand Down Expand Up @@ -512,9 +512,9 @@ def guard_fail_fn(failure):
f(a1, a1, a1, a1, 2, 2)
f(a2, b2, b2, b2, 2, 2)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(
failure_reason,
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)

@patch("torch._functorch.config.debug_assert", True)
Expand Down Expand Up @@ -550,9 +550,9 @@ def guard_fail_fn(failure):
f([3, 2, 1], [4, 5, 6], a1, a1, a1, a1)
f([3, 2, 1], [4, 5, 6], a2, b2, b2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(
failure_reason,
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)

torch._dynamo.reset()
Expand Down Expand Up @@ -602,9 +602,9 @@ def guard_fail_fn(failure):
f(a1, a1, a1, a1)
f(a2, b2, b2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(
failure_reason,
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)

torch._dynamo.reset()
Expand All @@ -620,7 +620,7 @@ def guard_fail_fn(failure):
f(a3, b3, c3, c3)
f(a4, b4, c4, d4)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
self.assertIn("""L['c'] is L['d']""", failure_reason)

@patch("torch._functorch.config.debug_assert", True)
def test_arg_dupe_via_dynamo_recompiles_many_args(self):
Expand Down Expand Up @@ -651,9 +651,9 @@ def guard_fail_fn(failure):
f(a1, a1, a1, a1)
f(a2, b2, b2, b2)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(
failure_reason,
self.assertIn(
"""L['a'] is L['b']""",
failure_reason,
)

torch._dynamo.reset()
Expand All @@ -669,7 +669,7 @@ def guard_fail_fn(failure):
f(a3, b3, c3, c3)
f(a4, b4, c4, d4)
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
self.assertIn("""L['c'] is L['d']""", failure_reason)

@expectedFailureDynamic # https://github.com/pytorch/pytorch/issues/103539
@torch._dynamo.config.patch(automatic_dynamic_shapes=False)
Expand Down
32 changes: 32 additions & 0 deletions test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,38 @@ def fn(x):
~~^~~""",
)

@make_logging_test(guards=True, recompiles=True)
def test_guards_recompiles(self, records):
def fn(x, ys, zs):
return inner(x, ys, zs)

def inner(x, ys, zs):
for y, z in zip(ys, zs):
x += y * z
return x

ys = [1.0, 2.0]
zs = [3.0]
x = torch.tensor([1.0])

fn_opt = torch._dynamo.optimize("eager")(fn)
fn_opt(x, ys, zs)
fn_opt(x, ys[:1], zs)

record_str = "\n".join(r.getMessage() for r in records)

self.assertIn(
"""\
L['zs'][0] == 3.0 # for y, z in zip(ys, zs):""",
record_str,
)
self.assertIn(
"""\
triggered by the following guard failure(s):\n\
- len(L['ys']) == 2 # for y, z in zip(ys, zs):""",
record_str,
)

@make_logging_test(**torch._logging.DEFAULT_LOGGING)
def test_default_logging(self, records):
def fn(a):
Expand Down
33 changes: 19 additions & 14 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,9 +650,9 @@ def guard_failures(failure):
)(compare_shapes)
opt_fn(torch.randn([3, 4]))
opt_fn(torch.randn([4, 3]))
self.assertExpectedInline(
guard_failure.reason,
self.assertIn(
"""tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
guard_failure.reason,
)

def test_builtin_abs(self):
Expand Down Expand Up @@ -716,9 +716,9 @@ def fn(x, y):
),
sorted(guard_code),
)
self.assertExpectedInline(
"\n".join(guard_code),
"""\
guard_code_str = "\n".join(guard_code)

for line in """\
2 <= L['x'].size()[0]
L['x'] is L['y']
L['x'].ndimension() == 2
Expand All @@ -734,8 +734,13 @@ def fn(x, y):
not ___dict_contains('cccccccc', G['sys'].modules)
str(L['x'].device) == 'cpu'
str(L['x'].dtype) == 'torch.float32'
utils_device.CURRENT_DEVICE == None""",
)
utils_device.CURRENT_DEVICE == None""".split(
"\n"
):
self.assertIn(
line,
guard_code_str,
)

def test_fold(self):
def fn(a):
Expand Down Expand Up @@ -5240,12 +5245,12 @@ def guard_failures(failure):
self.assertTrue(guard_failure is not None)
first_guard_failure = guard_failure[0].partition("\n")[0]
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
first_guard_failure,
self.assertIn(
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
first_guard_failure,
)
else:
self.assertExpectedInline(first_guard_failure, """L['x'].size()[0] < 3""")
self.assertIn("""L['x'].size()[0] < 3""", first_guard_failure)

def test_guard_failure_fn2(self):
def fn(x, y):
Expand Down Expand Up @@ -5273,9 +5278,9 @@ def guard_failures(failure):
opt_fn(x2, y2)

if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
guard_failure[0],
self.assertIn(
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
guard_failure[0],
)
else:
self.assertTrue(guard_failure is None)
Expand Down Expand Up @@ -5308,9 +5313,9 @@ def guard_failures(failure):

# guard is expected for both static and dynamic shapes
self.assertTrue(guard_failure is not None)
self.assertExpectedInline(
guard_failure[0],
self.assertIn(
"""len(L['x']) == 10""",
guard_failure[0],
)

def test_restore_graphstate(self):
Expand Down
9 changes: 6 additions & 3 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,15 +1031,15 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn):

# Don't report this guard, it's always the same, useless!
code_parts = ["___guarded_code.valid", "___check_global_state()"]
verbose_code_parts = code_parts[:]

def add_code_part(code, guard, log_only=False):
extra = ""
if guard.user_stack:
for fs in reversed(guard.user_stack):
if fs.filename not in uninteresting_files():
extra = f" # {format_frame(fs, line=True)}"
break
else:
extra = f" # {format_frame(fs, line=True)}"
elif guard.stack:
extra = f" # {format_frame(guard.stack.summary()[-1])}"

Expand All @@ -1064,6 +1064,7 @@ def add_code_part(code, guard, log_only=False):

if not log_only:
code_parts.append(code)
verbose_code_parts.append(f"{code:<60}{extra}")

seen = set()
for gcl in builder.code:
Expand Down Expand Up @@ -1113,6 +1114,7 @@ def convert(size_or_stride):
)
# Do this manually, to un-stagger the guards in log message
code_parts.append(f"___check_tensors({tensor_check_args})")
verbose_code_parts.append(f"___check_tensors({tensor_check_args})")
tensor_check_guards = builder.tensor_check_guards

for i, name in enumerate(tensor_check_names):
Expand Down Expand Up @@ -1183,6 +1185,7 @@ def convert(size_or_stride):
# TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both
guard_fn.args = largs
guard_fn.code_parts = code_parts
guard_fn.verbose_code_parts = verbose_code_parts
# Grab only G, but preserve "G" because guards access it as "G"
guard_fn.global_scope = {
"G": builder.scope["G"],
Expand Down Expand Up @@ -1282,7 +1285,7 @@ def get_guard_fail_reason(
scope.update(guard_fn.closure_vars)
scope["___check_tensors"] = scope["___check_tensors_verbose"]
reasons: List[str] = []
for part in guard_fn.code_parts:
for part in guard_fn.verbose_code_parts:
global_scope = dict(guard_fn.global_scope)
global_scope["__compile_source__"] = part
with report_compile_source_on_error():
Expand Down

0 comments on commit b4faa6b

Please sign in to comment.