From 6b560c12ec6fe99d56691795dac499465246ee0b Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Sun, 6 Oct 2024 17:02:43 -0700 Subject: [PATCH] Cleanup export_model API calls (#5882) Summary: Lots of things are redundant and a few need to move to utils. Subsequent changes will split the export function and separate the run part. Main changes: - call `fuse_pt2` after `convert_pt2` instead of `quantize_pt2`, and avoid calling `convert_pt2` twice - move `print_ops_info` into `export_to_cadence` - remove the need to call `export_to_edge` in `export_model` - move the serialization utils to `utils.py` Reviewed By: zonglinpeng Differential Revision: D63795843 --- backends/cadence/aot/TARGETS | 1 + backends/cadence/aot/compiler.py | 17 +++++-- backends/cadence/aot/export_example.py | 67 ++++++-------------------- backends/cadence/aot/utils.py | 37 +++++++++++++- 4 files changed, 64 insertions(+), 58 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 08093efe317..ae60c299f2c 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -22,6 +22,7 @@ python_library( deps = [ "fbsource//third-party/pypi/tabulate:tabulate", "//caffe2:torch", + "//executorch/exir:lib", "//executorch/exir:memory", "//executorch/exir/dialects:lib", "//executorch/exir/dialects/edge:lib", diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index fe8fc721245..5b151a3b6a4 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -36,6 +36,8 @@ from torch.export import export from torch.export.exported_program import ExportedProgram +from .utils import print_ops_info + # Note: this is not meant as a primary API since it can create inconsistencies # if the quantizer here is different from the quantizer used to convert. It is @@ -193,16 +195,17 @@ def export_to_edge( # Export the model and lower it to an EdgeProgramManager (in edge IR), and -# apply passes specific to Cadence DSP execution. +# apply passes specific to Cadence DSP execution. Return both to print the +# differences. def export_to_cadence( model: torch.nn.Module, inputs: tuple[object, ...], dump_graphs: bool = False, ) -> EdgeProgramManager: - edge_program_manager = export_to_edge(model, inputs) + edge_prog_manager = export_to_edge(model, inputs) # Run a couple required passes for quant/dequant ops - cadence_program_manager = edge_program_manager.transform( + cadence_prog_manager = edge_prog_manager.transform( [ InitializePipeline(), RemoveZeroSizedCatArgsPass(), @@ -216,4 +219,10 @@ def export_to_cadence( ] ) - return cadence_program_manager + # Print some information to terminal + print_ops_info( + edge_prog_manager.exported_program().graph_module, + cadence_prog_manager.exported_program().graph_module, + ) + + return cadence_prog_manager diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index f7920f0b8fb..10433016e38 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -10,14 +10,12 @@ import tempfile from executorch.backends.cadence.aot.ops_registrations import * # noqa -import os from typing import Any, Tuple from executorch.backends.cadence.aot.compiler import ( convert_pt2, export_to_cadence, - export_to_edge, - quantize_pt2, + fuse_pt2, ) from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer from executorch.backends.cadence.runtime import runtime @@ -25,46 +23,13 @@ from executorch.exir import ExecutorchProgramManager from torch import nn -from .utils import print_ops_info +from .utils import save_bpte_program, save_pte_program FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) -def _save_pte_program( - prog: ExecutorchProgramManager, model_name: str, output_dir: str = "" -) -> None: - if model_name.endswith(".pte"): - filename = model_name - else: - filename = os.path.join(output_dir, f"{model_name}.pte") - - try: - with open(filename, "wb") as file: - prog.write_to_file(file) - logging.info(f"Saved exported program to {filename}") - except Exception as e: - logging.error(f"Error while saving to {filename}: {e}") - - -def _save_bpte_program( - buffer: bytes, - model_name: str, - output_dir: str = "", -) -> None: - if model_name.endswith(".bpte"): - filename = model_name - else: - filename = os.path.join(output_dir, f"{model_name}.bpte") - try: - with open(filename, "wb") as f: - f.write(buffer) - logging.info(f"Saved exported program to {filename}") - except Exception as e: - logging.error(f"Error while saving to {output_dir}: {e}") - - def export_model( model: nn.Module, example_inputs: Tuple[Any, ...], @@ -74,32 +39,28 @@ def export_model( working_dir = tempfile.mkdtemp(dir="/tmp") logging.debug(f"Created work directory {working_dir}") - # convert the model (also called in quantize_pt2) - converted_model = convert_pt2(model, example_inputs, CadenceQuantizer()) + # Instantiate the quantizer + quantizer = CadenceQuantizer() - # Get reference outputs from quantized_model - ref_outputs = converted_model(*example_inputs) + # Convert the model + converted_model = convert_pt2(model, example_inputs, quantizer) - # Quantize the model - quantized_model = quantize_pt2(model, example_inputs) + # Get reference outputs from converted model + ref_outputs = converted_model(*example_inputs) - # Get edge program (also called in export_to_cadence) - edge_prog_manager = export_to_edge(quantized_model, example_inputs) + # Quantize the model (note: quantizer needs to be the same as + # the one used in convert_pt2) + quantized_model = fuse_pt2(converted_model, quantizer) # Get edge program after Cadence specific passes cadence_prog_manager = export_to_cadence(quantized_model, example_inputs) + # Get executorch program after Cadence specific passes exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch() logging.info("Final exported graph:\n") exec_prog.exported_program().graph_module.graph.print_tabular() - # Print some information to terminal - print_ops_info( - edge_prog_manager.exported_program().graph_module, - cadence_prog_manager.exported_program().graph_module, - ) - forward_test_data = BundledProgramManager.bundled_program_test_data_gen( method="forward", inputs=example_inputs, expected_outputs=ref_outputs ) @@ -110,9 +71,9 @@ def export_model( forward_test_data, ) # Save the program as pte (default name is CadenceDemoModel.pte) - _save_pte_program(exec_prog, file_name, working_dir) + save_pte_program(exec_prog, file_name, working_dir) # Save the program as btpe (default name is CadenceDemoModel.bpte) - _save_bpte_program(buffer, file_name, working_dir) + save_bpte_program(buffer, file_name, working_dir) logging.debug( f"Executorch bundled program buffer saved to {file_name} is {len(buffer)} total bytes" diff --git a/backends/cadence/aot/utils.py b/backends/cadence/aot/utils.py index f081036ccc1..9e32f3472da 100644 --- a/backends/cadence/aot/utils.py +++ b/backends/cadence/aot/utils.py @@ -8,10 +8,12 @@ import logging import operator +import os from typing import Dict, List, Tuple import torch -from executorch.exir import memory + +from executorch.exir import ExecutorchProgramManager, memory from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket from tabulate import tabulate @@ -185,3 +187,36 @@ def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool: if node.target == torch.ops.aten.scaled_dot_product_attention.default: return True return False + + +def save_pte_program( + prog: ExecutorchProgramManager, model_name: str, output_dir: str = "" +) -> None: + if model_name.endswith(".pte"): + filename = model_name + else: + filename = os.path.join(output_dir, f"{model_name}.pte") + + try: + with open(filename, "wb") as file: + prog.write_to_file(file) + logging.info(f"Saved exported program to {filename}") + except Exception as e: + logging.error(f"Error while saving to {filename}: {e}") + + +def save_bpte_program( + buffer: bytes, + model_name: str, + output_dir: str = "", +) -> None: + if model_name.endswith(".bpte"): + filename = model_name + else: + filename = os.path.join(output_dir, f"{model_name}.bpte") + try: + with open(filename, "wb") as f: + f.write(buffer) + logging.info(f"Saved exported program to {filename}") + except Exception as e: + logging.error(f"Error while saving to {output_dir}: {e}")