Skip to content

Commit

Permalink
Add support for CPU Inference (AutoGPTQ#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Oct 31, 2023
1 parent 0d9beff commit 878cbb0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
4 changes: 2 additions & 2 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def make_quant(
out_features = tmp.weight.shape[1]
if (not(desc_act) or group_size == -1) and not use_triton and not use_qigen:
new_layer = QuantLinear(
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable, weight_dtype=tmp.weight.dtype
)
else:
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable)
new_layer = QuantLinear(bits, group_size, in_features, out_features, True, trainable=trainable, weight_dtype=tmp.weight.dtype)
new_layer.device = ori_layer_device
setattr(module, attr, new_layer.to(ori_layer_device))
for name1, child in module.named_children():
Expand Down
17 changes: 9 additions & 8 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(
outfeatures,
bias,
kernel_switch_threshold=128,
trainable=False
trainable=False,
weight_dtype=torch.float16,
):
super().__init__()
global _autogptq_cuda_available
Expand All @@ -55,14 +56,14 @@ def __init__(
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=weight_dtype)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer('bias', torch.zeros((outfeatures), dtype=weight_dtype))
else:
self.bias = None

Expand Down Expand Up @@ -105,9 +106,9 @@ def pack(self, linear, scales, zeros, g_idx=None):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
self.scales = scales.clone().to(dtype=linear.weight.dtype)
if linear.bias is not None:
self.bias = linear.bias.clone().half()
self.bias = linear.bias.clone().to(dtype=linear.weight.dtype)

intweight = []
for idx in range(self.infeatures):
Expand Down Expand Up @@ -267,10 +268,10 @@ def forward(self, x: torch.Tensor):
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
weights = torch.cat(weights,dim=1)
out = torch.matmul(x.to(weights.dtype), weights)
out = out.half().reshape(out_shape)
out = torch.matmul(x, weights)
out = out.to(dtype=weights.dtype).reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out.to(x.dtype)
return out


__all__ = ["QuantLinear"]
17 changes: 9 additions & 8 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(
bias,
use_cuda_fp16=True,
kernel_switch_threshold=128,
trainable=False
trainable=False,
weight_dtype=torch.float16,
):
super().__init__()
global _autogptq_cuda_available
Expand All @@ -54,15 +55,15 @@ def __init__(
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=weight_dtype)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
)

if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer('bias', torch.zeros((outfeatures), dtype=weight_dtype))
else:
self.bias = None
self.half_indim = self.infeatures // 2
Expand Down Expand Up @@ -105,9 +106,9 @@ def pack(self, linear, scales, zeros, g_idx):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
self.scales = scales.clone().to(dtype=linear.weight.dtype)
if linear.bias is not None:
self.bias = linear.bias.clone().half()
self.bias = linear.bias.clone().to(dtype=linear.weight.dtype)

intweight = []
for idx in range(self.infeatures):
Expand Down Expand Up @@ -267,10 +268,10 @@ def forward(self, x):
weight = (scales * (weight - zeros))
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

out = torch.matmul(x.to(weight.dtype), weight)
out = out.half().reshape(out_shape)
out = torch.matmul(x, weight)
out = out.to(dtype=weight.dtype).reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out.to(x_dtype)
return out


__all__ = ["QuantLinear"]

0 comments on commit 878cbb0

Please sign in to comment.