From 9858f29524bc33f455574403cfd0ef4c4af15002 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 31 Oct 2023 23:38:16 -0700 Subject: [PATCH] Fix result dtype conversion in QuantLinear.forward() Fixes: https://github.com/PanQiWei/AutoGPTQ/pull/385#discussion_r1378237609 Signed-Off By: Vivek Khandelwal --- auto_gptq/nn_modules/qlinear/qlinear_cuda.py | 4 ++-- auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/auto_gptq/nn_modules/qlinear/qlinear_cuda.py b/auto_gptq/nn_modules/qlinear/qlinear_cuda.py index 95169cfd..09203d3c 100644 --- a/auto_gptq/nn_modules/qlinear/qlinear_cuda.py +++ b/auto_gptq/nn_modules/qlinear/qlinear_cuda.py @@ -268,8 +268,8 @@ def forward(self, x: torch.Tensor): g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim] weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) weights = torch.cat(weights,dim=1) - out = torch.matmul(x, weights) - out = out.to(dtype=weights.dtype).reshape(out_shape) + out = torch.matmul(x, weights).to(dtype=weights.dtype) + out = out.reshape(out_shape) out = out + self.bias if self.bias is not None else out return out diff --git a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py index 25299bbf..81645cfc 100644 --- a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py +++ b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py @@ -268,8 +268,8 @@ def forward(self, x): weight = (scales * (weight - zeros)) weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) - out = torch.matmul(x, weight) - out = out.to(dtype=weight.dtype).reshape(out_shape) + out = torch.matmul(x, weight).to(dtype=weight.dtype) + out = out.reshape(out_shape) out = out + self.bias if self.bias is not None else out return out