From 6efceea092c496c8ac113627139f4fc4c6c6e739 Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Wed, 5 Feb 2025 16:56:01 -0800 Subject: [PATCH 1/2] [ET-VK][int4] patch 4-bit source transformation quantizer to support linear modules with biases Pull Request resolved: https://github.com/pytorch/executorch/pull/8224 While LLaMa does not have biases, there are some models which will have biases in their linear modules. Add support in the source transform quantizer for biases. ghstack-source-id: 264952608 @exported-using-ghexport Differential Revision: [D69072087](https://our.internmc.facebook.com/intern/diff/D69072087/) --- .../_passes/int4_weight_only_quantizer.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index 4821b613405..409cbb4b755 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -39,11 +39,12 @@ def __init__( from torchao.utils import find_multiple self.origin_in_features = in_features - in_features = find_multiple(in_features, (1024,)) + # pyre-ignore[6]: Incompatible parameter type + in_features = find_multiple(in_features, 1024) + self.use_bias = bias self.in_features = in_features self.out_features = out_features - assert not bias, "require bias=False" self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles @@ -80,6 +81,11 @@ def __init__( device=device, ), ) + if bias: + self.register_buffer( + "bias", + torch.empty((out_features,), dtype=torch.float32, device=device), + ) def forward(self, input: torch.Tensor) -> torch.Tensor: if self.padding: @@ -87,13 +93,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # The forward method is replaced. In the original implementation, the forward # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom # operator is called instead. - return torch.ops.et_vk.linear_weight_int4( + r = torch.ops.et_vk.linear_weight_int4( input, self.weight, self.groupsize, self.scales_and_zeros, self.inner_k_tiles, ) + if self.use_bias: + return r + self.bias + return r # This function is coped from torchao.quantization.GPTQ._replace_linear_int4 @@ -128,7 +137,7 @@ def _vk_replace_linear_int4( new_linear = linear_class( child.in_features, child.out_features, - bias=False, + bias=child.bias is not None, device=child.weight.device, groupsize=groupsize, inner_k_tiles=inner_k_tiles, @@ -138,6 +147,9 @@ def _vk_replace_linear_int4( if copy_weights and child.weight.device != torch.device("meta"): # pyre-fixme[16]: `Module` has no attribute `weight`. new_linear.weight = child.weight + if child.bias is not None: + # pyre-fixme[16]: `Module` has no attribute `bias`. + new_linear.bias = child.bias setattr(module, name, new_linear) else: _vk_replace_linear_int4( @@ -189,7 +201,6 @@ def _create_quantized_state_dict( mod.out_features < self.feature_limit and mod.in_features < self.feature_limit ): - assert not mod.bias out_features = mod.out_features in_features = mod.in_features logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") @@ -210,7 +221,8 @@ def _create_quantized_state_dict( logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) - padded_in_features = find_multiple(in_features, (1024,)) + # pyre-ignore[6]: Incompatible parameter type + padded_in_features = find_multiple(in_features, 1024) weight = F.pad( weight, pad=(0, padded_in_features - in_features) ) From 65117d57ccd0c5f2ea67baeb4dfe36bde16104a6 Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Wed, 5 Feb 2025 16:56:02 -0800 Subject: [PATCH 2/2] [ET-VK][int4] patch 4-bit linear op for ensuring w-packed in/out Pull Request resolved: https://github.com/pytorch/executorch/pull/8225 If the partitioner is using channels-packed setting for activations, then the checks will throw. Remove the checks and conditionally re-pack the input/output tensors if they are not width-packed. ghstack-source-id: 264952605 @exported-using-ghexport Differential Revision: [D68813946](https://our.internmc.facebook.com/intern/diff/D68813946/) --- .../graph/ops/impl/QuantizedLinear.cpp | 37 +++++++++++++++---- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 1042c23bcb3..ea6601502f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -260,9 +260,6 @@ void check_q_4w_linear_args( const int group_size_val = graph.extract_scalar(group_size); VK_CHECK_COND(K % group_size_val == 0); - VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); - VK_CHECK_COND(graph.has_standard_axis_map(mat1)); VK_CHECK_COND(graph.has_standard_axis_map(out)); } @@ -320,13 +317,32 @@ void add_q_4w_linear_node( const uint32_t group_size_val = graph.extract_scalar(group_size); + ValueRef mat1_W_packed = mat1; + ValueRef out_W_packed = out; + auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); + // Create temporary tensors to store the width packed versions of mat1 and out + TmpTensor mat1_tmp( + &graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked); + TmpTensor out_tmp( + &graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked); + if (storage_type == utils::kTexture3D) { + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + // Ensure mat1 is width packed + mat1_W_packed = mat1_tmp; + viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + // Ensure out is packed correctly + out_W_packed = out_tmp; + } + } + vkapi::ParamsBindList ubos({}); - ubos.append(graph.logical_limits_ubo(out)); - ubos.append(graph.sizes_ubo(mat1)); + ubos.append(graph.logical_limits_ubo(out_W_packed)); + ubos.append(graph.sizes_ubo(mat1_W_packed)); ubos.append(graph.strides_ubo(mat2)); ubos.append(graph.strides_ubo(scales_and_zeros)); - utils::uvec3 global_wg_size = graph.logical_limits_of(out); + utils::uvec3 global_wg_size = graph.logical_limits_of(out_W_packed); utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); graph.execute_nodes().emplace_back(new DispatchNode( @@ -335,8 +351,9 @@ void add_q_4w_linear_node( global_wg_size, local_wg_size, // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}}, + {{out_W_packed, vkapi::MemoryAccessType::WRITE}, + {{mat1_W_packed, mat2, scales_and_zeros}, + vkapi::MemoryAccessType::READ}}, // Shader params buffers ubos, // Specialization Constants @@ -344,6 +361,10 @@ void add_q_4w_linear_node( // Resizing Logic resize_q_4w_linear_node, {})); + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(out) != WHCN::kWidthDim) { + viewFn(graph, {out_W_packed, graph.add_none(), out}); + } } void linear_weight_int4(