Skip to content

Commit b4faa6b

Browse files
jon-chuangpytorchmergebot
authored andcommitted
[dynamo] report guard failure user stack, fix incorrectly skipping interesting files (#114053)
Fixes #114015 Before: ``` test/dynamo/test_functions.py::DefaultsTests::test_zip_strict [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS: [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'], 94696321555200) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['ys']) == 3 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'], 94696321555200) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['zs']) == 3 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][0], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][0] == 1.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][1], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][1] == 2.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][2], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][2] == 3.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][0], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][0] == 2.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][1], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][1] == 5.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][2], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][2] == 8.0 [2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:365 in init_ambient_guards [2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140084534469552)) # _dynamo/output_graph.py:371 in init_ambient_guards [2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3], stride=[1]) [2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function fn in /home/jonch/Desktop/Programming/mlsys/pytorch/test/dynamo/test_functions.py:2539 [2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s): [2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG] - L['zs'][2] == 8.0 ``` After: ``` test/dynamo/test_functions.py::DefaultsTests::test_zip_strict [2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS: [2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False # x = x.clone() # test/dynamo/test_functions.py:2540 in fn [2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'], 94568804551424) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['ys']) == 3 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'], 94568804551424) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['zs']) == 3 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][0], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][0] == 1.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][1], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][1] == 2.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][2], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][2] == 3.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][0], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][0] == 2.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][1], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][1] == 5.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][2], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][2] == 8.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:365 in init_ambient_guards [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140370726823264)) # _dynamo/output_graph.py:371 in init_ambient_guards [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3], stride=[1]) # x = x.clone() # test/dynamo/test_functions.py:2540 in fn [2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function fn in /home/jonch/Desktop/Programming/mlsys/pytorch/test/dynamo/test_functions.py:2539 [2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s): [2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG] - L['zs'][2] == 8.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn ``` Pull Request resolved: #114053 Approved by: https://github.com/ezyang
1 parent 2b72543 commit b4faa6b

File tree

4 files changed

+73
-33
lines changed

4 files changed

+73
-33
lines changed

test/dynamo/test_aot_autograd.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,9 @@ def guard_fail_fn(failure):
302302
fxy = torch._dynamo.optimize(cc, guard_fail_fn=guard_fail_fn)(F())
303303
compare_equal_outs_and_grads(self, F(), fxy, (x, y))
304304
compare_equal_outs_and_grads(self, F(), fxy, (x, z))
305-
self.assertExpectedInline(
306-
failure_reason,
305+
self.assertIn(
307306
"""tensor 'L['y']' requires_grad mismatch. expected requires_grad=1""",
307+
failure_reason,
308308
)
309309

310310
# Reset failure reason
@@ -421,7 +421,7 @@ def guard_fail_fn(failure):
421421
fxx(x3, x3)
422422
fxx(x4, y4)
423423
self.assertEqual(cc.frame_count, 2)
424-
self.assertExpectedInline(failure_reason, """L['x'] is L['y']""")
424+
self.assertIn("""L['x'] is L['y']""", failure_reason)
425425

426426
@patch("torch._functorch.config.debug_assert", True)
427427
def test_arg_dupe_via_dynamo_recompiles_many_args_param_non_tensor_arg(self):
@@ -456,9 +456,9 @@ def guard_fail_fn(failure):
456456
f(a1, a1, a1, a1, 2, 2)
457457
f(a2, b2, b2, b2, 2, 2)
458458
self.assertEqual(cc.frame_count, 2)
459-
self.assertExpectedInline(
460-
failure_reason,
459+
self.assertIn(
461460
"""L['a'] is L['b']""",
461+
failure_reason,
462462
)
463463

464464
torch._dynamo.reset()
@@ -474,7 +474,7 @@ def guard_fail_fn(failure):
474474
f(a3, b3, c3, c3, 3, 3)
475475
f(a4, b4, c4, d4, 3, 3)
476476
self.assertEqual(cc.frame_count, 2)
477-
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
477+
self.assertIn("""L['c'] is L['d']""", failure_reason)
478478

479479
@patch("torch._functorch.config.debug_assert", True)
480480
def test_arg_dupe_via_dynamo_recompiles_many_with_global(self):
@@ -512,9 +512,9 @@ def guard_fail_fn(failure):
512512
f(a1, a1, a1, a1, 2, 2)
513513
f(a2, b2, b2, b2, 2, 2)
514514
self.assertEqual(cc.frame_count, 2)
515-
self.assertExpectedInline(
516-
failure_reason,
515+
self.assertIn(
517516
"""L['a'] is L['b']""",
517+
failure_reason,
518518
)
519519

520520
@patch("torch._functorch.config.debug_assert", True)
@@ -550,9 +550,9 @@ def guard_fail_fn(failure):
550550
f([3, 2, 1], [4, 5, 6], a1, a1, a1, a1)
551551
f([3, 2, 1], [4, 5, 6], a2, b2, b2, b2)
552552
self.assertEqual(cc.frame_count, 2)
553-
self.assertExpectedInline(
554-
failure_reason,
553+
self.assertIn(
555554
"""L['a'] is L['b']""",
555+
failure_reason,
556556
)
557557

558558
torch._dynamo.reset()
@@ -602,9 +602,9 @@ def guard_fail_fn(failure):
602602
f(a1, a1, a1, a1)
603603
f(a2, b2, b2, b2)
604604
self.assertEqual(cc.frame_count, 2)
605-
self.assertExpectedInline(
606-
failure_reason,
605+
self.assertIn(
607606
"""L['a'] is L['b']""",
607+
failure_reason,
608608
)
609609

610610
torch._dynamo.reset()
@@ -620,7 +620,7 @@ def guard_fail_fn(failure):
620620
f(a3, b3, c3, c3)
621621
f(a4, b4, c4, d4)
622622
self.assertEqual(cc.frame_count, 2)
623-
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
623+
self.assertIn("""L['c'] is L['d']""", failure_reason)
624624

625625
@patch("torch._functorch.config.debug_assert", True)
626626
def test_arg_dupe_via_dynamo_recompiles_many_args(self):
@@ -651,9 +651,9 @@ def guard_fail_fn(failure):
651651
f(a1, a1, a1, a1)
652652
f(a2, b2, b2, b2)
653653
self.assertEqual(cc.frame_count, 2)
654-
self.assertExpectedInline(
655-
failure_reason,
654+
self.assertIn(
656655
"""L['a'] is L['b']""",
656+
failure_reason,
657657
)
658658

659659
torch._dynamo.reset()
@@ -669,7 +669,7 @@ def guard_fail_fn(failure):
669669
f(a3, b3, c3, c3)
670670
f(a4, b4, c4, d4)
671671
self.assertEqual(cc.frame_count, 2)
672-
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")
672+
self.assertIn("""L['c'] is L['d']""", failure_reason)
673673

674674
@expectedFailureDynamic # https://github.com/pytorch/pytorch/issues/103539
675675
@torch._dynamo.config.patch(automatic_dynamic_shapes=False)

test/dynamo/test_logging.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,38 @@ def fn(x):
596596
~~^~~""",
597597
)
598598

599+
@make_logging_test(guards=True, recompiles=True)
600+
def test_guards_recompiles(self, records):
601+
def fn(x, ys, zs):
602+
return inner(x, ys, zs)
603+
604+
def inner(x, ys, zs):
605+
for y, z in zip(ys, zs):
606+
x += y * z
607+
return x
608+
609+
ys = [1.0, 2.0]
610+
zs = [3.0]
611+
x = torch.tensor([1.0])
612+
613+
fn_opt = torch._dynamo.optimize("eager")(fn)
614+
fn_opt(x, ys, zs)
615+
fn_opt(x, ys[:1], zs)
616+
617+
record_str = "\n".join(r.getMessage() for r in records)
618+
619+
self.assertIn(
620+
"""\
621+
L['zs'][0] == 3.0 # for y, z in zip(ys, zs):""",
622+
record_str,
623+
)
624+
self.assertIn(
625+
"""\
626+
triggered by the following guard failure(s):\n\
627+
- len(L['ys']) == 2 # for y, z in zip(ys, zs):""",
628+
record_str,
629+
)
630+
599631
@make_logging_test(**torch._logging.DEFAULT_LOGGING)
600632
def test_default_logging(self, records):
601633
def fn(a):

test/dynamo/test_misc.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,9 @@ def guard_failures(failure):
650650
)(compare_shapes)
651651
opt_fn(torch.randn([3, 4]))
652652
opt_fn(torch.randn([4, 3]))
653-
self.assertExpectedInline(
654-
guard_failure.reason,
653+
self.assertIn(
655654
"""tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
655+
guard_failure.reason,
656656
)
657657

658658
def test_builtin_abs(self):
@@ -716,9 +716,9 @@ def fn(x, y):
716716
),
717717
sorted(guard_code),
718718
)
719-
self.assertExpectedInline(
720-
"\n".join(guard_code),
721-
"""\
719+
guard_code_str = "\n".join(guard_code)
720+
721+
for line in """\
722722
2 <= L['x'].size()[0]
723723
L['x'] is L['y']
724724
L['x'].ndimension() == 2
@@ -734,8 +734,13 @@ def fn(x, y):
734734
not ___dict_contains('cccccccc', G['sys'].modules)
735735
str(L['x'].device) == 'cpu'
736736
str(L['x'].dtype) == 'torch.float32'
737-
utils_device.CURRENT_DEVICE == None""",
738-
)
737+
utils_device.CURRENT_DEVICE == None""".split(
738+
"\n"
739+
):
740+
self.assertIn(
741+
line,
742+
guard_code_str,
743+
)
739744

740745
def test_fold(self):
741746
def fn(a):
@@ -5240,12 +5245,12 @@ def guard_failures(failure):
52405245
self.assertTrue(guard_failure is not None)
52415246
first_guard_failure = guard_failure[0].partition("\n")[0]
52425247
if torch._dynamo.config.assume_static_by_default:
5243-
self.assertExpectedInline(
5244-
first_guard_failure,
5248+
self.assertIn(
52455249
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
5250+
first_guard_failure,
52465251
)
52475252
else:
5248-
self.assertExpectedInline(first_guard_failure, """L['x'].size()[0] < 3""")
5253+
self.assertIn("""L['x'].size()[0] < 3""", first_guard_failure)
52495254

52505255
def test_guard_failure_fn2(self):
52515256
def fn(x, y):
@@ -5273,9 +5278,9 @@ def guard_failures(failure):
52735278
opt_fn(x2, y2)
52745279

52755280
if torch._dynamo.config.assume_static_by_default:
5276-
self.assertExpectedInline(
5277-
guard_failure[0],
5281+
self.assertIn(
52785282
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
5283+
guard_failure[0],
52795284
)
52805285
else:
52815286
self.assertTrue(guard_failure is None)
@@ -5308,9 +5313,9 @@ def guard_failures(failure):
53085313

53095314
# guard is expected for both static and dynamic shapes
53105315
self.assertTrue(guard_failure is not None)
5311-
self.assertExpectedInline(
5312-
guard_failure[0],
5316+
self.assertIn(
53135317
"""len(L['x']) == 10""",
5318+
guard_failure[0],
53145319
)
53155320

53165321
def test_restore_graphstate(self):

torch/_dynamo/guards.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,15 +1031,15 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn):
10311031

10321032
# Don't report this guard, it's always the same, useless!
10331033
code_parts = ["___guarded_code.valid", "___check_global_state()"]
1034+
verbose_code_parts = code_parts[:]
10341035

10351036
def add_code_part(code, guard, log_only=False):
10361037
extra = ""
10371038
if guard.user_stack:
10381039
for fs in reversed(guard.user_stack):
10391040
if fs.filename not in uninteresting_files():
1041+
extra = f" # {format_frame(fs, line=True)}"
10401042
break
1041-
else:
1042-
extra = f" # {format_frame(fs, line=True)}"
10431043
elif guard.stack:
10441044
extra = f" # {format_frame(guard.stack.summary()[-1])}"
10451045

@@ -1064,6 +1064,7 @@ def add_code_part(code, guard, log_only=False):
10641064

10651065
if not log_only:
10661066
code_parts.append(code)
1067+
verbose_code_parts.append(f"{code:<60}{extra}")
10671068

10681069
seen = set()
10691070
for gcl in builder.code:
@@ -1113,6 +1114,7 @@ def convert(size_or_stride):
11131114
)
11141115
# Do this manually, to un-stagger the guards in log message
11151116
code_parts.append(f"___check_tensors({tensor_check_args})")
1117+
verbose_code_parts.append(f"___check_tensors({tensor_check_args})")
11161118
tensor_check_guards = builder.tensor_check_guards
11171119

11181120
for i, name in enumerate(tensor_check_names):
@@ -1183,6 +1185,7 @@ def convert(size_or_stride):
11831185
# TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both
11841186
guard_fn.args = largs
11851187
guard_fn.code_parts = code_parts
1188+
guard_fn.verbose_code_parts = verbose_code_parts
11861189
# Grab only G, but preserve "G" because guards access it as "G"
11871190
guard_fn.global_scope = {
11881191
"G": builder.scope["G"],
@@ -1282,7 +1285,7 @@ def get_guard_fail_reason(
12821285
scope.update(guard_fn.closure_vars)
12831286
scope["___check_tensors"] = scope["___check_tensors_verbose"]
12841287
reasons: List[str] = []
1285-
for part in guard_fn.code_parts:
1288+
for part in guard_fn.verbose_code_parts:
12861289
global_scope = dict(guard_fn.global_scope)
12871290
global_scope["__compile_source__"] = part
12881291
with report_compile_source_on_error():

0 commit comments

Comments
 (0)