Skip to content

Commit

Permalink
Print AOT Autograd graph name when accuracy failed
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: c2494d8bd9fd825cb0b0568cecf7ca955451c976
Pull Request resolved: #99366
  • Loading branch information
ezyang committed Apr 17, 2023
1 parent 62a6d81 commit 610c788
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torch/_dynamo/debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,10 @@ def debug_wrapper(gm, example_inputs, **kwargs):

compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)

from torch._functorch.aot_autograd import get_aot_graph_name
graph_name = get_aot_graph_name()

# TODO: Why do we have to save the orig_graph?
orig_graph = copy.deepcopy(gm.graph)
assert config.repro_after in ("dynamo", "aot", None)
inner_compiled_fn = None
Expand Down Expand Up @@ -578,7 +582,7 @@ def deferred_for_real_inputs(real_inputs):
if inner_compiled_fn is None:
inner_compiled_fn = compiler_fn(gm, example_inputs)
if backend_aot_accuracy_fails(gm, real_inputs, compiler_fn):
log.warning("Accuracy failed for the AOT Autograd graph")
log.warning("Accuracy failed for the AOT Autograd graph %s", graph_name)
dump_compiler_graph_state(
fx.GraphModule(gm, orig_graph),
copy_tensor_attrs,
Expand Down

0 comments on commit 610c788

Please sign in to comment.