diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 40807a87232..eaabc6589b5 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -172,29 +172,18 @@ def fuse_pt2( return converted_graph_module -def quantize_pt2( +# Note: quantizer is not optional here to force the user to supply a quantizer +# and ensure consistency is more likely to be maintained. +def get_fake_quant_model( model: torch.nn.Module, inputs: tuple[object, ...], - quantizer: Optional[CadenceQuantizer] = None, + quantizer: CadenceQuantizer, calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, -) -> ExportedProgram: - """ - Trace, prepare, convert and fuse the model using the given quantizer. - If calibration data is provided, it will be used to calibrate the model. If - not, the inputs will be used for calibration instead, which is useful for - unit tests but should not be used for end-to-end use cases. - Returns a GraphModule with the quantized model. - Note: this function should not be called directly in general. Please use - quantize_and_export_to_executorch for most needs. - """ +) -> torch.fx.GraphModule: # Make the model inference mode by calling model.eval() model.eval() - # Instantiate the quantizer to CadenceQuantizer if not supplied - if not quantizer: - quantizer = CadenceDefaultQuantizer() - program = trace(model, inputs, dump_graphs=dump_graphs) if dump_graphs: @@ -214,6 +203,37 @@ def quantize_pt2( # Get converted graph module converted_gm = convert_pt2(prepared_gm, dump_graphs=dump_graphs) + return converted_gm + + +def quantize_pt2( + model: torch.nn.Module, + inputs: tuple[object, ...], + quantizer: Optional[CadenceQuantizer] = None, + calibration_data: Optional[list[tuple[object, ...]]] = None, + dump_graphs: bool = False, +) -> ExportedProgram: + """ + Trace, prepare, convert and fuse the model using the given quantizer. + If calibration data is provided, it will be used to calibrate the model. If + not, the inputs will be used for calibration instead, which is useful for + unit tests but should not be used for end-to-end use cases. + Returns a GraphModule with the quantized model. + Note: this function should not be called directly in general. Please use + quantize_and_export_to_executorch for most needs. + """ + # Instantiate the quantizer to CadenceQuantizer if not supplied + if not quantizer: + quantizer = CadenceDefaultQuantizer() + + # Get the converted (aka fake quant) graph module + converted_gm = get_fake_quant_model( + model, + inputs, + quantizer=quantizer, + calibration_data=calibration_data, + dump_graphs=dump_graphs, + ) # Get fused model fused_gm = fuse_pt2(converted_gm, quantizer) @@ -237,7 +257,7 @@ def quantize_pt2( torch.ops.aten.angle.default, torch.ops.aten.rms_norm.default, ] -TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload, ...] = [ +TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload] = [ torch.ops.aten.rms_norm.default, ]