Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from executorch.exir.passes import ToOutVarPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.program._program import _transform, to_edge
from executorch.exir.program._program import to_edge

from torch.export.exported_program import ExportedProgram
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e
Expand Down Expand Up @@ -145,22 +145,22 @@ def convert_pt2(
# fused model, to be able to get reference numerics.
# If this does not apply, please use quantize_pt2 instead.
def fuse_pt2(
converted_program: ExportedProgram,
converted_graph_module: torch.fx.GraphModule,
quantizer: CadenceQuantizer,
) -> ExportedProgram:
) -> torch.fx.GraphModule:
"""
Fuse a converted exported program using the given quantizer.
Fuse a converted graph module using the given quantizer.
The quantizer must be the same as the one used to convert the model.
If you do not expect that behavior, please use quantize_pt2 instead,
which will instantiate a default quantizer for you if needed.
Returns an ExportedProgram with the fused model.
Returns a GraphModule with the fused model.
"""
# Get patterns and apply fusion of dq -> op -> q to qop
# pyre-ignore[16]: no attribute
patterns = [q.pattern for q in quantizer.quantizers]
fused_program = _transform(converted_program, QuantFusion(patterns))
QuantFusion(patterns)(converted_graph_module)

return fused_program
return converted_graph_module


# Note: quantizer is not optional here to force the user to supply a quantizer
Expand Down Expand Up @@ -210,7 +210,7 @@ def quantize_pt2(
If calibration data is provided, it will be used to calibrate the model. If
not, the inputs will be used for calibration instead, which is useful for
unit tests but should not be used for end-to-end use cases.
Returns an ExportedProgram with the quantized model.
Returns a GraphModule with the quantized model.
Note: this function should not be called directly in general. Please use
quantize_and_export_to_executorch for most needs.
"""
Expand All @@ -227,15 +227,16 @@ def quantize_pt2(
dump_graphs=dump_graphs,
)

# Apply quant fusion to the exported program
program = torch.export.export(converted_gm, inputs, strict=True)
fused_program = fuse_pt2(program, quantizer)
# Get fused model
fused_gm = fuse_pt2(converted_gm, quantizer)

if dump_graphs:
logging.info("Graph after quantization and fusion:")
logging.info(fused_program.graph_module.graph.print_tabular())
logging.info(fused_gm.graph.print_tabular())

return fused_program
program = torch.export.export(fused_gm, inputs, strict=True)

return program


TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [
Expand Down
7 changes: 4 additions & 3 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ def export_model(
# Get reference outputs from converted model
ref_outputs = converted_model(*example_inputs)

ep = torch.export.export(converted_model, example_inputs, strict=True)
# Quantize the model (note: quantizer needs to be the same as
# the one used in prepare_and_convert_pt2)
quantized_model = fuse_pt2(converted_model, quantizer)

# Fuse the quantized patterns on the exported program (note: quantizer needs to be the same as the one used in prepare_and_convert_pt2)
ep = fuse_pt2(ep, quantizer)
ep = torch.export.export(quantized_model, example_inputs, strict=True)

# Get edge program after Cadence specific passes
exec_prog: ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord(
Expand Down
43 changes: 0 additions & 43 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
)
from executorch.backends.cadence.aot.quantizer.utils import (
check_out_zero_point_is_min_range,
copy_node_metadata,
create_zero_bias_int32,
find_sequential_partitions_aten,
get_conv_args,
Expand Down Expand Up @@ -160,8 +159,6 @@ def get_args_and_kwargs_layer_norm(
),
{"dtype": torch.float32},
)
if len(inputs_inputs) > 0:
copy_node_metadata(weight, inputs_inputs[0])

bias = other_inputs[2] if len(other_inputs) > 2 else None

Expand All @@ -174,8 +171,6 @@ def get_args_and_kwargs_layer_norm(
),
{"dtype": torch.float32},
)
if len(inputs_inputs) > 0:
copy_node_metadata(bias, inputs_inputs[0])

# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs + [scale, zero_point])
Expand Down Expand Up @@ -351,8 +346,6 @@ def get_args_and_kwargs_softmax(
),
{"dtype": torch.int32},
)
if len(inputs_inputs) > 0:
copy_node_metadata(mask_tensor, inputs_inputs[0])
# Make the scale and zero_point tensors
in_scale = dequants_inputs[0].args[1]
in_zero_point = dequants_inputs[0].args[2]
Expand Down Expand Up @@ -402,13 +395,10 @@ def get_args_and_kwargs_mixed_w8a32_conv(
torch.ops.aten.permute.default,
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
)
copy_node_metadata(transposed_inputs, other_inputs[0])

transposed_weights = graph_module.graph.call_function(
torch.ops.aten.permute.default,
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
)
copy_node_metadata(transposed_weights, weights_inputs[0])

args = (
transposed_inputs,
Expand Down Expand Up @@ -592,26 +582,6 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
torch.ops.aten.transpose.int,
(weights_inputs[0], 0, 1),
)
if "val" in weights_inputs[0].meta:
original_val = weights_inputs[0].meta["val"]
fake_mode = original_val.fake_mode
if fake_mode is not None:
with fake_mode:
transposed_val = torch.ops.aten.transpose.int(
original_val, 0, 1
)
transposed_weights.meta["val"] = transposed_val
else:
transposed_shape = list(original_val.shape)
transposed_shape[0], transposed_shape[1] = (
transposed_shape[1],
transposed_shape[0],
)
transposed_weights.meta["val"] = torch.zeros(
transposed_shape, dtype=original_val.dtype
)
copy_node_metadata(transposed_weights, weights_inputs[0])

# Call linear with transposed weight
args, kwargs = get_args_and_kwargs_linear(
graph_module,
Expand Down Expand Up @@ -684,19 +654,6 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901

legalize_graph(graph_module)
graph_module.graph.eliminate_dead_code()
nodes_list = list(graph_module.graph.nodes)

if len(nodes_list) > 0 and nodes_list[-1].op != "output":
output_nodes = [n for n in nodes_list if n.op == "output"]
output_arg = output_nodes[0].args[0]
original_meta = output_nodes[0].meta.copy()

for out_node in output_nodes:
graph_module.graph.erase_node(out_node)

new_output_node = graph_module.graph.output(output_arg)
new_output_node.meta.update(original_meta)

graph_module.recompile()
return PassResult(graph_module, True)

Expand Down
40 changes: 2 additions & 38 deletions backends/cadence/aot/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


def copy_node_metadata(dest_node: fx.Node, src_node: fx.Node) -> None:
for key in ["nn_module_stack", "stack_trace", "source_fn_stack"]:
if key in src_node.meta and src_node.meta[key]:
dest_node.meta[key] = src_node.meta[key]


def quantize_tensor_multiplier(
requantize_scale_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -120,45 +114,15 @@ def create_zero_bias_int32(
"""
Creates a zero bias tensor with the shape of weight[0]
"""
try:
attr_node = getattr(graph_module, weight_node.target)
except AttributeError:
if "val" in weight_node.meta:
attr_node = weight_node.meta["val"]
else:
param_dict = dict(graph_module.named_parameters())
if weight_node.target in param_dict:
attr_node = param_dict[weight_node.target]
else:
buffer_dict = dict(graph_module.named_buffers())
if weight_node.target in buffer_dict:
attr_node = buffer_dict[weight_node.target]
else:
raise AttributeError(
f"Could not find weight tensor for node {weight_node.target}. "
f"Node metadata keys: {list(weight_node.meta.keys())}"
)

attr_node = getattr(graph_module, weight_node.target)
weight_shape = list(attr_node.shape)
bias_shape = weight_shape[0]
new_node = graph_module.graph.call_function(
return graph_module.graph.call_function(
torch.ops.aten.full.default,
([bias_shape], 0.0),
{"dtype": torch.int32},
)

if "val" in weight_node.meta:
fake_mode = weight_node.meta["val"].fake_mode
if fake_mode is not None:
with fake_mode:
fake_bias = torch.zeros([bias_shape], dtype=torch.int32)
new_node.meta["val"] = fake_bias
else:
new_node.meta["val"] = torch.zeros([bias_shape], dtype=torch.int32)
copy_node_metadata(new_node, weight_node)

return new_node


def get_bias_qparams(
obs_or_fqs: List[ObserverOrFakeQuantize],
Expand Down
5 changes: 2 additions & 3 deletions util/activation_memory_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ def _get_module_hierarchy(node: torch.fx.Node) -> str:
Get the module hierarchy of the given node.
"""
module_stack = node.meta.get("nn_module_stack")
if module_stack is not None and module_stack:
if module_stack is not None:
module_values_list = list(module_stack.values())
if module_values_list:
return module_values_list[-1][0]
return module_values_list[-1][0]
return ""


Expand Down
Loading