From bea304151f7d7e5907d5e095a99df80eacf7541e Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Wed, 5 Feb 2025 11:59:29 -0800 Subject: [PATCH 1/3] [ET-VK][int4] patch 4-bit source transformation quantizer to support linear modules with biases While LLaMa does not have biases, there are some models which will have biases in their linear modules. Add support in the source transform quantizer for biases. Differential Revision: [D69072087](https://our.internmc.facebook.com/intern/diff/D69072087/) [ghstack-poisoned] --- .../_passes/int4_weight_only_quantizer.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index 4821b613405..409cbb4b755 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -39,11 +39,12 @@ def __init__( from torchao.utils import find_multiple self.origin_in_features = in_features - in_features = find_multiple(in_features, (1024,)) + # pyre-ignore[6]: Incompatible parameter type + in_features = find_multiple(in_features, 1024) + self.use_bias = bias self.in_features = in_features self.out_features = out_features - assert not bias, "require bias=False" self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles @@ -80,6 +81,11 @@ def __init__( device=device, ), ) + if bias: + self.register_buffer( + "bias", + torch.empty((out_features,), dtype=torch.float32, device=device), + ) def forward(self, input: torch.Tensor) -> torch.Tensor: if self.padding: @@ -87,13 +93,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # The forward method is replaced. In the original implementation, the forward # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom # operator is called instead. - return torch.ops.et_vk.linear_weight_int4( + r = torch.ops.et_vk.linear_weight_int4( input, self.weight, self.groupsize, self.scales_and_zeros, self.inner_k_tiles, ) + if self.use_bias: + return r + self.bias + return r # This function is coped from torchao.quantization.GPTQ._replace_linear_int4 @@ -128,7 +137,7 @@ def _vk_replace_linear_int4( new_linear = linear_class( child.in_features, child.out_features, - bias=False, + bias=child.bias is not None, device=child.weight.device, groupsize=groupsize, inner_k_tiles=inner_k_tiles, @@ -138,6 +147,9 @@ def _vk_replace_linear_int4( if copy_weights and child.weight.device != torch.device("meta"): # pyre-fixme[16]: `Module` has no attribute `weight`. new_linear.weight = child.weight + if child.bias is not None: + # pyre-fixme[16]: `Module` has no attribute `bias`. + new_linear.bias = child.bias setattr(module, name, new_linear) else: _vk_replace_linear_int4( @@ -189,7 +201,6 @@ def _create_quantized_state_dict( mod.out_features < self.feature_limit and mod.in_features < self.feature_limit ): - assert not mod.bias out_features = mod.out_features in_features = mod.in_features logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") @@ -210,7 +221,8 @@ def _create_quantized_state_dict( logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) - padded_in_features = find_multiple(in_features, (1024,)) + # pyre-ignore[6]: Incompatible parameter type + padded_in_features = find_multiple(in_features, 1024) weight = F.pad( weight, pad=(0, padded_in_features - in_features) ) From b4349e4d4e0e9e8fbce0f3c912f1618713824fca Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Wed, 5 Feb 2025 14:04:10 -0800 Subject: [PATCH 2/3] Update on "[ET-VK][int4] patch 4-bit source transformation quantizer to support linear modules with biases" While LLaMa does not have biases, there are some models which will have biases in their linear modules. Add support in the source transform quantizer for biases. Differential Revision: [D69072087](https://our.internmc.facebook.com/intern/diff/D69072087/) [ghstack-poisoned] From 6e36efe05116cc75af49f6f3279e1444ff6efb03 Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Wed, 5 Feb 2025 14:19:34 -0800 Subject: [PATCH 3/3] Update on "[ET-VK][int4] patch 4-bit source transformation quantizer to support linear modules with biases" While LLaMa does not have biases, there are some models which will have biases in their linear modules. Add support in the source transform quantizer for biases. Differential Revision: [D69072087](https://our.internmc.facebook.com/intern/diff/D69072087/) [ghstack-poisoned]