From 0a670bd37d5b6155df5b53f6a20838aad72b992d Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Tue, 28 Jan 2025 14:11:00 -0800 Subject: [PATCH] Improve graph logging for debug purposes (#7991) Summary: Cleans up some of the graph logging. Adds a few important ones: - initial graph from export_for_training - converted graph - fused graph and removes some unnecessary ones. Also standardizes the logging names. Reviewed By: zonglinpeng Differential Revision: D68636227 --- backends/cadence/aot/compiler.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index c7cea31b492..f9abe1c5425 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -56,6 +56,7 @@ def convert_pt2( model: torch.nn.Module, inputs: tuple[object, ...], quantizer: CadenceQuantizer, + dump_graphs: bool = False, ) -> torch.fx.GraphModule: """ Prepare and convert a model using the given quantizer. @@ -86,6 +87,10 @@ def convert_pt2( .module() ) + if dump_graphs: + logging.info("Graph before quantization:") + logging.info(model_gm.graph.print_tabular()) + # Prepare prepared_model = prepare_pt2e(model_gm, quantizer) @@ -95,6 +100,10 @@ def convert_pt2( # Convert converted_model = convert_pt2e(prepared_model) + if dump_graphs: + logging.info("Graph after quantization (before fusion):") + logging.info(model_gm.graph.print_tabular()) + return converted_model @@ -127,6 +136,7 @@ def quantize_pt2( model: torch.nn.Module, inputs: tuple[object, ...], quantizer: Optional[CadenceQuantizer] = None, + dump_graphs: bool = False, ) -> torch.fx.GraphModule: """ Prepare, convert and fuse the model using the given quantizer. @@ -140,11 +150,15 @@ def quantize_pt2( quantizer = CadenceDefaultQuantizer() # Get converted graph module - converted_gm = convert_pt2(model, inputs, quantizer) + converted_gm = convert_pt2(model, inputs, quantizer, dump_graphs) # Get fused model fused_gm = fuse_pt2(converted_gm, quantizer) + if dump_graphs: + logging.info("Graph after quantization and fusion:") + logging.info(fused_gm.graph.print_tabular()) + return fused_gm @@ -152,7 +166,6 @@ def quantize_pt2( def export_program( model: torch.nn.Module, inputs: tuple[object, ...], - dump_graphs: bool = False, ) -> ExportedProgram: assert isinstance(model, torch.nn.Module), "model should be an nn.Module" @@ -162,10 +175,6 @@ def export_program( # Export the model and return it. expo_program = export(model, inputs, strict=True) - if dump_graphs: - logging.info("Exported graph:") - expo_program.graph_module.graph.print_tabular() - return expo_program @@ -179,7 +188,7 @@ def export_to_edge( assert isinstance(model, torch.nn.Module), "model should be an nn.Module" # Export the model into an ExportedProgram. - expo_program = export_program(model, inputs, dump_graphs=dump_graphs) + expo_program = export_program(model, inputs) # Call to_edge to convert the graph to edge IR. # Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704) @@ -200,8 +209,10 @@ def export_to_edge( ) if dump_graphs: - logging.info("Edge graph:") - edge_prog_manager.exported_program().graph_module.graph.print_tabular() + logging.info("Graph after Edge lowering:") + logging.info( + edge_prog_manager.exported_program().graph_module.graph.print_tabular() + ) return edge_prog_manager