Skip to content

Commit

Permalink
Fix result dtype conversion in QuantLinear.forward()
Browse files Browse the repository at this point in the history
Fixes: AutoGPTQ#385 (comment)

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
  • Loading branch information
vivekkhandelwal1 committed Nov 1, 2023
1 parent 878cbb0 commit 9858f29
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def forward(self, x: torch.Tensor):
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
weights = torch.cat(weights,dim=1)
out = torch.matmul(x, weights)
out = out.to(dtype=weights.dtype).reshape(out_shape)
out = torch.matmul(x, weights).to(dtype=weights.dtype)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

Expand Down
4 changes: 2 additions & 2 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def forward(self, x):
weight = (scales * (weight - zeros))
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

out = torch.matmul(x, weight)
out = out.to(dtype=weight.dtype).reshape(out_shape)
out = torch.matmul(x, weight).to(dtype=weight.dtype)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

Expand Down

0 comments on commit 9858f29

Please sign in to comment.