Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Int4WeightOnlyQuantizer to set different dtype for scales_and_zeros #479

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,14 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
return k_divisible_by_groupsize

def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize, dtype=torch.bfloat16):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(
x.to(torch.bfloat16),
x.to(dtype),
weight_int4pack,
groupsize,
scales_and_zeros.to(torch.bfloat16)
scales_and_zeros.to(dtype)
).to(dtype=x.dtype)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
Expand All @@ -546,12 +546,12 @@ class WeightOnlyInt4Linear(torch.nn.Module):

def __init__(
self, in_features: int, out_features: int,
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8,
bias=False, device=None, dtype=torch.bfloat16, groupsize: int = 128, inner_k_tiles: int = 8,
) -> None:
super().__init__()
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
from model import find_multiple
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a module called model

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I think this is a relic of when gptq was more deeply coupled with gpt-fast

from .utils import find_multiple
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)

Expand All @@ -567,9 +567,10 @@ def __init__(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
self.dtype = dtype
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
torch.empty((in_features // groupsize, out_features, 2), dtype=self.dtype)
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
Expand All @@ -578,20 +579,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_int4(
input,
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
self.weight, self.scales_and_zeros, self.out_features, self.groupsize, self.dtype
)

def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None):
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None, dtype=torch.bfloat16):

for name, child in module.named_children():
if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)):
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
setattr(module, name, WeightOnlyInt4Linear(
child.in_features, child.out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles,
dtype=dtype,
))
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func)
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func, dtype)

class Int4WeightOnlyQuantizer(Quantizer):
def __init__(
Expand All @@ -600,6 +602,7 @@ def __init__(
padding_allowed: bool = True,
inner_k_tiles: Optional[int] = 8,
device: torch.device = torch.device("cuda"),
precision: torch.dtype = torch.bfloat16,
) -> None:
super().__init__()
assert inner_k_tiles in [2, 4, 8]
Expand All @@ -609,6 +612,7 @@ def __init__(
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed
self.device: torch.device = device
self.precision: torch.dtype = precision

@torch.no_grad()
def _create_quantized_state_dict(
Expand Down Expand Up @@ -648,6 +652,7 @@ def _create_quantized_state_dict(
weight,
4, # n_bit
self.groupsize,
self.precision, # precision for scales_and_zeros
larryliu0820 marked this conversation as resolved.
Show resolved Hide resolved
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)
Expand All @@ -660,6 +665,8 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
skip_layer_func=None,
dtype=self.precision,
)
return model

Expand Down
8 changes: 4 additions & 4 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16
).reshape(w.shape[0], -1)


def pack_tinygemm_scales_and_zeros(scales, zeros):
guard_dtype_size(scales, "scales", dtype=torch.bfloat16, size=zeros.size())
guard_dtype_size(zeros, "zeros", dtype=torch.bfloat16)
def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16):
guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size())
guard_dtype_size(zeros, "zeros", dtype=dtype)
return (
torch.cat(
[
Expand Down Expand Up @@ -376,7 +376,7 @@ def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bflo
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
w, scales, zeros, n_bit, groupsize
)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros, dtype)
return w_int4x8, scales_and_zeros


Expand Down
Loading