diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 3dd612e650e..9c454f4339f 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -38,7 +38,7 @@ ) from executorch.exir.passes import ToOutVarPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import _transform, to_edge +from executorch.exir.program._program import to_edge from torch.export.exported_program import ExportedProgram from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e @@ -145,22 +145,22 @@ def convert_pt2( # fused model, to be able to get reference numerics. # If this does not apply, please use quantize_pt2 instead. def fuse_pt2( - converted_program: ExportedProgram, + converted_graph_module: torch.fx.GraphModule, quantizer: CadenceQuantizer, -) -> ExportedProgram: +) -> torch.fx.GraphModule: """ - Fuse a converted exported program using the given quantizer. + Fuse a converted graph module using the given quantizer. The quantizer must be the same as the one used to convert the model. If you do not expect that behavior, please use quantize_pt2 instead, which will instantiate a default quantizer for you if needed. - Returns an ExportedProgram with the fused model. + Returns a GraphModule with the fused model. """ # Get patterns and apply fusion of dq -> op -> q to qop # pyre-ignore[16]: no attribute patterns = [q.pattern for q in quantizer.quantizers] - fused_program = _transform(converted_program, QuantFusion(patterns)) + QuantFusion(patterns)(converted_graph_module) - return fused_program + return converted_graph_module # Note: quantizer is not optional here to force the user to supply a quantizer @@ -210,7 +210,7 @@ def quantize_pt2( If calibration data is provided, it will be used to calibrate the model. If not, the inputs will be used for calibration instead, which is useful for unit tests but should not be used for end-to-end use cases. - Returns an ExportedProgram with the quantized model. + Returns a GraphModule with the quantized model. Note: this function should not be called directly in general. Please use quantize_and_export_to_executorch for most needs. """ @@ -227,15 +227,16 @@ def quantize_pt2( dump_graphs=dump_graphs, ) - # Apply quant fusion to the exported program - program = torch.export.export(converted_gm, inputs, strict=True) - fused_program = fuse_pt2(program, quantizer) + # Get fused model + fused_gm = fuse_pt2(converted_gm, quantizer) if dump_graphs: logging.info("Graph after quantization and fusion:") - logging.info(fused_program.graph_module.graph.print_tabular()) + logging.info(fused_gm.graph.print_tabular()) - return fused_program + program = torch.export.export(fused_gm, inputs, strict=True) + + return program TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [ diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 20719322e82..6af7a88fdc2 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -63,10 +63,11 @@ def export_model( # Get reference outputs from converted model ref_outputs = converted_model(*example_inputs) - ep = torch.export.export(converted_model, example_inputs, strict=True) + # Quantize the model (note: quantizer needs to be the same as + # the one used in prepare_and_convert_pt2) + quantized_model = fuse_pt2(converted_model, quantizer) - # Fuse the quantized patterns on the exported program (note: quantizer needs to be the same as the one used in prepare_and_convert_pt2) - ep = fuse_pt2(ep, quantizer) + ep = torch.export.export(quantized_model, example_inputs, strict=True) # Get edge program after Cadence specific passes exec_prog: ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord( diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 4c34adc0533..e2818f725ef 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -33,7 +33,6 @@ ) from executorch.backends.cadence.aot.quantizer.utils import ( check_out_zero_point_is_min_range, - copy_node_metadata, create_zero_bias_int32, find_sequential_partitions_aten, get_conv_args, @@ -160,8 +159,6 @@ def get_args_and_kwargs_layer_norm( ), {"dtype": torch.float32}, ) - if len(inputs_inputs) > 0: - copy_node_metadata(weight, inputs_inputs[0]) bias = other_inputs[2] if len(other_inputs) > 2 else None @@ -174,8 +171,6 @@ def get_args_and_kwargs_layer_norm( ), {"dtype": torch.float32}, ) - if len(inputs_inputs) > 0: - copy_node_metadata(bias, inputs_inputs[0]) # Make the args and kwargs for the replacement op args = tuple(inputs_inputs + [scale, zero_point]) @@ -351,8 +346,6 @@ def get_args_and_kwargs_softmax( ), {"dtype": torch.int32}, ) - if len(inputs_inputs) > 0: - copy_node_metadata(mask_tensor, inputs_inputs[0]) # Make the scale and zero_point tensors in_scale = dequants_inputs[0].args[1] in_zero_point = dequants_inputs[0].args[2] @@ -402,13 +395,10 @@ def get_args_and_kwargs_mixed_w8a32_conv( torch.ops.aten.permute.default, (other_inputs[0], [0, 2, 1]), # NCL -> NLC ) - copy_node_metadata(transposed_inputs, other_inputs[0]) - transposed_weights = graph_module.graph.call_function( torch.ops.aten.permute.default, (weights_inputs[0], [2, 0, 1]), # NCL -> LNC ) - copy_node_metadata(transposed_weights, weights_inputs[0]) args = ( transposed_inputs, @@ -592,26 +582,6 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 torch.ops.aten.transpose.int, (weights_inputs[0], 0, 1), ) - if "val" in weights_inputs[0].meta: - original_val = weights_inputs[0].meta["val"] - fake_mode = original_val.fake_mode - if fake_mode is not None: - with fake_mode: - transposed_val = torch.ops.aten.transpose.int( - original_val, 0, 1 - ) - transposed_weights.meta["val"] = transposed_val - else: - transposed_shape = list(original_val.shape) - transposed_shape[0], transposed_shape[1] = ( - transposed_shape[1], - transposed_shape[0], - ) - transposed_weights.meta["val"] = torch.zeros( - transposed_shape, dtype=original_val.dtype - ) - copy_node_metadata(transposed_weights, weights_inputs[0]) - # Call linear with transposed weight args, kwargs = get_args_and_kwargs_linear( graph_module, @@ -684,19 +654,6 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 legalize_graph(graph_module) graph_module.graph.eliminate_dead_code() - nodes_list = list(graph_module.graph.nodes) - - if len(nodes_list) > 0 and nodes_list[-1].op != "output": - output_nodes = [n for n in nodes_list if n.op == "output"] - output_arg = output_nodes[0].args[0] - original_meta = output_nodes[0].meta.copy() - - for out_node in output_nodes: - graph_module.graph.erase_node(out_node) - - new_output_node = graph_module.graph.output(output_arg) - new_output_node.meta.update(original_meta) - graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index dfc31bfac8c..68fc6740cb4 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -24,12 +24,6 @@ from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY -def copy_node_metadata(dest_node: fx.Node, src_node: fx.Node) -> None: - for key in ["nn_module_stack", "stack_trace", "source_fn_stack"]: - if key in src_node.meta and src_node.meta[key]: - dest_node.meta[key] = src_node.meta[key] - - def quantize_tensor_multiplier( requantize_scale_tensor: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -120,45 +114,15 @@ def create_zero_bias_int32( """ Creates a zero bias tensor with the shape of weight[0] """ - try: - attr_node = getattr(graph_module, weight_node.target) - except AttributeError: - if "val" in weight_node.meta: - attr_node = weight_node.meta["val"] - else: - param_dict = dict(graph_module.named_parameters()) - if weight_node.target in param_dict: - attr_node = param_dict[weight_node.target] - else: - buffer_dict = dict(graph_module.named_buffers()) - if weight_node.target in buffer_dict: - attr_node = buffer_dict[weight_node.target] - else: - raise AttributeError( - f"Could not find weight tensor for node {weight_node.target}. " - f"Node metadata keys: {list(weight_node.meta.keys())}" - ) - + attr_node = getattr(graph_module, weight_node.target) weight_shape = list(attr_node.shape) bias_shape = weight_shape[0] - new_node = graph_module.graph.call_function( + return graph_module.graph.call_function( torch.ops.aten.full.default, ([bias_shape], 0.0), {"dtype": torch.int32}, ) - if "val" in weight_node.meta: - fake_mode = weight_node.meta["val"].fake_mode - if fake_mode is not None: - with fake_mode: - fake_bias = torch.zeros([bias_shape], dtype=torch.int32) - new_node.meta["val"] = fake_bias - else: - new_node.meta["val"] = torch.zeros([bias_shape], dtype=torch.int32) - copy_node_metadata(new_node, weight_node) - - return new_node - def get_bias_qparams( obs_or_fqs: List[ObserverOrFakeQuantize], diff --git a/util/activation_memory_profiler.py b/util/activation_memory_profiler.py index caf4dc1380b..80e4fac56e2 100644 --- a/util/activation_memory_profiler.py +++ b/util/activation_memory_profiler.py @@ -41,10 +41,9 @@ def _get_module_hierarchy(node: torch.fx.Node) -> str: Get the module hierarchy of the given node. """ module_stack = node.meta.get("nn_module_stack") - if module_stack is not None and module_stack: + if module_stack is not None: module_values_list = list(module_stack.values()) - if module_values_list: - return module_values_list[-1][0] + return module_values_list[-1][0] return ""