Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dynamo] report guard failure user stack, fix incorrectly skipping in…
…teresting files (#114053) Fixes #114015 Before: ``` test/dynamo/test_functions.py::DefaultsTests::test_zip_strict [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS: [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'], 94696321555200) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['ys']) == 3 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'], 94696321555200) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['zs']) == 3 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][0], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][0] == 1.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][1], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][1] == 2.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][2], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][2] == 3.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][0], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][0] == 2.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][1], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][1] == 5.0 [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][2], 94696321556032) [2023-11-18 23:11:09,316] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][2] == 8.0 [2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:365 in init_ambient_guards [2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140084534469552)) # _dynamo/output_graph.py:371 in init_ambient_guards [2023-11-18 23:11:09,317] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3], stride=[1]) [2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function fn in /home/jonch/Desktop/Programming/mlsys/pytorch/test/dynamo/test_functions.py:2539 [2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s): [2023-11-18 23:11:09,320] torch._dynamo.guards.__recompiles: [DEBUG] - L['zs'][2] == 8.0 ``` After: ``` test/dynamo/test_functions.py::DefaultsTests::test_zip_strict [2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS: [2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False # x = x.clone() # test/dynamo/test_functions.py:2540 in fn [2023-11-18 23:07:33,341] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'], 94568804551424) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['ys']) == 3 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'], 94568804551424) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] len(L['zs']) == 3 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][0], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][0] == 1.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][1], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][1] == 2.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['ys'][2], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['ys'][2] == 3.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][0], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][0] == 2.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][1], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][1] == 5.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(L['zs'][2], 94568804552256) # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] L['zs'][2] == 8.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:365 in init_ambient_guards [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140370726823264)) # _dynamo/output_graph.py:371 in init_ambient_guards [2023-11-18 23:07:33,342] [0/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3], stride=[1]) # x = x.clone() # test/dynamo/test_functions.py:2540 in fn [2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function fn in /home/jonch/Desktop/Programming/mlsys/pytorch/test/dynamo/test_functions.py:2539 [2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s): [2023-11-18 23:07:33,346] torch._dynamo.guards.__recompiles: [DEBUG] - L['zs'][2] == 8.0 # for y, z in zip(ys, zs, strict=True): # test/dynamo/test_functions.py:2541 in fn ``` Pull Request resolved: #114053 Approved by: https://github.com/ezyang
- Loading branch information