From a10de01bd7bb129e64afa99e5ce51c5392136197 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 14 Nov 2025 10:01:03 -0800 Subject: [PATCH 1/6] print device and stride when print module --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index aee089cad9..cc3f5b783d 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -36,7 +36,7 @@ def export_joint( ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) logger.info("Dynamo gm:") - logger.info(gm.print_readable(print_output=False)) + logger.info(gm.print_readable(print_output=False, include_stride=True, include_device=True) tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): @@ -195,14 +195,14 @@ def compiler( passes = DEFAULT_COMPILER_PASSES logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.info(gm.print_readable(print_output=False, include_stride=True, include_device=True)) for pass_fn in passes: logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.info(gm.print_readable(print_output=False, include_stride=True, include_device=True)) return gm From 8827380e5ac5c795df83af82f1f5c966e01b374c Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 14 Nov 2025 10:03:51 -0800 Subject: [PATCH 2/6] nit --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index cc3f5b783d..259014b4c7 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -36,7 +36,7 @@ def export_joint( ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) logger.info("Dynamo gm:") - logger.info(gm.print_readable(print_output=False, include_stride=True, include_device=True) + logger.info(gm.print_readable(print_output=False, include_stride=True, include_device=True)) tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): From 745cc1e41e72a48052f475b2c5f23989cf505080 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 14 Nov 2025 10:05:07 -0800 Subject: [PATCH 3/6] lint --- .../experiments/compiler_toolkit/graph_utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 259014b4c7..c642e7d26b 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -36,7 +36,11 @@ def export_joint( ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) logger.info("Dynamo gm:") - logger.info(gm.print_readable(print_output=False, include_stride=True, include_device=True)) + logger.info( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): @@ -195,14 +199,18 @@ def compiler( passes = DEFAULT_COMPILER_PASSES logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False, include_stride=True, include_device=True)) + logger.info( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) for pass_fn in passes: logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False, include_stride=True, include_device=True)) + logger.info( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) return gm From 842c431b453bdbc5603da4dc39e28cf457f9970f Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 14 Nov 2025 10:21:34 -0800 Subject: [PATCH 4/6] log debug instead of info --- .../experiments/compiler_toolkit/graph_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index ac49a1f768..c8831dd68e 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -46,8 +46,8 @@ def export_joint( torch.fx.traceback.preserve_node_meta(), ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) - logger.info("Dynamo gm:") - logger.info( + logger.debug("Dynamo gm:") + logger.debug( gm.print_readable( print_output=False, include_stride=True, include_device=True ) @@ -227,18 +227,18 @@ def compiler( if passes is None: passes = DEFAULT_COMPILER_PASSES - logger.info(f"{name} before compiler:") - logger.info( + logger.debug(f"{name} before compiler:") + logger.debug( gm.print_readable(print_output=False, include_stride=True, include_device=True) ) _dump_gm(dump_folder, gm, f"{name}_before_compiler") for pass_fn in passes: - logger.info(f"Applying pass: {pass_fn.__name__}") + logger.debug(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) - logger.info(f"{name} after compiler:") - logger.info( + logger.debug(f"{name} after compiler:") + logger.debug( gm.print_readable(print_output=False, include_stride=True, include_device=True) ) _dump_gm(dump_folder, gm, f"{name}_after_compiler") @@ -295,6 +295,6 @@ def get_compiler_passes_from_config(job_config: JobConfig): compiler_passes.append(AVAILABLE_PASSES[pass_name]) if pass_names: - logger.info(f"Using compiler passes from config: {pass_names}") + logger.debug(f"Using compiler passes from config: {pass_names}") return compiler_passes From 3482e3e00de920dd4adaa88505141cda099693a2 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 14 Nov 2025 10:22:16 -0800 Subject: [PATCH 5/6] nit --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index c8831dd68e..99ab4e6ef4 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -234,7 +234,7 @@ def compiler( _dump_gm(dump_folder, gm, f"{name}_before_compiler") for pass_fn in passes: - logger.debug(f"Applying pass: {pass_fn.__name__}") + logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) logger.debug(f"{name} after compiler:") @@ -295,6 +295,6 @@ def get_compiler_passes_from_config(job_config: JobConfig): compiler_passes.append(AVAILABLE_PASSES[pass_name]) if pass_names: - logger.debug(f"Using compiler passes from config: {pass_names}") + logger.info(f"Using compiler passes from config: {pass_names}") return compiler_passes From 3c31bf4b9b77316eab783d3f6b01f1df9bf1c6c2 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 14 Nov 2025 10:26:39 -0800 Subject: [PATCH 6/6] also print in graph dump --- torchtitan/experiments/compiler_toolkit/graph_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 99ab4e6ef4..cd758438b3 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -29,7 +29,9 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No output_path = Path(dump_folder) / "compiler" / f"{name}.txt" output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(gm.print_readable(print_output=False)) + output_path.write_text( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) def export_joint(