From df3b18ace9815c977e3a25bde09b3310e0a243b9 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 5 Jul 2024 11:51:38 -0700 Subject: [PATCH 1/2] Allow Int4WeightOnlyQuantizer to set different dtype for scales_and_zeros As titled. Currently `Int4WeightOnlyQuantizer` is hardcoded to return `scales_and_zeros` with dtype `torch.bfloat16`. Adding `dtype` argument into the flow so that it can be different dtype. --- torchao/quantization/GPTQ.py | 25 ++++++++++++++++--------- torchao/quantization/utils.py | 8 ++++---- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 038eae8d4..f93a82f81 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -525,14 +525,14 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles return k_divisible_by_groupsize -def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize, dtype=torch.bfloat16): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) c = torch.ops.aten._weight_int4pack_mm( - x.to(torch.bfloat16), + x.to(dtype), weight_int4pack, groupsize, - scales_and_zeros.to(torch.bfloat16) + scales_and_zeros.to(dtype) ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) @@ -546,12 +546,12 @@ class WeightOnlyInt4Linear(torch.nn.Module): def __init__( self, in_features: int, out_features: int, - bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, + bias=False, device=None, dtype=torch.bfloat16, groupsize: int = 128, inner_k_tiles: int = 8, ) -> None: super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) if self.padding: - from model import find_multiple + from .utils import find_multiple self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) @@ -567,9 +567,10 @@ def __init__( "weight", torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) ) + self.dtype = dtype self.register_buffer( "scales_and_zeros", - torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + torch.empty((in_features // groupsize, out_features, 2), dtype=self.dtype) ) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -578,10 +579,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) return linear_forward_int4( input, - self.weight, self.scales_and_zeros, self.out_features, self.groupsize + self.weight, self.scales_and_zeros, self.out_features, self.groupsize, self.dtype ) -def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None): +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None, dtype=torch.bfloat16): for name, child in module.named_children(): if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): @@ -589,9 +590,10 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_ setattr(module, name, WeightOnlyInt4Linear( child.in_features, child.out_features, bias=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles, + dtype=dtype, )) else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func) + replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func, dtype) class Int4WeightOnlyQuantizer(Quantizer): def __init__( @@ -600,6 +602,7 @@ def __init__( padding_allowed: bool = True, inner_k_tiles: Optional[int] = 8, device: torch.device = torch.device("cuda"), + precision: torch.dtype = torch.bfloat16, ) -> None: super().__init__() assert inner_k_tiles in [2, 4, 8] @@ -609,6 +612,7 @@ def __init__( self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed self.device: torch.device = device + self.precision: torch.dtype = precision @torch.no_grad() def _create_quantized_state_dict( @@ -648,6 +652,7 @@ def _create_quantized_state_dict( weight, 4, # n_bit self.groupsize, + self.precision, # precision for scales_and_zeros ) weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) @@ -660,6 +665,8 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: self.groupsize, self.inner_k_tiles, self.padding_allowed, + skip_layer_func=None, + dtype=self.precision, ) return model diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index e6c24ea27..cab8f9b62 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -307,9 +307,9 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16 ).reshape(w.shape[0], -1) -def pack_tinygemm_scales_and_zeros(scales, zeros): - guard_dtype_size(scales, "scales", dtype=torch.bfloat16, size=zeros.size()) - guard_dtype_size(zeros, "zeros", dtype=torch.bfloat16) +def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16): + guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size()) + guard_dtype_size(zeros, "zeros", dtype=dtype) return ( torch.cat( [ @@ -376,7 +376,7 @@ def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bflo w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( w, scales, zeros, n_bit, groupsize ) - scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros, dtype) return w_int4x8, scales_and_zeros From f3c320a8f2d576bfc49609f68427c3882271484f Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 5 Jul 2024 14:01:01 -0700 Subject: [PATCH 2/2] Add comment --- torchao/quantization/GPTQ.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index f93a82f81..99b7621ee 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -612,6 +612,7 @@ def __init__( self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed self.device: torch.device = device + # precision and dtype are being used interchangeably here self.precision: torch.dtype = precision @torch.no_grad() @@ -652,7 +653,7 @@ def _create_quantized_state_dict( weight, 4, # n_bit self.groupsize, - self.precision, # precision for scales_and_zeros + self.precision, # dtype for scales_and_zeros ) weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)