1212
1313import executorch .backends .cadence .aot .ops_registrations # noqa
1414import torch
15+ from executorch .backends .cadence .aot .compiler_funcs import (
16+ convert as convert_fn ,
17+ prepare as prepare_fn ,
18+ trace as trace_fn ,
19+ )
1520from executorch .backends .cadence .aot .memory_planning import (
1621 CadenceMemoryPlanning ,
1722 print_memory_planning_info ,
3540from executorch .exir .passes import ToOutVarPass
3641from executorch .exir .passes .sym_shape_eval_pass import HintBasedSymShapeEvalPass
3742from executorch .exir .program ._program import to_edge
38- from torch ._inductor .decomposition import remove_decompositions
3943
4044from torch .export .exported_program import ExportedProgram
41- from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
4245
4346from .passes import apply_exir_ops_passes , apply_torch_ops_passes
4447
4548from .utils import print_ops_info
4649
47-
4850default_quantizer = CadenceDefaultQuantizer ()
4951
5052
@@ -62,13 +64,6 @@ def trace(
6264 Trace the model with export and return an ExportedProgram.
6365 """
6466
65- # Make the model inference mode by calling model.eval()
66- model .eval ()
67-
68- # Get default decompositions
69- decomp_table = torch .export .default_decompositions ()
70-
71- # Select ops to keep
7267 ops_to_keep = [
7368 torch .ops .aten .conv1d .default ,
7469 torch .ops .aten .conv2d .default ,
@@ -78,63 +73,54 @@ def trace(
7873 torch .ops .aten .rms_norm .default ,
7974 ]
8075
81- # Remove decompositions for the ops we want to keep
82- # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
83- remove_decompositions (decomp_table , ops_to_keep )
84-
85- # Export with dynamo
86- program = torch .export .export (model , inputs , strict = True ).run_decompositions (
87- decomp_table
76+ program = trace_fn (
77+ model , inputs , is_qat = False , strict = True , ops_to_keep = ops_to_keep
8878 )
8979
9080 if dump_graphs :
9181 logging .info ("Graph before quantization:" )
92- logging .info (program .module () .graph .print_tabular ())
82+ logging .info (program .graph_module .graph .print_tabular ())
9383
9484 return program
9585
9686
97- def prepare_and_convert_pt2 (
87+ def prepare_pt2 (
9888 program : ExportedProgram ,
99- inputs : tuple [object , ...],
10089 quantizer : CadenceQuantizer ,
101- calibration_data : Optional [list [tuple [object , ...]]] = None ,
10290 dump_graphs : bool = False ,
10391) -> torch .fx .GraphModule :
10492 """
105- Prepare and convert a model using the given quantizer.
93+ Prepare a model using the given quantizer.
10694 The quantizer must be supplied and be the same as the one used to
10795 fuse the model later, if applicable. If you do not expect that behavior,
10896 please use quantize_and_fuse_pt2 instead, which will instantiate a
10997 default quantizer for you if needed.
110- If calibration data is provided, it will be used to calibrate the model. If
111- not, the inputs will be used for calibration instead, which is useful for
112- unit tests but should not be used for end-to-end use cases.
113- Returns a GraphModule with the converted model.
98+ Returns a GraphModule with the prepared model.
11499 """
115100
116- # Get the graph module from the ExportedProgram
117- model_gm = program .module ()
101+ prepared_model = prepare_fn (program , quantizer , is_qat = False )
118102
119- assert isinstance (model_gm , torch .fx .GraphModule )
103+ if dump_graphs :
104+ logging .info ("Graph after preparation:" )
105+ logging .info (prepared_model .graph .print_tabular ())
120106
121- # Prepare
122- prepared_model = prepare_pt2e (model_gm , quantizer )
107+ return prepared_model
123108
124- # Calibrate
125- # If no calibration data is provided, use the inputs
126- if calibration_data is None :
127- calibration_data = [inputs ]
128109
129- for samples in calibration_data :
130- prepared_model (* samples )
110+ def convert_pt2 (
111+ graph_module : torch .fx .GraphModule ,
112+ dump_graphs : bool = False ,
113+ ) -> torch .fx .GraphModule :
114+ """
115+ Convert the model
116+ Returns a GraphModule with the converted model.
117+ """
131118
132- # Convert
133- converted_model = convert_pt2e (prepared_model )
119+ converted_model = convert_fn (graph_module )
134120
135121 if dump_graphs :
136- logging .info ("Graph after quantization (before fusion) :" )
137- logging .info (model_gm .graph .print_tabular ())
122+ logging .info ("Graph after convert :" )
123+ logging .info (converted_model .graph .print_tabular ())
138124
139125 return converted_model
140126
@@ -192,10 +178,19 @@ def quantize_pt2(
192178 logging .info ("Graph after trace:" )
193179 logging .info (program .graph .print_tabular ())
194180
181+ # Get prepared graph module
182+ prepared_gm = prepare_pt2 (program , quantizer , dump_graphs = dump_graphs )
183+
184+ # Calibrate
185+ # If no calibration data is provided, use the inputs
186+ if calibration_data is None :
187+ calibration_data = [inputs ]
188+
189+ for samples in calibration_data :
190+ prepared_gm (* samples )
191+
195192 # Get converted graph module
196- converted_gm = prepare_and_convert_pt2 (
197- program , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
198- )
193+ converted_gm = convert_pt2 (prepared_gm , dump_graphs = dump_graphs )
199194
200195 # Get fused model
201196 fused_gm = fuse_pt2 (converted_gm , quantizer )
0 commit comments