From 8686e98e80da481c3e82767081e06fb048c08ae3 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 29 Sep 2025 18:25:36 -0400 Subject: [PATCH] [ET-VK] AOT logic for quantized conv2d Pull Request resolved: https://github.com/pytorch/executorch/pull/14648 ### Changes As title; this diff adds the necessary export logic required to enable fusing quantized convolution patterns into the custom ops introduced in the below diff. ghstack-source-id: 312809808 @exported-using-ghexport Differential Revision: [D83437826](https://our.internmc.facebook.com/intern/diff/D83437826/) --- backends/vulkan/_passes/TARGETS | 14 ++ backends/vulkan/_passes/__init__.py | 2 + backends/vulkan/_passes/replace_qdq.py | 93 +++++++++++ backends/vulkan/custom_ops_lib.py | 158 +++++++++++++++--- backends/vulkan/op_registry.py | 41 ++++- .../vulkan/patterns/quantized_convolution.py | 100 +++++++---- backends/vulkan/runtime/VulkanBackend.cpp | 4 + backends/vulkan/serialization/schema.fbs | 2 + .../serialization/vulkan_graph_schema.py | 2 + backends/vulkan/test/test_vulkan_delegate.py | 2 + backends/vulkan/utils.py | 10 ++ backends/vulkan/vulkan_preprocess.py | 2 + 12 files changed, 377 insertions(+), 53 deletions(-) create mode 100644 backends/vulkan/_passes/replace_qdq.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index aed41114ada..ae1a0b79654 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -117,6 +117,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "replace_qdq", + srcs = ["replace_qdq.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan:utils_lib", + "//executorch/exir:pass_base", + ], +) + runtime.python_library( name = "fuse_patterns", srcs = ["fuse_patterns.py"], @@ -150,6 +163,7 @@ runtime.python_library( ":remove_asserts", ":remove_local_scalar_dense", ":remove_redundant_ops", + ":replace_qdq", ":squeeze_unsqueeze_inputs", ":tag_memory_meta_pass", ] diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index f4ef6b2ac0e..169bd60543c 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -22,6 +22,7 @@ from executorch.backends.vulkan._passes.remove_redundant_ops import ( RemoveRedundantOpsTransform, ) +from executorch.backends.vulkan._passes.replace_qdq import ReplaceQDQPass from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import ( SqueezeUnsqueezeInputs, ) @@ -36,6 +37,7 @@ "RemoveAssertsTransform", "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", + "ReplaceQDQPass", "SqueezeUnsqueezeInputs", "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/replace_qdq.py b/backends/vulkan/_passes/replace_qdq.py new file mode 100644 index 00000000000..3613c5bf53c --- /dev/null +++ b/backends/vulkan/_passes/replace_qdq.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.vulkan.utils as utils +import torch +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult + + +class ReplaceQDQPass(ExportPass): + """ + Replace standard quantize/dequantize ops with custom conv-specific ops when they + feed into/from quantized convolution operations. This optimization allows the + backend to handle quantization more efficiently for convolution operations. + """ + + def __init__(self): + super(ReplaceQDQPass, self).__init__() + + def call(self, graph_module: torch.fx.GraphModule): + # Track nodes that need to be replaced + nodes_to_replace = [] + + for node in graph_module.graph.nodes: + # Check if this is the custom quantized conv2d op + if node.target in [ + exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default, + exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, + ]: + # Replace quantize op feeding into conv2d (first argument is the quantized input) + quantized_input_node = node.args[0] + if isinstance( + quantized_input_node, torch.fx.Node + ) and utils.is_quant_node(quantized_input_node): + # Get the arguments from the original quantize node + input_tensor = quantized_input_node.args[0] + scale = quantized_input_node.args[1] + zero_point = quantized_input_node.args[2] + + nodes_to_replace.append( + { + "old_node": quantized_input_node, + "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, + "args": (input_tensor, scale, zero_point), + "node_type": "quantize_input", + } + ) + + # Find dequantize ops that consume the output of this conv2d + for user in node.users: + if utils.is_dequant_node(user): + # Get the arguments from the original dequantize node + scale = user.args[1] + zero_point = user.args[2] + + nodes_to_replace.append( + { + "old_node": user, + "new_target": exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default, + "args": ( + node, + scale, + zero_point, + ), # node is the conv2d output + "node_type": "dequantize_output", + } + ) + + # Apply the replacements + for replacement in nodes_to_replace: + old_node = replacement["old_node"] + new_target = replacement["new_target"] + new_args = replacement["args"] + + with graph_module.graph.inserting_before(old_node): + new_node = graph_module.graph.create_node( + "call_function", new_target, args=new_args + ) + new_node.meta = old_node.meta.copy() + old_node.replace_all_uses_with(new_node) + + # Clean up the graph + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + # Re-trace to validate everything is ok + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 56e803b9127..314c470e5db 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -354,18 +354,20 @@ def linear_q8ta_q8csw( lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd") qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name) -####################### -## conv2d_q8ta_q8csw ## -####################### +############################ +## conv2d_q8ta_q8csw_q8to ## +############################ -def conv2d_q8ta_q8csw( +def conv2d_q8ta_q8csw_q8to( x: torch.Tensor, input_scale: float, input_zero_point: int, weights: torch.Tensor, weight_sums: torch.Tensor, weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, bias: Optional[torch.Tensor], kernel_size: list, stride: list, @@ -373,27 +375,103 @@ def conv2d_q8ta_q8csw( dilation: list, groups: int, ): - IC = x.shape[1] + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + # Calculate weight dimensions + OC = weights.shape[0] + assert OC % groups == 0, "Output channels must be divisible by groups" + IC_per_group = int(x.shape[1] / groups) K_h, K_w = kernel_size[0], kernel_size[1] - canonical_weight_K_dim = K_h * K_w * IC + orig_weight_K_dim = K_h * K_w * IC_per_group + # Remove any padding added to in_features dim to align to a multiple of 4 + if weights.shape[-1] > orig_weight_K_dim: + weights = weights[:, :orig_weight_K_dim] + # Remove any padding added to output channels dim to align to a multiple of 4 - if weights.shape[-1] != canonical_weight_K_dim: - weights = weights[:, :canonical_weight_K_dim] - weight_scales = weight_scales[:canonical_weight_K_dim] + if weight_scales.shape[0] > OC: + weight_scales = weight_scales[:OC] if bias is not None: - bias = bias[:canonical_weight_K_dim] + bias = bias[:OC] + + # Reshape to original 4D format (OC, IC, H, W) + weights = weights.view(OC, IC_per_group, K_h, K_w) weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + # Dequantize weights + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales, + weight_zeros, + 0, # axis=0 for output channel quantization + -127, + 127, + torch.int8, + ) - # Calculate dimensions - OC = weights.shape[0] - in_features = weights.shape[1] - IC = in_features // (K_h * K_w) + # Perform convolution + out = torch.nn.functional.conv2d( + x, weights, bias, stride, padding, dilation, groups + ) - # Reshape to original 4D format (OC, IC, H, W) - weights = weights.view(OC, IC, K_h, K_w) + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + +name = "conv2d_q8ta_q8csw_q8to" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias, + SymInt[] kernel_size, + SymInt[] stride, + SymInt[] padding, + SymInt[] dilation, + SymInt groups) -> Tensor + """ +) +lib.impl(name, conv2d_q8ta_q8csw_q8to, "CompositeExplicitAutograd") +conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name) + + +def conv2d_q8ta_q8csw_q8to_dw( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor], + kernel_size: list, + stride: list, + padding: list, + dilation: list, + groups: int, +): + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + # Restore weight to original data layout + K_h, K_w, OC = weights.shape + weights = weights.permute(2, 0, 1).reshape(OC, 1, K_h, K_w) + + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) # Dequantize weights weights = torch.ops.quantized_decomposed.dequantize_per_channel( weights, @@ -410,10 +488,14 @@ def conv2d_q8ta_q8csw( x, weights, bias, stride, padding, dilation, groups ) + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + return out -name = "conv2d_q8ta_q8csw" +name = "conv2d_q8ta_q8csw_q8to_dw" lib.define( f""" {name}( @@ -423,6 +505,8 @@ def conv2d_q8ta_q8csw( Tensor weights, Tensor weight_sums, Tensor weight_scales, + float output_scale, + int output_zero_point, Tensor? bias, SymInt[] kernel_size, SymInt[] stride, @@ -431,8 +515,8 @@ def conv2d_q8ta_q8csw( SymInt groups) -> Tensor """ ) -lib.impl(name, conv2d_q8ta_q8csw, "CompositeExplicitAutograd") -conv2d_q8ta_q8csw_op = getattr(getattr(torch.ops, namespace), name) +lib.impl(name, conv2d_q8ta_q8csw_q8to_dw, "CompositeExplicitAutograd") +conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name) ###################### ## apply_rotary_emb ## @@ -452,3 +536,39 @@ def apply_rotary_emb_impl( ) lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd") apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name) + +############################# +## quantize/dequantize ops ## +############################# + + +def quantize_q8ta_for_conv2d_impl( + input: torch.Tensor, + scale: float, + zero_point: int, +): + return torch.ops.quantized_decomposed.quantize_per_tensor( + input, scale, zero_point, -128, 127, torch.int8 + ) + + +name = "quantize_q8ta_for_conv2d" +lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor") +lib.impl(name, quantize_q8ta_for_conv2d_impl, "CompositeExplicitAutograd") +quantize_q8ta_for_conv2d_op = getattr(getattr(torch.ops, namespace), name) + + +def dequantize_q8to_from_conv2d_impl( + input: torch.Tensor, + scale: float, + zero_point: int, +): + return torch.ops.quantized_decomposed.dequantize_per_tensor( + input, scale, zero_point, -128, 127, input.dtype + ) + + +name = "dequantize_q8to_from_conv2d" +lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor") +lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd") +dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 4c686e0cfc5..8d67a5275d7 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -495,18 +495,21 @@ def register_convolution_op(): @update_features( [ - exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default, + exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default, + exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, ] ) def register_quantized_conv_op(): return OpFeatures( inputs_storage=[ - utils.CHANNELS_PACKED_TEXTURE, # input + utils.PACKED_INT8_4W4C_BUFFER, # input utils.NO_STORAGE, # input_scale (non tensor) utils.NO_STORAGE, # input_zero_point (non tensor) utils.NO_STORAGE, # weight (prepacked) utils.NO_STORAGE, # weight_sums (prepacked) utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) utils.NO_STORAGE, # bias (prepacked) utils.NO_STORAGE, # kernel_size (non tensor) utils.NO_STORAGE, # stride (non tensor) @@ -520,6 +523,40 @@ def register_quantized_conv_op(): ) +@update_features( + [ + exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, + ] +) +def register_quantize_for_conv2d_op(): + return OpFeatures( + inputs_storage=[ + utils.CHANNELS_PACKED_TEXTURE, + ], + outputs_storage=[ + utils.PACKED_INT8_4W4C_BUFFER, + ], + supports_resize=False, + ) + + +@update_features( + [ + exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default, + ] +) +def register_dequantize_for_conv2d_op(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_4W4C_BUFFER, + ], + outputs_storage=[ + utils.CHANNELS_PACKED_TEXTURE, + ], + supports_resize=False, + ) + + @update_features("llama::sdpa_with_kv_cache") def register_sdpa_with_kv_cache_op(): return OpFeatures( diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 65b51b5e103..522a19c58d6 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -76,11 +76,13 @@ def __init__(self, conv_node: torch.fx.Node) -> None: # Identify output node self.output_node = self.anchor_node - out_channels = self.output_node.meta["val"].shape[-1] - # The implementation requires that for grouped convolutions, a group does not - # cross any texel boundary. The output channels per group must be a multiple of - # 4. If this is not true, then don't match the pattern. - if self.groups > 1 and (out_channels / self.groups) % 4 == 0: + out_channels = self.output_node.meta["val"].shape[-3] + # The implementation requires that for non-depthwise grouped convolutions, a + # group does not cross the texel boundary. The output channels per group must be + # a multiple of 4. If this is not true, then don't match the pattern. + if (self.groups > 1 and self.groups < out_channels) and ( + out_channels / self.groups + ) % 4 != 0: return # Identify bias node, if applicable @@ -93,23 +95,37 @@ def __init__(self, conv_node: torch.fx.Node) -> None: self.all_nodes.extend(arg_chain) # Identify input node - self.fp_input_node, self.quantize_input_node, dq_node = ( - utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) - ) - assert self.fp_input_node is not None - self.all_nodes.append(self.fp_input_node) - assert self.quantize_input_node is not None - assert dq_node is not None - - self.input_scales_node = self.quantize_input_node.args[1] - self.input_zeros_node = self.quantize_input_node.args[2] - - self.all_nodes.extend( - [ - self.quantize_input_node, - dq_node, - ] - ) + primary_input_node = self.anchor_node.args[0] + assert isinstance(primary_input_node, torch.fx.Node) + # Argument must be a dequant node for static quantization + if not utils.is_dequant_node(primary_input_node): + return + + self.dequantize_input_node = primary_input_node + self.quantize_input_node = self.dequantize_input_node.args[0] + + self.input_scales_node = self.dequantize_input_node.args[1] + self.input_zeros_node = self.dequantize_input_node.args[2] + + self.all_nodes.extend([self.dequantize_input_node]) + + # The convolution output must have only one user; it will be either a relu node + # or a dequantize node. + if len(self.output_node.users) != 1: + return + + cur_node = list(self.output_node.users)[0] + self.relu_node = None + if cur_node.target == exir_ops.edge.aten.relu.default: + self.relu_node = cur_node + cur_node = list(cur_node.users)[0] + + if not utils.is_quant_node(cur_node): + return + + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] self.match_found = True @@ -161,13 +177,26 @@ def make_conv2d_q8ta_q8csw_custom_op( bias_tensor = get_param_tensor(ep, match.bias_node) assert bias_tensor is not None - OC, IC, H, W = weight_tensor.shape + OC, IC_per_group, H, W = weight_tensor.shape - # Reshape weight tensor from (OC, IC, H, W) to (OC, H * W * IC) (i.e. matrix format) - # This prepares the weights for Im2Col-based convolution - weight_tensor = ( - weight_tensor.permute(0, 2, 3, 1).contiguous().view(OC, H * W * IC).contiguous() - ) + is_depthwise_conv = IC_per_group == 1 and match.groups == OC + + if is_depthwise_conv: + assert OC % 4 == 0, "depthwise conv requires that OC is divisible by 4" + # Depthwise convs use a specialized layout; the weight tensor is reshaped to + # (H, W, OC) + weight_tensor = ( + weight_tensor.permute(2, 3, 1, 0).contiguous().view(H, W, OC).contiguous() + ) + else: + # Reshape weight tensor from (OC, IC_per_group, H, W) to (OC, H * W * IC_per_group) + # (i.e. matrix format). This prepares the weights for Im2Col-based convolution. + weight_tensor = ( + weight_tensor.permute(0, 2, 3, 1) + .contiguous() + .view(OC, H * W * IC_per_group) + .contiguous() + ) # Need to make sure that OC dim is a multiple of 4 so that data load/stores are well # aligned with texel boundaries. Add padding to align to the next multiple of 4 if @@ -178,6 +207,7 @@ def make_conv2d_q8ta_q8csw_custom_op( utils.align_width_and_update_state_dict( ep, match.weight_scales_node, weight_scales_tensor ) + if bias_tensor is not None: utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor) @@ -185,7 +215,7 @@ def make_conv2d_q8ta_q8csw_custom_op( with graph_module.graph.inserting_before(first_graph_node): qweight_tensor_name = utils.get_tensor_name(ep, match.weight_node) # Pre-compute the weight sums which are needed to apply activation zero point - # when using integer accumulation. For the reshaped 2D weight matrix (IC * H * W, OC), + # when using integer accumulation. For the reshaped 2D weight matrix (IC_per_group * H * W, OC), # sum over dimension 0 to get sums per output channel sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() sums_name = qweight_tensor_name + "_sums" @@ -201,16 +231,22 @@ def make_conv2d_q8ta_q8csw_custom_op( ) with graph_module.graph.inserting_before(match.output_node): + op_target = exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default + if is_depthwise_conv: + op_target = exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default + qconv_node = graph_module.graph.create_node( "call_function", - exir_ops.edge.et_vk.conv2d_q8ta_q8csw.default, + op_target, args=( - match.fp_input_node, + match.quantize_input_node, match.input_scales_node, match.input_zeros_node, match.weight_node, weight_sums_node, match.weight_scales_node, + match.output_scales_node, + match.output_zeros_node, match.bias_node, # Add bias after weight_scales [H, W], # Pass kernel size information before stride match.stride, @@ -221,4 +257,4 @@ def make_conv2d_q8ta_q8csw_custom_op( ) qconv_node.meta["val"] = match.output_node.meta["val"] - match.output_node.replace_all_uses_with(qconv_node) + match.quantize_output_node.replace_all_uses_with(qconv_node) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 67b646ae1a8..fe8cc83c481 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -139,6 +139,10 @@ utils::GPUMemoryLayout get_memory_layout( return utils::kHeightPacked; case vkgraph::VkMemoryLayout::TENSOR_CHANNELS_PACKED: return utils::kChannelsPacked; + case vkgraph::VkMemoryLayout::PACKED_INT8_4W4C: + return utils::kPackedInt8_4W4C; + case vkgraph::VkMemoryLayout::PACKED_INT8_4H4W: + return utils::kPackedInt8_4H4W; default: break; } diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index 4bc12208ce7..9d738bc386f 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -40,6 +40,8 @@ enum VkMemoryLayout : ubyte { TENSOR_WIDTH_PACKED = 0, TENSOR_HEIGHT_PACKED = 1, TENSOR_CHANNELS_PACKED = 2, + PACKED_INT8_4W4C = 3, + PACKED_INT8_4H4W = 4, DEFAULT_LAYOUT = 255, } diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index cf5326f40cf..236183ce42f 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -48,6 +48,8 @@ class VkMemoryLayout(IntEnum): TENSOR_WIDTH_PACKED = 0 TENSOR_HEIGHT_PACKED = 1 TENSOR_CHANNELS_PACKED = 2 + PACKED_INT8_4W4C = 3 + PACKED_INT8_4H4W = 4 DEFAULT_LAYOUT = 255 def __str__(self) -> str: diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index f8194f0b32c..f92cea64767 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -2482,6 +2482,7 @@ def forward(self, x): rtol=1e-1, ) + @unittest.skip("Cannot run on swiftshader due to no integer dot product support") def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence(self): """ Test a sequence of convolution layers quantized with PT2E quantization. @@ -2572,6 +2573,7 @@ def forward(self, x): rtol=1e-1, ) + @unittest.skip("Cannot run on swiftshader due to no integer dot product support") def test_vulkan_backend_xnnpack_pt2e_quantized_conv_sequence_all_reduced(self): """ Test a sequence of convolution layers quantized with PT2E quantization. diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 96f200eecbc..972a4f26c1b 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -348,6 +348,8 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: VkMemoryLayout.TENSOR_WIDTH_PACKED, VkMemoryLayout.TENSOR_HEIGHT_PACKED, VkMemoryLayout.TENSOR_CHANNELS_PACKED, + VkMemoryLayout.PACKED_INT8_4W4C, + VkMemoryLayout.PACKED_INT8_4H4W, } MemoryLayoutSet = Set[VkMemoryLayout] @@ -400,6 +402,12 @@ def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageEx height = (height + 3) // 4 elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: channels = (channels + 3) // 4 + elif layout == VkMemoryLayout.PACKED_INT8_4W4C: + width = (width + 3) // 4 + channels = (channels + 3) // 4 + elif layout == VkMemoryLayout.PACKED_INT8_4H4W: + height = (height + 3) // 4 + width = (width + 3) // 4 else: raise RuntimeError(f"Unsupported memory layout {layout}") @@ -692,6 +700,8 @@ def make_filtered_tensor_repset( ## Convenience TensorRepSet definitions +PACKED_INT8_4W4C_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W4C}, set()) + CONTIGUOUS_ANY = TensorRepSet( {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_WIDTH_PACKED} ) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 95da66494e0..2f91d97ff58 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -24,6 +24,7 @@ insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, + ReplaceQDQPass, SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) @@ -162,6 +163,7 @@ def preprocess( # noqa: C901 RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseQuantizedOpsTransform(program), + ReplaceQDQPass(), FoldQDQPass(program), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(),