From 7ae1a951c55071a622823de51227f4d60b318e22 Mon Sep 17 00:00:00 2001 From: ssjia Date: Sun, 7 Sep 2025 22:02:39 -0400 Subject: [PATCH] [ET-VK][AOT] Enable exporting Q8 Quantized Linear + Convolution Pull Request resolved: https://github.com/pytorch/executorch/pull/13818 As title. Introduce fusion patterns to enable fusing quantized convolution and linear graph patterns into a custom op. ## Changes Introduce the concept of using custom pattern detection functions to detect graph patterns rather than solely relying on SubgraphMatcher. The issue with SubgraphMatcher is that a large number of graph patterns may need to be exported to obtain variants for different combinations of decompositions/quantization workflows. Having a custom detection function improves maintainability. Implement detection + replacement functions for quantized linear and quantized conv2d. ghstack-source-id: 308092874 @exported-using-ghexport Differential Revision: [D81323425](https://our.internmc.facebook.com/intern/diff/D81323425/) --- .github/workflows/pull.yml | 4 + backends/vulkan/_passes/TARGETS | 14 + backends/vulkan/_passes/__init__.py | 2 + backends/vulkan/_passes/fold_qdq.py | 41 ++ backends/vulkan/custom_ops_lib.py | 131 +++++ backends/vulkan/op_registry.py | 40 ++ .../vulkan/partitioner/vulkan_partitioner.py | 9 +- backends/vulkan/patterns/TARGETS | 1 + backends/vulkan/patterns/__init__.py | 41 +- backends/vulkan/patterns/pattern_registry.py | 64 ++- .../vulkan/patterns/quantized_convolution.py | 224 ++++++++ backends/vulkan/patterns/quantized_linear.py | 480 ++++++++++-------- backends/vulkan/patterns/rope.py | 16 +- backends/vulkan/test/TARGETS | 21 + backends/vulkan/test/test_vulkan_delegate.py | 328 ++++++++++-- backends/vulkan/test/utils.py | 256 ++++++++-- backends/vulkan/utils.py | 164 +++++- backends/vulkan/vulkan_preprocess.py | 2 + 18 files changed, 1556 insertions(+), 282 deletions(-) create mode 100644 backends/vulkan/_passes/fold_qdq.py create mode 100644 backends/vulkan/patterns/quantized_convolution.py diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 1500fceebb2..37c6623ca97 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -934,6 +934,10 @@ jobs: ./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear ./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d + # Run e2e testing for selected operators. More operators will be tested via this + # route in the future. + python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*" + nxp-build-test: name: nxp-build-test uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 3263d273b72..8558a2eea93 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -118,6 +118,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "fold_qdq", + srcs = ["fold_qdq.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan:utils_lib", + "//executorch/exir:pass_base", + ], +) + runtime.python_library( name = "fuse_patterns", srcs = ["fuse_patterns.py"], @@ -144,6 +157,7 @@ runtime.python_library( "//executorch/examples/...", ], deps = [ + ":fold_qdq", ":fuse_patterns", ":fuse_quantized_ops", ":insert_prepack_nodes", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index ccf15fd2c7f..2c4588ac43d 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -6,6 +6,7 @@ # pyre-strict +from executorch.backends.vulkan._passes.fold_qdq import FoldQDQPass from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.fuse_quantized_ops import ( FuseQuantizedOpsTransform, @@ -30,6 +31,7 @@ from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ + "FoldQDQPass", "FusePatternsPass", "FuseQuantizedOpsTransform", "insert_prepack_nodes", diff --git a/backends/vulkan/_passes/fold_qdq.py b/backends/vulkan/_passes/fold_qdq.py new file mode 100644 index 00000000000..3beccc2205c --- /dev/null +++ b/backends/vulkan/_passes/fold_qdq.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.vulkan.utils as utils +import torch + +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + + +class FoldQDQPass(ExportPass): + """ + Erase Q/DQ chain introduced by PT2E quantization workflow. It is assumed that all + valid quant op patterns have already been fused before this pass. + """ + + def __init__(self, edge_program: torch.export.ExportedProgram): + super(FoldQDQPass, self).__init__() + self.edge_program = edge_program + + def call(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if utils.is_quant_node(node): + original_node = node.args[0] + assert isinstance(original_node, torch.fx.Node) + # For each direct user that is a dequant node, connect the original + # node to the users of the dequant node. + for user in node.users: + if utils.is_dequant_node(user): + dq_node = user + dq_node.replace_all_uses_with(original_node) + + graph_module.recompile() + dead_code_elimination_pass(graph_module) + # Re-trace to validate everything is ok + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index bc61b44ce78..3ef3a6b45ea 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + import executorch.backends.vulkan.patterns as vk_patterns import torch.library @@ -321,6 +323,135 @@ def linear_qta8a_qga4w( lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd") linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name) +################# +## qaqw_linear ## +################# + + +def linear_q8ta_q8csw( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + bias: Optional[torch.Tensor] = None, +): + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales, + weight_zeros, + 0, + -127, + 127, + torch.int8, + ) + + # Perform linear operation + out = torch.nn.functional.linear(x, weights) + if bias is not None: + out = out + bias + + return out + + +name = "linear_q8ta_q8csw" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + Tensor? bias = None) -> Tensor + """ +) +lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd") +qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name) + +################## +## conv2d_q8ta_q8csw ## +################## + + +def conv2d_q8ta_q8csw( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + bias: Optional[torch.Tensor], + kernel_size: list, + stride: list, + padding: list, + dilation: list, + groups: int, +): + IC = x.shape[1] + K_h, K_w = kernel_size[0], kernel_size[1] + + canonical_weight_K_dim = K_h * K_w * IC + # Remove any padding added to output channels dim to align to a multiple of 4 + if weights.shape[-1] != canonical_weight_K_dim: + weights = weights[:, :canonical_weight_K_dim] + weight_scales = weight_scales[:canonical_weight_K_dim] + if bias is not None: + bias = bias[:canonical_weight_K_dim] + + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + + # Calculate dimensions + OC = weights.shape[0] + in_features = weights.shape[1] + IC = in_features // (K_h * K_w) + + # Reshape to original 4D format (OC, IC, H, W) + weights = weights.view(OC, IC, K_h, K_w) + + # Dequantize weights + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales, + weight_zeros, + 0, # axis=0 for output channel quantization + -127, + 127, + torch.int8, + ) + + # Perform convolution + out = torch.nn.functional.conv2d( + x, weights, bias, stride, padding, dilation, groups + ) + + return out + + +name = "conv2d_q8ta_q8csw" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + Tensor? bias, + SymInt[] kernel_size, + SymInt[] stride, + SymInt[] padding, + SymInt[] dilation, + SymInt groups) -> Tensor + """ +) +lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd") +conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name) + ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 55c36463b51..f9f8aeb79e3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -318,6 +318,19 @@ def register_int8_mm_op(): ) +@update_features( + [ + exir_ops.edge.et_vk.linear_q8ta_q8csw.default, + ] +) +def register_qa_qw_linear(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_prepacking=True, + supports_resize=False, + ) + + @update_features( [ exir_ops.edge.et_vk.linear_weight_int4.default, @@ -457,6 +470,33 @@ def register_convolution_op(): ) +@update_features( + [ + exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default, + ] +) +def register_quantized_conv_op(): + return OpFeatures( + inputs_storage=[ + utils.CHANNELS_PACKED_TEXTURE, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # kernel_size (non tensor) + utils.NO_STORAGE, # stride (non tensor) + utils.NO_STORAGE, # padding (non tensor) + utils.NO_STORAGE, # dilation (non tensor) + utils.NO_STORAGE, # groups (non tensor) + utils.NO_STORAGE, # original OC count (non tensor) + ], + supports_resize=False, + supports_prepacking=True, + ) + + @update_features("llama::sdpa_with_kv_cache") def register_sdpa_with_kv_cache_op(): return OpFeatures( diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 06db2a58f12..e5b2d0f7864 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -22,6 +22,8 @@ vulkan_supported_ops, ) +from executorch.backends.vulkan.patterns import PatternMatch + from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, @@ -41,7 +43,6 @@ from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase -from torch.fx.passes.utils.matcher_utils import InternalMatch # pyre-ignore ops_not_to_decompose = [ @@ -60,7 +61,7 @@ def __init__( require_dynamic_shape: bool = False, operator_blocklist: Optional[Set[OpKey]] = None, operator_allowlist: Optional[Set[OpKey]] = None, - fusable_subgraphs: Optional[List[InternalMatch]] = None, + fusable_subgraphs: Optional[List[PatternMatch]] = None, nn_module_blocklist: Optional[Set[str]] = None, nn_module_allowlist: Optional[Set[str]] = None, ) -> None: @@ -72,13 +73,13 @@ def __init__( operator_blocklist if operator_blocklist is not None else set() ) self.operator_allowlist = operator_allowlist - self.fusable_subgraphs: List[InternalMatch] = ( + self.fusable_subgraphs: List[PatternMatch] = ( fusable_subgraphs if fusable_subgraphs is not None else [] ) # Create a set of all nodes that are part of fusable subgraphs for quick lookup self.fusable_nodes: Set[torch.fx.Node] = set() for match in self.fusable_subgraphs: - self.fusable_nodes.update(match.nodes_map.values()) + self.fusable_nodes.update(match.all_nodes) self.nn_module_blocklist = nn_module_blocklist self.nn_module_allowlist = nn_module_allowlist diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS index f58ff4e9adf..791edf58984 100644 --- a/backends/vulkan/patterns/TARGETS +++ b/backends/vulkan/patterns/TARGETS @@ -10,6 +10,7 @@ runtime.python_library( "pattern_registry.py", "rope.py", "quantized_linear.py", + "quantized_convolution.py", ], visibility = [ "//executorch/backends/...", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index b8026f517e6..8ffad98b3c3 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -6,6 +6,8 @@ from typing import List +import executorch.backends.vulkan.patterns.quantized_convolution # noqa + import executorch.backends.vulkan.patterns.quantized_linear # noqa import executorch.backends.vulkan.patterns.rope # noqa @@ -13,9 +15,13 @@ import torch from executorch.backends.vulkan.patterns.pattern_registry import ( + create_pattern_match_from_internal_match, CreateReplacementFn, + DetectorFn, fusable_patterns, GetGraphFn, + PatternMatch, + register_pattern_detector, register_pattern_graph, register_pattern_replacement, ) @@ -24,15 +30,18 @@ from executorch.exir import ExportedProgram -from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher __all__ = [ + "PatternMatch", "GetGraphFn", + "DetectorFn", "CreateReplacementFn", "RotaryEmbeddingPattern", "fusable_patterns", "register_pattern_graph", + "register_pattern_detector", "register_pattern_replacement", ] @@ -48,14 +57,22 @@ def all_fusable_graph_patterns() -> List[torch.fx.GraphModule]: def get_all_fusable_subgraphs( graph_module: torch.fx.GraphModule, -) -> List[InternalMatch]: +) -> List[PatternMatch]: fusable_subgraphs = [] fuse_patterns = all_fusable_graph_patterns() for pattern in fuse_patterns: sm = SubgraphMatcher(pattern.graph, ignore_literals=True) matches = list(sm.match(graph_module.graph)) - fusable_subgraphs.extend(matches) + for match in matches: + fusable_subgraphs.append(create_pattern_match_from_internal_match(match)) + + for node in graph_module.graph.nodes: + for entry in fusable_patterns.values(): + if entry.detector_fn is not None: + maybe_match = entry.detector_fn(node) + if maybe_match is not None: + fusable_subgraphs.append(maybe_match) return fusable_subgraphs @@ -73,7 +90,8 @@ def create_replacement_for_pattern( matches = list(sm.match(graph_module.graph)) for partition_to_replace in matches: - create_replacement_func(ep, graph_module, partition_to_replace) + pattern = create_pattern_match_from_internal_match(partition_to_replace) + create_replacement_func(ep, graph_module, pattern) total_replaced += 1 # Remove dead code so they won't be matched again graph_module.graph.eliminate_dead_code() @@ -87,6 +105,7 @@ def replace_all_fusable_subgraphs( ) -> int: total_replaced = 0 + # Handle patterns identified with SubgraphMatcher for entry in fusable_patterns.values(): if entry.get_graphs_fn is not None and entry.create_replacement_fn is not None: total_replaced += create_replacement_for_pattern( @@ -97,4 +116,18 @@ def replace_all_fusable_subgraphs( entry.create_replacement_fn, ) + # Handle patterns identified with custom detector function + for node in graph_module.graph.nodes: + for entry in fusable_patterns.values(): + if ( + entry.detector_fn is not None + and entry.create_replacement_fn is not None + ): + maybe_match = entry.detector_fn(node) + if maybe_match is not None: + assert entry.create_replacement_fn is not None + entry.create_replacement_fn(ep, graph_module, maybe_match) + total_replaced += 1 + + graph_module.graph.eliminate_dead_code() return total_replaced diff --git a/backends/vulkan/patterns/pattern_registry.py b/backends/vulkan/patterns/pattern_registry.py index 37fa0bcca8c..9a906cd8770 100644 --- a/backends/vulkan/patterns/pattern_registry.py +++ b/backends/vulkan/patterns/pattern_registry.py @@ -13,22 +13,65 @@ from torch.fx.passes.utils.matcher_utils import InternalMatch GetGraphFn = Callable[[], List[torch.fx.GraphModule]] + + +class PatternMatch: + __slots__ = ("input_nodes", "output_nodes", "all_nodes", "anchor_node") + """ + The design of this class is based on InternalMatch from + torch.fx.passes.utils.matcher_utils. It represents nodes in a graph that + match a particular pattern. + + The reason to not use InternalMatch directly is to enable more (i.e. custom) + methods to detect and represent matches other than through SubgraphMatcher. + """ + + def __init__( + self, + input_nodes: List[torch.fx.Node], + output_nodes: List[torch.fx.Node], + all_nodes: List[torch.fx.Node], + anchor_node: Optional[torch.fx.Node] = None, + ): + self.input_nodes = input_nodes + self.output_nodes = output_nodes + self.all_nodes = all_nodes + self.anchor_node = anchor_node + + +def create_pattern_match_from_internal_match( + internal_match: InternalMatch, +) -> PatternMatch: + return PatternMatch( + internal_match.placeholder_nodes, + internal_match.returning_nodes, + list(internal_match.nodes_map.values()), + ) + + CreateReplacementFn = Callable[ - [ExportedProgram, torch.fx.GraphModule, InternalMatch], None + [ExportedProgram, torch.fx.GraphModule, PatternMatch], None ] +DetectorFn = Callable[[torch.fx.Node], Optional[PatternMatch]] + + class PatternEntry: def __init__( self, get_graphs_fn: Optional[GetGraphFn] = None, + detector_fn: Optional[DetectorFn] = None, create_replacement_fn: Optional[CreateReplacementFn] = None, ): self.get_graphs_fn = get_graphs_fn + self.detector_fn = detector_fn self.create_replacement_fn = create_replacement_fn def is_valid(self): - return self.get_graphs_fn is not None and self.create_replacement_fn is not None + return ( + self.get_graphs_fn is not None or self.detector_fn is not None + ) and self.create_replacement_fn is not None fusable_patterns: Dict[str, PatternEntry] = {} @@ -39,7 +82,24 @@ def decorator(fn: GetGraphFn): if pattern_name not in fusable_patterns: fusable_patterns[pattern_name] = PatternEntry() + # Cannot define both get_graphs_fn and detector_fn + assert fusable_patterns[pattern_name].detector_fn is None fusable_patterns[pattern_name].get_graphs_fn = fn + + return fn + + return decorator + + +def register_pattern_detector(pattern_name: str): + def decorator(fn: DetectorFn): + if pattern_name not in fusable_patterns: + fusable_patterns[pattern_name] = PatternEntry() + + # Cannot define both get_graphs_fn and detector_fn + assert fusable_patterns[pattern_name].get_graphs_fn is None + fusable_patterns[pattern_name].detector_fn = fn + return fn return decorator diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py new file mode 100644 index 00000000000..65b51b5e103 --- /dev/null +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + get_param_tensor, +) + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.export.graph_signature import InputKind + + +class QuantizedConvolutionMatch(PatternMatch): + def __init__(self, conv_node: torch.fx.Node) -> None: + self.anchor_node = conv_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # Extract convolution parameters + self.stride = conv_node.args[3] if len(conv_node.args) > 3 else [1, 1] + self.padding = conv_node.args[4] if len(conv_node.args) > 4 else [0, 0] + self.dilation = conv_node.args[5] if len(conv_node.args) > 5 else [1, 1] + self.groups = conv_node.args[8] if len(conv_node.args) > 8 else 1 + + const_node, arg_chain = utils.trace_args_until_placeholder( + self.anchor_node.args[1] + ) + + # weight is not a constant tensor - no match + if const_node is None: + return + + dequantize_weight_node = None + # Search for a dequantize node in the arg chain of weight + for node in arg_chain: + if isinstance(node, torch.fx.Node) and utils.is_dequant_node(node): + dequantize_weight_node = node + # weight is not quantized - no match + if dequantize_weight_node is None: + return + + self.weight_node = const_node + self.dequantize_weight_node = dequantize_weight_node + self.all_nodes.extend(arg_chain) + + # Identify weight quantization parameter nodes + self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder( + self.dequantize_weight_node.args[1] + ) + assert self.weight_scales_node is not None + self.all_nodes.extend(arg_chain) + + self.weight_zeros_node, arg_chain = utils.trace_args_until_placeholder( + self.dequantize_weight_node.args[2] + ) + assert self.weight_zeros_node is not None + self.all_nodes.extend(arg_chain) + + # Identify output node + self.output_node = self.anchor_node + + out_channels = self.output_node.meta["val"].shape[-1] + # The implementation requires that for grouped convolutions, a group does not + # cross any texel boundary. The output channels per group must be a multiple of + # 4. If this is not true, then don't match the pattern. + if self.groups > 1 and (out_channels / self.groups) % 4 == 0: + return + + # Identify bias node, if applicable + self.bias_node = None + if len(self.anchor_node.args) > 2 and self.anchor_node.args[2] is not None: + self.bias_node, arg_chain = utils.trace_args_until_placeholder( + self.anchor_node.args[2] + ) + if self.bias_node is not None: + self.all_nodes.extend(arg_chain) + + # Identify input node + self.fp_input_node, self.quantize_input_node, dq_node = ( + utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) + ) + assert self.fp_input_node is not None + self.all_nodes.append(self.fp_input_node) + assert self.quantize_input_node is not None + assert dq_node is not None + + self.input_scales_node = self.quantize_input_node.args[1] + self.input_zeros_node = self.quantize_input_node.args[2] + + self.all_nodes.extend( + [ + self.quantize_input_node, + dq_node, + ] + ) + + self.match_found = True + + +convolution_anchor_nodes = { + exir_ops.edge.aten.conv2d.default, + exir_ops.edge.aten.convolution.default, +} + + +@register_pattern_detector("quantized_convolution") +def find_quantized_convolution_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedConvolutionMatch]: + if node.target not in convolution_anchor_nodes: + return None + + matched_pattern = QuantizedConvolutionMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("quantized_convolution") +def make_conv2d_q8ta_q8csw_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedConvolutionMatch, +): + weight_tensor = get_param_tensor(ep, match.weight_node) + assert weight_tensor is not None + + assert match.weight_scales_node is not None + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + + assert match.weight_zeros_node is not None + weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node) + assert weight_zeros_tensor is not None + + bias_tensor = None + if match.bias_node is not None: + bias_tensor = get_param_tensor(ep, match.bias_node) + assert bias_tensor is not None + + OC, IC, H, W = weight_tensor.shape + + # Reshape weight tensor from (OC, IC, H, W) to (OC, H * W * IC) (i.e. matrix format) + # This prepares the weights for Im2Col-based convolution + weight_tensor = ( + weight_tensor.permute(0, 2, 3, 1).contiguous().view(OC, H * W * IC).contiguous() + ) + + # Need to make sure that OC dim is a multiple of 4 so that data load/stores are well + # aligned with texel boundaries. Add padding to align to the next multiple of 4 if + # needed. + utils.align_width_and_update_state_dict( + ep, match.weight_node, weight_tensor, force_update=True + ) + utils.align_width_and_update_state_dict( + ep, match.weight_scales_node, weight_scales_tensor + ) + if bias_tensor is not None: + utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor) + + first_graph_node = list(graph_module.graph.nodes)[0] + with graph_module.graph.inserting_before(first_graph_node): + qweight_tensor_name = utils.get_tensor_name(ep, match.weight_node) + # Pre-compute the weight sums which are needed to apply activation zero point + # when using integer accumulation. For the reshaped 2D weight matrix (IC * H * W, OC), + # sum over dimension 0 to get sums per output channel + sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + sums_name = qweight_tensor_name + "_sums" + # Sanitize the name + sums_name = sums_name.replace(".", "_") + + weight_sums_node = create_constant_placeholder( + exp_program=ep, + graph=graph_module.graph, + kind=InputKind.CONSTANT_TENSOR, + name=sums_name, + data=sum_per_output_channel, + ) + + with graph_module.graph.inserting_before(match.output_node): + qconv_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default, + args=( + match.fp_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + match.bias_node, # Add bias after weight_scales + [H, W], # Pass kernel size information before stride + match.stride, + match.padding, + match.dilation, + match.groups, + ), + ) + + qconv_node.meta["val"] = match.output_node.meta["val"] + match.output_node.replace_all_uses_with(qconv_node) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 34476adeeb4..514abd78bf4 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -4,131 +4,191 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from functools import lru_cache -from typing import Callable, List, Optional +from typing import Optional import executorch.backends.vulkan.utils as utils import torch import torch.nn.functional as F -from executorch.backends.transforms.utils import get_param_tensor, is_param_node +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + get_param_tensor, +) from executorch.backends.vulkan.patterns.pattern_registry import ( - register_pattern_graph, + PatternMatch, + register_pattern_detector, register_pattern_replacement, ) -from executorch.exir import EdgeCompileConfig, ExportedProgram, to_edge +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from torch.export import export -from torch.fx.passes.utils.matcher_utils import InternalMatch +from torch.export.graph_signature import InputKind -from torchao.quantization.granularity import PerGroup -from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ -from torchao.utils import unwrap_tensor_subclass +class QuantizedLinearMatch(PatternMatch): + def __init__(self, mm_node: torch.fx.Node) -> None: + self.anchor_node = mm_node + self.match_found = False + self.all_nodes = [self.anchor_node] -class TorchAOWeightOnlyQuantizedLinearPattern(torch.nn.Module): - """ - Quantized linear pattern produced when quantizing linear layers using - `torchao.quantization.quant_api.quantize_()` with IntxWeightOnlyConfig. - """ + const_node, arg_chain = utils.trace_args_until_placeholder( + self.anchor_node.args[1] + ) + + # mat2 is not a constant tensor - no match + if const_node is None: + return + + dequantize_weight_node = None + # Search for a dequantize node in the arg chain of weight + for node in arg_chain: + if isinstance(node, torch.fx.Node) and utils.is_dequant_node(node): + dequantize_weight_node = node + # weight is not quantized - no match + if dequantize_weight_node is None: + return + + self.weight_node = const_node + self.dequantize_weight_node = dequantize_weight_node + self.all_nodes.extend(arg_chain) + + # By default, assume dequant node is from quantized_decomposed namespace + scales_arg_idx = 1 + zeros_arg_idx = 2 + # torchao dequantize has a different function schema than quantized_decomposed + if ( + self.dequantize_weight_node.target + == exir_ops.edge.torchao.dequantize_affine.default + ): + scales_arg_idx = 2 + zeros_arg_idx = 3 + + # Identify weight quantization parameter nodes + self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder( + self.dequantize_weight_node.args[scales_arg_idx] + ) + assert self.weight_scales_node is not None + self.all_nodes.extend(arg_chain) + + self.weight_zeros_node, arg_chain = utils.trace_args_until_placeholder( + self.dequantize_weight_node.args[zeros_arg_idx] + ) + assert self.weight_zeros_node is not None + self.all_nodes.extend(arg_chain) + + # Identify output node + self.output_node = self.anchor_node + + # The implementation has a limitation that output channels must be a + # multiple of 4. This is to ensure that data loads are aligned well with + # texel boundaries. If this is not true, then don't match the pattern. + out_channels = self.output_node.meta["val"].shape[-1] + if out_channels % 4 != 0: + return + + # Identify input node + self.fp_input_node, self.quantize_input_node, dq_node = ( + utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) + ) + assert self.fp_input_node is not None + self.all_nodes.append(self.fp_input_node) + + # The implementation has a limitation that input channels must be a + # multiple of 4. This is to ensure that data loads are aligned well with + # texel boundaries. If this is not true, then don't match the pattern. + in_channels = self.fp_input_node.meta["val"].shape[-1] + if in_channels % 4 != 0: + return + + # Identify bias node, if applicable + self.bias_node = None + if self.anchor_node.target == exir_ops.edge.aten.addmm.default: + self.bias_node, arg_chain = utils.trace_args_until_placeholder( + self.anchor_node.args[2] + ) + assert self.bias_node is not None + self.all_nodes.extend(arg_chain) + + # If input is not quantized, then we are done + if self.quantize_input_node is None: + self.match_found = True + return + + self.input_scales_node = self.quantize_input_node.args[1] + self.input_zeros_node = self.quantize_input_node.args[2] + + assert dq_node is not None + self.all_nodes.extend( + [ + self.quantize_input_node, + dq_node, + ] + ) - def __init__( - self, - in_features: int = 512, - out_features: int = 256, - bias: bool = False, - group_size: int = 64, - weight_bits: int = 4, - granularity_class: Optional[Callable] = None, - ) -> None: - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features, bias=bias) - self.group_size = group_size - self.weight_bits = weight_bits - - if self.weight_bits == 4: - # pyre-ignore[16] - self.weight_dtype = torch.int4 - else: - self.weight_dtype = torch.int8 - - if granularity_class is not None: - self.quant_granularity = granularity_class(self.group_size) - else: - self.quant_granularity = PerGroup(self.group_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - - def apply_quantization(self): - q_config = IntxWeightOnlyConfig( - weight_dtype=self.weight_dtype, - granularity=self.quant_granularity, + self.match_found = True + + def is_weight_only_quantized(self) -> bool: + return self.quantize_input_node is None + + def is_weight_pergroup_quantized(self) -> bool: + weight_shape = self.weight_node.meta["val"].shape + scales_shape = self.weight_scales_node.meta["val"].shape + if len(scales_shape) != 2: + return False + + # Check that: + # height dim of scales is same as height dim of weight (N / output channels dim) + # width dim of weight (K / in channels dim) is divisible by width dim of scales + # (number of quantization groups) + return scales_shape[-2] == weight_shape[-2] and ( + weight_shape[-1] % scales_shape[-1] == 0 ) - quantize_(self, q_config) - unwrap_tensor_subclass(self) - return self - - -@lru_cache(maxsize=None) -@register_pattern_graph("torchao_wo_quantized_linear") -def get_torchao_wo_quantized_linear_graphs() -> List[torch.fx.GraphModule]: - graphs = [] - - # Different configurations to test - configs = [ - # gemv pattern - (1, 1, 128, 128, False, 64, 4, PerGroup), - # gemm pattern - (1, 8, 128, 128, False, 64, 4, PerGroup), - ] - - for ( - batch_size, - seq_len, - in_features, - out_features, - bias, - group_size, - weight_bits, - granularity_class, - ) in configs: - for dtype in [torch.float32]: - xs = [] - xs.append(torch.randn(batch_size, seq_len, in_features, dtype=dtype)) - if batch_size == 1: - xs.append(torch.randn(seq_len, in_features, dtype=dtype)) - - for x in xs: - # Create and quantize the pattern - pattern = TorchAOWeightOnlyQuantizedLinearPattern( - in_features=in_features, - out_features=out_features, - bias=bias, - group_size=group_size, - weight_bits=weight_bits, - granularity_class=granularity_class, - ) - - # Apply quantization - pattern = pattern.apply_quantization() - - # Export the quantized pattern - edge = to_edge( - export( - pattern, - (x,), - ), - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - gm = edge.exported_program().graph_module - graphs.append(gm) - - return graphs + + def is_weight_perchannel_quantized(self) -> bool: + weight_shape = self.weight_node.meta["val"].shape + scales_shape = self.weight_scales_node.meta["val"].shape + if len(scales_shape) != 1: + return False + + # scales should have same size as weight's output channels dim + return scales_shape[0] == weight_shape[-2] + + def is_input_static_per_tensor_quantized(self) -> bool: + if self.quantize_input_node is None: + return False + + # For static quantization per tensor quantization, the scales and zeros + # are scalars. + return isinstance(self.input_scales_node, float) + + +linear_anchor_nodes = { + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.addmm.default, +} + + +@register_pattern_detector("quantized_linear") +def find_quantized_linear_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedLinearMatch]: + if node.target not in linear_anchor_nodes: + return None + + matched_pattern = QuantizedLinearMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Constant tensor manipulation +## def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor: @@ -192,117 +252,139 @@ def make_combined_scales_and_zeros_tensor( return torch.cat((scales_reshaped, zeros_scaled), dim=2) -def identify_wo_quantized_linear_io_nodes( # noqa: C901 - ep: ExportedProgram, - graph_module: torch.fx.GraphModule, - match: InternalMatch, -) -> Optional[List[torch.fx.Node]]: - dequant_node = None - # First, find the dequant node - for node in match.nodes_map.values(): - if utils.is_dequant_node(node): - dequant_node = node - break - - if dequant_node is None: - return None +## +## Pattern Replacement +## - quantized_weight = dequant_node.args[0] - quant_scales = dequant_node.args[2] - quant_zeros = dequant_node.args[3] - if not isinstance(quantized_weight, torch.fx.Node) or not is_param_node( - ep, quantized_weight - ): - return None - if not isinstance(quant_scales, torch.fx.Node) or not is_param_node( - ep, quant_scales - ): - return None - if not isinstance(quant_zeros, torch.fx.Node) or not is_param_node(ep, quant_zeros): - return None - - input_nodes = match.placeholder_nodes - if len(input_nodes) != 4: - return None - - in_tensor_node = None - for node in input_nodes: - if node not in dequant_node.args: - in_tensor_node = node - break - - if in_tensor_node is None: - return None - - output_nodes = match.returning_nodes - - if len(output_nodes) != 1: - return None - - out_tensor_node = output_nodes[0] - if not isinstance(out_tensor_node, torch.fx.Node): - return None - - return [ - in_tensor_node, - quantized_weight, - quant_scales, - quant_zeros, - out_tensor_node, - ] - - -# wo = "weight only" -@register_pattern_replacement("torchao_wo_quantized_linear") -def create_wo_quantized_linear_custom_op( +def make_linear_q4ga_op( ep: ExportedProgram, graph_module: torch.fx.GraphModule, - match: InternalMatch, + match: QuantizedLinearMatch, + weight_tensor: torch.Tensor, + weight_scales_tensor: torch.Tensor, + weight_zeros_tensor: torch.Tensor, ): - io_nodes = identify_wo_quantized_linear_io_nodes(ep, graph_module, match) - if io_nodes is None: - return - - assert len(io_nodes) == 5 - in_tensor, quantized_weight, quant_scales, quant_zeros, out_tensor = io_nodes - - quantized_weight_tensor = get_param_tensor(ep, quantized_weight) - if not isinstance(quantized_weight_tensor, torch.Tensor): - return - packed_quantized_weight_tensor = pack_4bit_weight_tensor(quantized_weight_tensor) + packed_quantized_weight_tensor = pack_4bit_weight_tensor(weight_tensor) utils.update_program_state_dict( - ep, quantized_weight.name, packed_quantized_weight_tensor + ep, match.weight_node.name, packed_quantized_weight_tensor + ) + # Need to make sure corresponding FakeTensor has same size + match.weight_node.meta["val"] = match.weight_node.meta["val"][:, ::2].to( + torch.uint8 ) - quantized_weight.meta["val"] = quantized_weight.meta["val"][:, ::2].to(torch.uint8) - - quant_scales_tensor = get_param_tensor(ep, quant_scales) - quant_zeros_tensor = get_param_tensor(ep, quant_zeros) - - assert quantized_weight_tensor is not None - assert quant_scales_tensor is not None - assert quant_zeros_tensor is not None - group_size = quantized_weight_tensor.shape[1] // quant_scales_tensor.shape[1] + group_size = weight_tensor.shape[1] // weight_scales_tensor.shape[1] combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor( - quant_scales_tensor, quant_zeros_tensor + weight_scales_tensor, weight_zeros_tensor ) - combined_scales_zeros_name = f"{quantized_weight.name}_scales_zeros" + combined_scales_zeros_name = f"{match.weight_node.name}_scales_zeros" graph_module.register_parameter( combined_scales_zeros_name, torch.nn.Parameter(combined_scales_zeros_tensor) ) - with graph_module.graph.inserting_before(out_tensor): + with graph_module.graph.inserting_before(match.output_node): combined_scales_zeros = graph_module.graph.get_attr(combined_scales_zeros_name) - wo_qlinear = graph_module.graph.create_node( + linear_q4ga_node = graph_module.graph.create_node( "call_function", exir_ops.edge.et_vk.linear_weight_int4.default, - args=(in_tensor, quantized_weight, group_size, combined_scales_zeros, 1), + args=( + match.fp_input_node, + match.weight_node, + group_size, + combined_scales_zeros, + 1, + ), + ) + + linear_q4ga_node.meta["val"] = match.output_node.meta["val"] + match.output_node.replace_all_uses_with(linear_q4ga_node) + + +def make_linear_q8ta_q8csw_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedLinearMatch, + weight_tensor: torch.Tensor, +): + first_graph_node = list(graph_module.graph.nodes)[0] + with graph_module.graph.inserting_before(first_graph_node): + weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) + # Pre-compute the weight sums which are needed to apply activation zero point + # when using integer accumulation. + sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + sums_name = weight_tensor_name + "_sums" + # Sanitize the name + sums_name = sums_name.replace(".", "_") + + weight_sums_node = create_constant_placeholder( + exp_program=ep, + graph=graph_module.graph, + kind=InputKind.CONSTANT_TENSOR, + name=sums_name, + data=sum_per_output_channel, + ) + + with graph_module.graph.inserting_before(match.output_node): + qlinear_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.linear_q8ta_q8csw.default, + args=( + match.fp_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + ), ) - if hasattr(out_tensor, "meta") and "val" in out_tensor.meta: - wo_qlinear.meta["val"] = out_tensor.meta["val"] + qlinear_node.meta["val"] = match.output_node.meta["val"] + match.output_node.replace_all_uses_with(qlinear_node) + + +@register_pattern_replacement("quantized_linear") +def replace_quantized_linear_patterns( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedLinearMatch, +): + # Extract relevant tensors + weight_tensor = get_param_tensor(ep, match.weight_node) + assert weight_tensor is not None + + assert match.weight_scales_node is not None + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + + assert match.weight_zeros_node is not None + weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node) + assert weight_zeros_tensor is not None + + # Biases not supported at the moment + if match.bias_node is not None: + return + + # Route to appropriate custom op + if ( + match.is_weight_only_quantized() + and match.is_weight_pergroup_quantized() + and utils.is_in_4bit_range(weight_tensor) + ): + make_linear_q4ga_op( + ep, + graph_module, + match, + weight_tensor, + weight_scales_tensor, + weight_zeros_tensor, + ) + elif ( + match.is_input_static_per_tensor_quantized() + and match.is_weight_perchannel_quantized() + ): + make_linear_q8ta_q8csw_custom_op(ep, graph_module, match, weight_tensor) - out_tensor.replace_all_uses_with(wo_qlinear) + # No-op for unsupported quant patterns diff --git a/backends/vulkan/patterns/rope.py b/backends/vulkan/patterns/rope.py index e0c2e4c5501..b174224ab78 100644 --- a/backends/vulkan/patterns/rope.py +++ b/backends/vulkan/patterns/rope.py @@ -12,6 +12,7 @@ import torch from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, register_pattern_graph, register_pattern_replacement, ) @@ -20,7 +21,6 @@ from executorch.exir.dialects._ops import ops as exir_ops from torch.export import export -from torch.fx.passes.utils.matcher_utils import InternalMatch class RotaryEmbeddingPattern(torch.nn.Module): @@ -111,16 +111,16 @@ def get_rope_graphs() -> List[torch.fx.GraphModule]: def identify_rotary_emb_io_nodes( ep: ExportedProgram, graph_module: torch.fx.GraphModule, - match: InternalMatch, + match: PatternMatch, ) -> Optional[List[torch.fx.Node]]: - # Get the input placeholders (xq, xk, freqs_cos, freqs_sin) - placeholder_nodes = match.placeholder_nodes - if len(placeholder_nodes) != 4: + # Get the input inputs (xq, xk, freqs_cos, freqs_sin) + input_nodes = match.input_nodes + if len(input_nodes) != 4: return None - xq, xk, freqs_cos, freqs_sin = placeholder_nodes + xq, xk, freqs_cos, freqs_sin = input_nodes - output_nodes = match.returning_nodes + output_nodes = match.output_nodes if len(output_nodes) != 2: return None @@ -133,7 +133,7 @@ def identify_rotary_emb_io_nodes( def create_rotary_emb_custom_op( ep: ExportedProgram, graph_module: torch.fx.GraphModule, - match: InternalMatch, + match: PatternMatch, ): io_nodes = identify_rotary_emb_io_nodes(ep, graph_module, match) if io_nodes is None: diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index ef429ff21fa..53fad86f90c 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -14,6 +14,7 @@ python_unittest( "//executorch/kernels/portable:custom_ops_generated_lib", ], deps = [ + ":test_utils", "//caffe2:torch", "//executorch/backends/transforms:convert_dtype_pass", "//executorch/backends/vulkan:vulkan_preprocess", @@ -68,3 +69,23 @@ runtime.python_library( "//executorch/backends/vulkan:vulkan_preprocess", ] ) + +runtime.python_library( + name = "test_utils", + srcs = [ + "utils.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan:vulkan_preprocess", + "//executorch/backends/vulkan/partitioner:vulkan_partitioner", + "//executorch/backends/xnnpack:xnnpack_preprocess", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/devtools:lib", + "//executorch/devtools/bundled_program/serialize:lib", + "//executorch/exir:lib", + "//executorch/extension/pybindings:portable_lib", # @manual + "//executorch/extension/pytree:pylib", + ], +) diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 687a8761c6b..00a357b0b67 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -10,6 +10,8 @@ import unittest from typing import Tuple +import executorch.backends.vulkan.test.utils as test_utils + import torch from executorch.backends.transforms.convert_dtype_pass import I64toI32 @@ -18,12 +20,23 @@ from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) + from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, ExecutorchProgramManager, + to_edge_transform_and_lower, +) +from executorch.extension.pybindings.portable_lib import ( # @manual + _load_for_executorch_from_buffer, ) -from torch.export import Dim, export, export_for_training, ExportedProgram +from executorch.extension.pytree import tree_flatten +from torch.export import Dim, export, ExportedProgram + from torchao.quantization.granularity import PerGroup from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -32,14 +45,10 @@ from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ from torchao.utils import unwrap_tensor_subclass -ctypes.CDLL("libvulkan.so.1") - - -from executorch.exir import to_edge_transform_and_lower -from executorch.extension.pybindings.portable_lib import ( # @manual - _load_for_executorch_from_buffer, -) -from executorch.extension.pytree import tree_flatten +try: + ctypes.CDLL("libvulkan.so.1") +except: + pass def lower_module( @@ -83,7 +92,7 @@ def quantize_and_lower_module( _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. ) - program = export_for_training( + program = export( model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True ).module() @@ -129,37 +138,47 @@ def assert_outputs_equal( # Multiple outputs executor always returns tuple, even if there is one output self.assertTrue(len(ref_output) == len(model_output)) if first_output_only: - self.assertTrue( - torch.allclose( - model_output[0], - ref_output[0], + result = torch.allclose( + model_output[0], + ref_output[0], + atol=atol, + rtol=rtol, + equal_nan=equal_nan, + ) + if not result: + test_utils.print_tensor_comparison_errors( + model_output[0], ref_output[0], atol, rtol + ) + self.assertTrue(result) + else: + for i in range(len(ref_output)): + result = torch.allclose( + model_output[i], + ref_output[i], atol=atol, rtol=rtol, equal_nan=equal_nan, ) - ) - else: - for i in range(len(ref_output)): - self.assertTrue( - torch.allclose( - model_output[i], - ref_output[i], - atol=atol, - rtol=rtol, - equal_nan=equal_nan, + if not result: + print(f"\n=== Output {i} comparison failed ===") + test_utils.print_tensor_comparison_errors( + model_output[i], ref_output[i], atol, rtol ) - ) + self.assertTrue(result) else: # If one output, eager returns tensor while executor tuple of size 1 - self.assertTrue( - torch.allclose( - model_output[0], - ref_output, - atol=atol, - rtol=rtol, - equal_nan=equal_nan, - ) + result = torch.allclose( + model_output[0], + ref_output, + atol=atol, + rtol=rtol, + equal_nan=equal_nan, ) + if not result: + test_utils.print_tensor_comparison_errors( + model_output[0], ref_output, atol, rtol + ) + self.assertTrue(result) def check_no_delegation(self, et_program: ExecutorchProgramManager): self.assertEqual( @@ -2388,3 +2407,246 @@ def apply_quantization(self): self.lower_module_and_test_output( quantized_linear_module_gemm, sample_inputs_gemm, atol=1e-2, rtol=1e-2 ) + + def test_vulkan_backend_xnnpack_pt2e_quantized_linear_sequence(self): + """ + Test a sequence of linear layers quantized with XNNPACK quantization config. + This test creates a module with multiple linear layers in sequence and applies + XNNPACK symmetric quantization to test the quantized model execution. + """ + + import executorch.backends.vulkan.test.utils as test_utils + + class LinearSequenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 64, bias=False) + self.linear2 = torch.nn.Linear(64, 32, bias=False) + self.linear3 = torch.nn.Linear(32, 16, bias=False) + + MAX = 0.75 + MIN = -0.25 + self.linear1.weight.data = test_utils.random_uniform_tensor( + self.linear1.weight.shape, MIN, MAX + ) + self.linear2.weight.data = test_utils.random_uniform_tensor( + self.linear2.weight.shape, MIN, MAX + ) + self.linear3.weight.data = test_utils.random_uniform_tensor( + self.linear3.weight.shape, MIN, MAX + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + # Create the module + linear_sequence_module = LinearSequenceModule() + + M = 32 + # Create sample inputs + sample_inputs = ( + ( + test_utils.random_uniform_tensor( + (M, linear_sequence_module.linear1.in_features), + -0.25, + 0.75, + ) + ), + ) + + # Create XNNPACK quantizer with symmetric quantization config + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + # Test the quantized module using the existing quantize_and_lower_module function + # Use higher tolerance since quantization introduces some error + edge_program = quantize_and_lower_module( + linear_sequence_module, sample_inputs, quantizer + ) + + et_program = edge_program.to_executorch() + self.check_vk_delegation(et_program) + + self.run_delegated_model_and_check_output( + et_program, + linear_sequence_module, + sample_inputs, + atol=1e-2, + rtol=1e-1, + ) + + def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence(self): + """ + Test a sequence of convolution layers quantized with PT2E quantization. + This test creates a module with multiple Conv2d layers in sequence and applies + XNNPACK symmetric quantization to test the quantized model execution. + Similar to the linear sequence test but using convolution layers. + """ + + import executorch.backends.vulkan.test.utils as test_utils + + class ConvSequenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + bias=False, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=16, + out_channels=32, + kernel_size=3, + padding=1, + bias=False, + ) + self.conv3 = torch.nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=3, + padding=1, + bias=False, + ) + + MAX = 0.75 + MIN = -0.25 + self.conv1.weight.data = test_utils.random_uniform_tensor( + self.conv1.weight.shape, MIN, MAX + ) + self.conv2.weight.data = test_utils.random_uniform_tensor( + self.conv2.weight.shape, MIN, MAX + ) + self.conv3.weight.data = test_utils.random_uniform_tensor( + self.conv3.weight.shape, MIN, MAX + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + + # Create the module + conv_sequence_module = ConvSequenceModule() + + input_tensor = test_utils.random_uniform_tensor( + (1, 3, 32, 32), + -0.25, + 0.75, + ) + + # Create sample inputs + sample_inputs = (input_tensor,) + + # Create XNNPACK quantizer with symmetric quantization config + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + # Test the quantized module using the existing quantize_and_lower_module function + # Use higher tolerance since quantization introduces some error + edge_program = quantize_and_lower_module( + conv_sequence_module, sample_inputs, quantizer + ) + + et_program = edge_program.to_executorch() + self.check_vk_delegation(et_program) + + self.run_delegated_model_and_check_output( + et_program, + conv_sequence_module, + sample_inputs, + atol=1e-2, + rtol=1e-1, + ) + + def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence_all_reduced(self): + """ + Test a sequence of convolution layers quantized with PT2E quantization. + This test creates a module with multiple Conv2d layers in sequence and applies + XNNPACK symmetric quantization to test the quantized model execution. + Similar to the linear sequence test but using convolution layers. + """ + + import executorch.backends.vulkan.test.utils as test_utils + + class ConvSequenceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d( + in_channels=3, + out_channels=32, + kernel_size=3, + padding=1, + bias=False, + ) + self.conv2 = torch.nn.Conv2d( + in_channels=32, + out_channels=1, + kernel_size=3, + padding=1, + bias=False, + ) + + MAX = 0.75 + MIN = -0.25 + self.conv1.weight.data = test_utils.random_uniform_tensor( + self.conv1.weight.shape, MIN, MAX + ) + self.conv2.weight.data = test_utils.random_uniform_tensor( + self.conv2.weight.shape, MIN, MAX + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + # Create the module + conv_sequence_module = ConvSequenceModule() + + input_tensor = test_utils.random_uniform_tensor( + (1, 3, 32, 32), + -0.25, + 0.75, + ) + + # Create sample inputs + sample_inputs = (input_tensor,) + + # Create XNNPACK quantizer with symmetric quantization config + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + # Test the quantized module using the existing quantize_and_lower_module function + # Use higher tolerance since quantization introduces some error + edge_program = quantize_and_lower_module( + conv_sequence_module, sample_inputs, quantizer + ) + + et_program = edge_program.to_executorch() + self.check_vk_delegation(et_program) + + self.run_delegated_model_and_check_output( + et_program, + conv_sequence_module, + sample_inputs, + atol=1e-2, + rtol=1e-1, + ) diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index 0e9ea6bc9d8..363ee37058d 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -7,7 +7,10 @@ import logging from collections import OrderedDict -from typing import List, Optional, Tuple +from copy import deepcopy + +from enum import auto, Enum +from typing import Any, List, Optional, Tuple import executorch.backends.vulkan.utils as utils @@ -16,6 +19,11 @@ from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend from executorch.devtools import BundledProgram from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite @@ -29,6 +37,47 @@ from executorch.extension.pytree import tree_flatten from torch.export import export, export_for_training +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class QuantizationMode(Enum): + """Enum to describe how a model should be quantized.""" + + NONE = auto() + INT8_STATIC_PER_CHANNEL = auto() + + +def get_exported_graph( + model, + sample_inputs, + dynamic_shapes=None, + qmode=QuantizationMode.NONE, +) -> torch.fx.GraphModule: + export_training_graph = export_for_training( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ).module() + + if qmode == QuantizationMode.NONE: + return export_training_graph + + quantizer = XNNPACKQuantizer() + + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + + prepared_graph = prepare_pt2e(export_training_graph, quantizer) + prepared_graph(*sample_inputs) + converted_graph = convert_pt2e(prepared_graph) + + return converted_graph + + +def random_uniform_tensor(shape, low=0.0, high=1.0, device=None, dtype=None): + if dtype is None: + dtype = torch.float32 + + return torch.empty(shape, device=device, dtype=dtype).uniform_(low, high) + def export_model_to_vulkan( model, @@ -36,18 +85,19 @@ def export_model_to_vulkan( dynamic_shapes=None, operator_blocklist=None, operator_allowlist=None, + nn_module_blocklist=None, + nn_module_allowlist=None, + qmode=QuantizationMode.NONE, ): - """Helper to export a model to Vulkan backend.""" compile_options = {} - export_training_graph = export_for_training( - model, sample_inputs, strict=True - ).module() + exported_graph = get_exported_graph(model, sample_inputs, qmode=qmode) program = export( - export_training_graph, + exported_graph, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True, ) + edge_program = to_edge_transform_and_lower( program, partitioner=[ @@ -55,6 +105,8 @@ def export_model_to_vulkan( compile_options, operator_blocklist=operator_blocklist, operator_allowlist=operator_allowlist, + nn_module_blocklist=nn_module_blocklist, + nn_module_allowlist=nn_module_allowlist, ) ], transform_passes=None, @@ -75,18 +127,25 @@ def export_model_to_vulkan( return executorch_program -def export_model_to_xnnpack(model, sample_inputs, dynamic_shapes=None): - """Helper to export a model to XNNPACK backend.""" +def export_model_to_xnnpack( + model, + sample_inputs, + dynamic_shapes=None, + operator_blocklist=None, + operator_allowlist=None, + nn_module_blocklist=None, + nn_module_allowlist=None, + qmode=QuantizationMode.NONE, +): compile_options = {} - export_training_graph = export_for_training( - model, sample_inputs, strict=True - ).module() + exported_graph = get_exported_graph(model, sample_inputs, qmode=qmode) program = export( - export_training_graph, + exported_graph, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True, ) + edge_program = to_edge_transform_and_lower( program, partitioner=[XnnpackPartitioner(compile_options)], @@ -108,6 +167,74 @@ def export_model_to_xnnpack(model, sample_inputs, dynamic_shapes=None): return executorch_program +def print_tensor_comparison_errors( + tensor1, tensor2, atol=1e-03, rtol=1e-03, max_errors=10 +): + """ + Print the first max_errors tensor indexes that exceed the absolute/relative tolerance + and the error at each of those locations. + + Args: + tensor1: First tensor to compare + tensor2: Second tensor to compare + atol: Absolute tolerance + rtol: Relative tolerance + max_errors: Maximum number of errors to print (default: 10) + """ + # Handle lists/tuples of tensors + if isinstance(tensor1, (list, tuple)) and isinstance(tensor2, (list, tuple)): + if len(tensor1) != len(tensor2): + print(f"Tensor count mismatch: {len(tensor1)} vs {len(tensor2)}") + return + + for i, (t1, t2) in enumerate(zip(tensor1, tensor2)): + print(f"\n=== Tensor {i} comparison ===") + print_tensor_comparison_errors(t1, t2, atol, rtol, max_errors) + return + + # Handle single tensor comparison + if not isinstance(tensor1, torch.Tensor) or not isinstance(tensor2, torch.Tensor): + print("Error: Both inputs must be torch.Tensor objects") + return + + if tensor1.shape != tensor2.shape: + print(f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}") + return + + # Calculate absolute and relative errors + abs_diff = torch.abs(tensor1 - tensor2) + rel_diff = abs_diff / ( + torch.abs(tensor2) + 1e-8 + ) # Add small epsilon to avoid division by zero + + # Find locations where tolerance is exceeded + tolerance_mask = (abs_diff > atol) & (rel_diff > rtol) + + if not tolerance_mask.any(): + print("All values are within tolerance") + return + + # Get indices where tolerance is exceeded + error_indices = torch.nonzero(tolerance_mask, as_tuple=False) + total_errors = error_indices.shape[0] + + print(f"Found {total_errors} values exceeding tolerance (atol={atol}, rtol={rtol})") + print(f"Showing first {min(max_errors, total_errors)} errors:") + print("Index -> tensor1_value, tensor2_value, abs_error, rel_error") + + # Print first max_errors locations + for i in range(min(max_errors, total_errors)): + idx = tuple(error_indices[i].tolist()) + val1 = tensor1[idx].item() + val2 = tensor2[idx].item() + abs_err = abs_diff[idx].item() + rel_err = rel_diff[idx].item() + + print( + f"{idx} -> {val1:.6f}, {val2:.6f}, abs_err={abs_err:.6f}, rel_err={rel_err:.6f}" + ) + + def check_outputs_equal( model_output, ref_output, atol=1e-03, rtol=1e-03, first_output_only=False ): @@ -123,19 +250,34 @@ def check_outputs_equal( if isinstance(ref_output, tuple) or isinstance(ref_output, list): # Multiple outputs executor always returns tuple, even if there is one output if len(ref_output) != len(model_output): + print_tensor_comparison_errors(model_output, ref_output, atol, rtol) return False if first_output_only: - return torch.allclose(model_output[0], ref_output[0], atol=atol, rtol=rtol) + result = torch.allclose( + model_output[0], ref_output[0], atol=atol, rtol=rtol + ) + if not result: + print_tensor_comparison_errors( + model_output[0], ref_output[0], atol, rtol + ) + return result else: for i in range(len(ref_output)): if not torch.allclose( model_output[i], ref_output[i], atol=atol, rtol=rtol ): + print(f"\n=== Output {i} comparison failed ===") + print_tensor_comparison_errors( + model_output[i], ref_output[i], atol, rtol + ) return False return True else: # If one output, eager returns tensor while executor tuple of size 1 - return torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol) + result = torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol) + if not result: + print_tensor_comparison_errors(model_output[0], ref_output, atol, rtol) + return result def run_and_check_output( @@ -183,6 +325,16 @@ def run_and_check_output( ) +def make_copy_of_inputs(sample_inputs: Tuple[Any]) -> Tuple[Any]: + sample_inputs_copy = [] + for input_val in sample_inputs: + if isinstance(input_val, torch.Tensor): + sample_inputs_copy.append(input_val.clone()) + else: + sample_inputs_copy.append(deepcopy(input_val)) + return tuple(sample_inputs_copy) + + def lower_module_and_test_output( model: torch.nn.Module, sample_inputs: Tuple[torch.Tensor], @@ -193,6 +345,9 @@ def lower_module_and_test_output( first_output_only=False, operator_blocklist=None, operator_allowlist=None, + nn_module_allowlist=None, + nn_module_blocklist=None, + xnnpack=False, ) -> bool: """ Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with @@ -203,16 +358,33 @@ def lower_module_and_test_output( bool: True if all comparisons pass, False otherwise. """ # Export model to Vulkan using the helper function - executorch_program = export_model_to_vulkan( - model, sample_inputs, dynamic_shapes, operator_blocklist, operator_allowlist - ) + if xnnpack: + executorch_program = export_model_to_xnnpack( + model, + make_copy_of_inputs(sample_inputs), + dynamic_shapes, + operator_blocklist, + operator_allowlist, + nn_module_blocklist, + nn_module_allowlist, + ) + else: + executorch_program = export_model_to_vulkan( + model, + make_copy_of_inputs(sample_inputs), + dynamic_shapes, + operator_blocklist=operator_blocklist, + operator_allowlist=operator_allowlist, + nn_module_blocklist=nn_module_blocklist, + nn_module_allowlist=nn_module_allowlist, + ) executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) inputs_flattened, _ = tree_flatten(sample_inputs) model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) - ref_output = model(*sample_inputs) + ref_output = model(*make_copy_of_inputs(sample_inputs)) if not check_outputs_equal( model_output, @@ -455,14 +627,14 @@ def op_ablation_test( # noqa: C901 all_operators = list(operator_frequencies.keys()) logger.info(f"Found {len(all_operators)} unique operators in the graph") - # Sort operators by frequency (least frequent first for binary search) + # Sort operators by frequency (most frequent first for binary search) operators_by_frequency = sorted( - all_operators, key=lambda op: operator_frequencies[op] + all_operators, key=lambda op: operator_frequencies[op], reverse=True ) - logger.info("Operator frequencies (sorted by occurrence, least frequent first):") + logger.info("Operator frequencies (sorted by occurrence, most frequent first):") for op in operators_by_frequency: - logger.info(f" {op}: {operator_frequencies[op]} occurrences") + logger.info(f" {op.name()}: {operator_frequencies[op]} occurrences") # Global test counter test_count = 0 @@ -489,6 +661,17 @@ def test_operator_set(ops_to_test: List, known_good_ops: List) -> bool: operator_allowlist=test_allowlist, ) logger.info(f" {'✓ PASS' if success else '✗ FAIL'}") + + # Log known good ops + logger.info(" Known good:") + for op in known_good_ops: + logger.info(f" * {op.name()}") + + # Log tested ops + logger.info(" Tested ops:") + for op in ops_to_test: + logger.info(f" * {op.name()}") + return success except Exception as e: logger.info(f" ! Error: {e}") @@ -510,10 +693,10 @@ def find_bad_operators( # Base case: single operator op = ops_to_test[0] if test_operator_set([op], known_good_ops): - logger.info(f" Single operator {op} is GOOD") + logger.info(f" Single operator {op.name()} is GOOD") return [op], [] else: - logger.info(f" Single operator {op} is BAD") + logger.info(f" Single operator {op.name()} is BAD") return [], [op] # Split ops_to_test into two halves @@ -525,20 +708,33 @@ def find_bad_operators( f"Splitting {len(ops_to_test)} operators: {len(first_half)} + {len(second_half)}" ) - # Test each half - first_half_good = test_operator_set(first_half, known_good_ops) - second_half_good = test_operator_set(second_half, known_good_ops) + # Log known good ops + logger.info(" Known good:") + for op in known_good_ops: + logger.info(f" * {op.name()}") + + # Log first half ops + logger.info(" First half ops:") + for op in first_half: + logger.info(f" * {op.name()}") + + # Log second half ops + logger.info(" Second half ops:") + for op in second_half: + logger.info(f" * {op.name()}") good_ops = [] bad_ops = [] - # Process first half + first_half_good = test_operator_set(first_half, known_good_ops) if first_half_good: logger.info( f"First half ({len(first_half)} ops) is good - adding to known good" ) good_ops.extend(first_half) known_good_ops.extend(first_half) + + second_half_good = test_operator_set(second_half, known_good_ops) if second_half_good: logger.info( f"Second half ({len(second_half)} ops) is good - adding to known good" @@ -569,11 +765,11 @@ def find_bad_operators( logger.info(f"\n=== Binary search complete after {test_count} tests ===") logger.info(f"Good operators ({len(good_operators)}):") for op in good_operators: - logger.info(f" ✓ {op} (frequency: {operator_frequencies[op]})") + logger.info(f" ✓ {op.name()} (frequency: {operator_frequencies[op]})") logger.info(f"Bad operators ({len(bad_operators)}):") for op in bad_operators: - logger.info(f" ✗ {op} (frequency: {operator_frequencies[op]})") + logger.info(f" ✗ {op.name()} (frequency: {operator_frequencies[op]})") print_occurrences(edge_program, bad_operators) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 3b3e27acfbd..1291eb62936 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -22,15 +22,17 @@ from executorch.exir.tensor import TensorSpec -from torch._export.utils import is_buffer, is_param +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param -from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter from torch.export import ExportedProgram from torch.export.exported_program import InputKind from torch.export.graph_signature import TensorArgument +TorchOpType = Union[EdgeOpOverload, torch._ops.OpOverload, str] + _DQ_OPS = { "dequantize_per_tensor.tensor", "dequantize_per_tensor.default", @@ -275,6 +277,45 @@ def node_comes_from_any_nn_module_in_set( return False +def get_tensor_name(exp_prog: ExportedProgram, node: torch.fx.Node) -> str: + if node is None: + return "" + if is_param(exp_prog, node): + return exp_prog.graph_signature.inputs_to_parameters[node.name] + elif is_buffer(exp_prog, node): + return exp_prog.graph_signature.inputs_to_buffers[node.name] + elif is_lifted_tensor_constant(exp_prog, node): + return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name] + else: + assert isinstance(node.target, str) + return node.target + + return "" + + +def find_dequant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: + """ + Search the direct users of the given node and return the first one that is a + dequantization op. Returns None if no dequantization op is found. + """ + for user in node.users: + if is_dequant_node(user): + return user + return None + + +def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: + """ + Search the direct users of the given node and return the first one that is a + quantization op. Returns None if no quantization op is found. + """ + for user in node.users: + if is_quant_node(user): + return user + + return None + + ## ## Memory Layout, Storage Type Determination ## @@ -1068,6 +1109,89 @@ def get_node_repr(node) -> Union[TensorRepr, TensorReprList]: return get_node_spec_attr(node, "etvk_node_repr", False) +## +## Graph Pattern Matching +## + + +def maybe_skip_q_dq_arg_chain( + arg: torch.fx.node.Argument, +) -> Tuple[Optional[torch.fx.Node], Optional[torch.fx.Node], Optional[torch.fx.Node]]: + """ + Check if the given node argument is part of a Quantize/Dequantize chain produced by + the quant workflow. If so, return the source tensor that is the input to the Q/DQ + chain and the quantize/dequantize nodes in the chain. Otherwise, return the argument + as is and None, None + """ + if not isinstance(arg, torch.fx.Node): + return None, None, None + + if is_dequant_node(arg): + dequant_node = arg + quant_node = dequant_node.args[0] + assert isinstance(quant_node, torch.fx.Node) + source_arg = quant_node.args[0] + assert isinstance(source_arg, torch.fx.Node) + return source_arg, quant_node, dequant_node + else: + return arg, None, None + + +def trace_args_until_placeholder( + node: torch.fx.node.Argument, max_search_depth: int = 4 +) -> Tuple[Optional[torch.fx.Node], List[torch.fx.Node]]: + """ + Trace through node.args[0] of a given initial node until a placeholder node is found + then return it and the list of nodes traversed. If no placeholder node is found, + returns None and an empty list. + """ + cur_node = node + search_depth = 0 + + if not isinstance(cur_node, torch.fx.Node): + return None, [] + + traversed = [cur_node] + while cur_node.op != "placeholder" and search_depth < max_search_depth: + # Break if cur_node has no args + if len(cur_node.args) == 0: + break + + cur_node = cur_node.args[0] + if not isinstance(cur_node, torch.fx.Node): + break + traversed.append(cur_node) + search_depth += 1 + + if not isinstance(cur_node, torch.fx.Node): + return None, [] + if cur_node.op != "placeholder": + return None, [] + + assert isinstance(cur_node, torch.fx.Node) + return cur_node, traversed + + +def is_in_4bit_range(tensor: torch.Tensor) -> bool: + """ + Check if the given tensor is in the range of 4-bit quantization and is of integer type. + """ + if tensor.dtype not in (torch.int8, torch.uint8): + return False + + return tensor.min().item() >= -8 and tensor.max().item() <= 7 + + +def is_in_8bit_range(tensor: torch.Tensor) -> bool: + """ + Check if the given tensor is in the range of 4-bit quantization and is of integer type. + """ + if tensor.dtype not in (torch.int8, torch.uint8): + return False + + return tensor.min().item() >= -128 and tensor.max().item() <= 127 + + ## ## Misc ## @@ -1143,3 +1267,39 @@ def update_program_state_dict( # Finally, overwrite the current tensor with updated tensor program.state_dict[target_name] = updated_tensor + + +def align_width_and_update_state_dict( + ep: ExportedProgram, + node: torch.fx.Node, + cur_tensor: torch.Tensor, + align_to: int = 4, + force_update: bool = False, +) -> torch.Tensor: + """ + Align the width of the given tensor to the given alignment value and update the + state dict of the program with the aligned tensor. + """ + added_padding = False + cur_width = cur_tensor.shape[-1] + # Only align the width of the tensor if it is not already aligned + if cur_width % align_to != 0: + num_padding = align_to - (cur_width % align_to) + # Align the width of the tensor to the given alignment value + aligned_tensor = torch.nn.functional.pad( + cur_tensor, (0, num_padding) + ).contiguous() + added_padding = True + else: + aligned_tensor = cur_tensor + + if added_padding or force_update: + update_program_state_dict(ep, node.name, aligned_tensor) + # FakeTensor needs to match updated tensor + cur_fake_tensor = node.meta["val"] + node.meta["val"] = FakeTensorConverter().from_real_tensor( + cur_fake_tensor.fake_mode, + aligned_tensor, + ) + + return aligned_tensor diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 5db5d7a4ff4..69d3cdef75d 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -19,6 +19,7 @@ ViewCopyToSqueezeUnsqueezePass, ) from executorch.backends.vulkan._passes import ( + FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, @@ -157,6 +158,7 @@ def preprocess( # noqa: C901 RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseQuantizedOpsTransform(program), + FoldQDQPass(program), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(),