diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 4e60fc7bd7e..59658e58f28 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -30,6 +30,21 @@ runtime.python_library( ] ) +runtime.python_library( + name = "squeeze_int4_linear_inputs", + srcs = [ + "squeeze_int4_linear_inputs.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//executorch/backends/vulkan:custom_ops_lib", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ] +) + runtime.python_library( name = "remove_asserts", srcs = ["remove_asserts.py"], @@ -99,6 +114,7 @@ runtime.python_library( ":remove_asserts", ":remove_local_scalar_dense", ":remove_redundant_ops", - ":tag_memory_meta_pass" + ":squeeze_int4_linear_inputs", + ":tag_memory_meta_pass", ] ) diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 8c29f5488f3..2a4a2b4b5c9 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -1,3 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( VkInt4WeightOnlyQuantizer, @@ -12,6 +20,9 @@ from executorch.backends.vulkan._passes.remove_redundant_ops import ( RemoveRedundantOpsTransform, ) +from executorch.backends.vulkan._passes.squeeze_int4_linear_inputs import ( + SqueezeInt4LinearInputs, +) from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ @@ -21,5 +32,6 @@ "RemoveAssertsTransform", "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", + "SqueezeInt4LinearInputs", "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index 4821b613405..409cbb4b755 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -39,11 +39,12 @@ def __init__( from torchao.utils import find_multiple self.origin_in_features = in_features - in_features = find_multiple(in_features, (1024,)) + # pyre-ignore[6]: Incompatible parameter type + in_features = find_multiple(in_features, 1024) + self.use_bias = bias self.in_features = in_features self.out_features = out_features - assert not bias, "require bias=False" self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles @@ -80,6 +81,11 @@ def __init__( device=device, ), ) + if bias: + self.register_buffer( + "bias", + torch.empty((out_features,), dtype=torch.float32, device=device), + ) def forward(self, input: torch.Tensor) -> torch.Tensor: if self.padding: @@ -87,13 +93,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # The forward method is replaced. In the original implementation, the forward # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom # operator is called instead. - return torch.ops.et_vk.linear_weight_int4( + r = torch.ops.et_vk.linear_weight_int4( input, self.weight, self.groupsize, self.scales_and_zeros, self.inner_k_tiles, ) + if self.use_bias: + return r + self.bias + return r # This function is coped from torchao.quantization.GPTQ._replace_linear_int4 @@ -128,7 +137,7 @@ def _vk_replace_linear_int4( new_linear = linear_class( child.in_features, child.out_features, - bias=False, + bias=child.bias is not None, device=child.weight.device, groupsize=groupsize, inner_k_tiles=inner_k_tiles, @@ -138,6 +147,9 @@ def _vk_replace_linear_int4( if copy_weights and child.weight.device != torch.device("meta"): # pyre-fixme[16]: `Module` has no attribute `weight`. new_linear.weight = child.weight + if child.bias is not None: + # pyre-fixme[16]: `Module` has no attribute `bias`. + new_linear.bias = child.bias setattr(module, name, new_linear) else: _vk_replace_linear_int4( @@ -189,7 +201,6 @@ def _create_quantized_state_dict( mod.out_features < self.feature_limit and mod.in_features < self.feature_limit ): - assert not mod.bias out_features = mod.out_features in_features = mod.in_features logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") @@ -210,7 +221,8 @@ def _create_quantized_state_dict( logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) - padded_in_features = find_multiple(in_features, (1024,)) + # pyre-ignore[6]: Incompatible parameter type + padded_in_features = find_multiple(in_features, 1024) weight = F.pad( weight, pad=(0, padded_in_features - in_features) ) diff --git a/backends/vulkan/_passes/squeeze_int4_linear_inputs.py b/backends/vulkan/_passes/squeeze_int4_linear_inputs.py new file mode 100644 index 00000000000..95fcef7f754 --- /dev/null +++ b/backends/vulkan/_passes/squeeze_int4_linear_inputs.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Dict, List, Tuple + +import executorch.backends.vulkan.custom_ops_lib # noqa: needed to access vk op +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue + +from torch.fx.node import Argument + + +class SqueezeInt4LinearInputs(ExportPass): + def call_operator( + self, + op, # pyre-ignore + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + def _squeezable(shape: List[int]) -> bool: + return len(shape) > 2 and 1 in shape + + if op != exir_ops.edge.et_vk.linear_weight_int4.default: + return super().call_operator(op, args, kwargs, meta) + + # pyre-ignore[16]: `None` has no attribute `node` + input_shape = args[0].node.meta["val"].shape + output_shape = meta["val"].shape + if not _squeezable(input_shape): + return super().call_operator(op, args, kwargs, meta) + + # squeeze input tensor + squeeze_shape = list(input_shape) + while _squeezable(squeeze_shape): + squeeze_shape.remove(1) + + squeeze_out = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (args[0], squeeze_shape), + kwargs, + meta, + ) + # call linear on squeezed output + new_args = (squeeze_out, *args[1:]) + linear_out = super().call_operator( + op, + new_args, + kwargs, + meta, + ) + # unsqueeze output + unsqueeze_shape = list(output_shape) + return super().call_operator( + exir_ops.edge.aten.view_copy.default, + (linear_out, unsqueeze_shape), + kwargs, + meta, + ) diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 1042c23bcb3..ea6601502f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -260,9 +260,6 @@ void check_q_4w_linear_args( const int group_size_val = graph.extract_scalar(group_size); VK_CHECK_COND(K % group_size_val == 0); - VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); - VK_CHECK_COND(graph.has_standard_axis_map(mat1)); VK_CHECK_COND(graph.has_standard_axis_map(out)); } @@ -320,13 +317,32 @@ void add_q_4w_linear_node( const uint32_t group_size_val = graph.extract_scalar(group_size); + ValueRef mat1_W_packed = mat1; + ValueRef out_W_packed = out; + auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); + // Create temporary tensors to store the width packed versions of mat1 and out + TmpTensor mat1_tmp( + &graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked); + TmpTensor out_tmp( + &graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked); + if (storage_type == utils::kTexture3D) { + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + // Ensure mat1 is width packed + mat1_W_packed = mat1_tmp; + viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + // Ensure out is packed correctly + out_W_packed = out_tmp; + } + } + vkapi::ParamsBindList ubos({}); - ubos.append(graph.logical_limits_ubo(out)); - ubos.append(graph.sizes_ubo(mat1)); + ubos.append(graph.logical_limits_ubo(out_W_packed)); + ubos.append(graph.sizes_ubo(mat1_W_packed)); ubos.append(graph.strides_ubo(mat2)); ubos.append(graph.strides_ubo(scales_and_zeros)); - utils::uvec3 global_wg_size = graph.logical_limits_of(out); + utils::uvec3 global_wg_size = graph.logical_limits_of(out_W_packed); utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); graph.execute_nodes().emplace_back(new DispatchNode( @@ -335,8 +351,9 @@ void add_q_4w_linear_node( global_wg_size, local_wg_size, // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}}, + {{out_W_packed, vkapi::MemoryAccessType::WRITE}, + {{mat1_W_packed, mat2, scales_and_zeros}, + vkapi::MemoryAccessType::READ}}, // Shader params buffers ubos, // Specialization Constants @@ -344,6 +361,10 @@ void add_q_4w_linear_node( // Resizing Logic resize_q_4w_linear_node, {})); + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(out) != WHCN::kWidthDim) { + viewFn(graph, {out_W_packed, graph.add_none(), out}); + } } void linear_weight_int4( diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 85b962784e7..150ae32dfce 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -328,6 +328,7 @@ def define_common_targets(is_fbcode = False): "//executorch/backends/transforms:fuse_dequant_linear", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:remove_clone_ops", + "//executorch/backends/transforms:view_copy_to_squeeze_unsqueeze", "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/backends/vulkan/serialization:lib", "//executorch/exir/backend:backend_details", diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 02ca8d2bec5..c6b444e5def 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -19,10 +19,14 @@ from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform +from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import ( + ViewCopyToSqueezeUnsqueezePass, +) from executorch.backends.vulkan._passes import ( insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, + SqueezeInt4LinearInputs, TagMemoryMetaPass, ) @@ -149,7 +153,9 @@ def preprocess( # noqa: C901 RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseDequantLinearPass(), + SqueezeInt4LinearInputs(), FuseViewCopyTransform(), + ViewCopyToSqueezeUnsqueezePass(), FuseBatchNormWithConvPass(program), FuseClampPass(), ],