diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 62c19cb3de..cd758438b3 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -29,7 +29,9 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No output_path = Path(dump_folder) / "compiler" / f"{name}.txt" output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(gm.print_readable(print_output=False)) + output_path.write_text( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) def export_joint( @@ -47,7 +49,11 @@ def export_joint( ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) logger.debug("Dynamo gm:") - logger.debug(gm.print_readable(print_output=False)) + logger.debug( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) _dump_gm(dump_folder, gm, "dynamo_gm") tracing_context = gm.meta["tracing_context"] @@ -224,7 +230,9 @@ def compiler( passes = DEFAULT_COMPILER_PASSES logger.debug(f"{name} before compiler:") - logger.debug(gm.print_readable(print_output=False)) + logger.debug( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) _dump_gm(dump_folder, gm, f"{name}_before_compiler") for pass_fn in passes: @@ -232,7 +240,9 @@ def compiler( gm = pass_fn(gm, example_inputs) logger.debug(f"{name} after compiler:") - logger.debug(gm.print_readable(print_output=False)) + logger.debug( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) _dump_gm(dump_folder, gm, f"{name}_after_compiler") return gm