From f60e57aa21607a7911b589ad6ad1cd8c3caf5cba Mon Sep 17 00:00:00 2001 From: XiaoWang Date: Wed, 24 Sep 2025 23:59:53 -0700 Subject: [PATCH 1/3] Adds _weight_int8pack_mm pass for woq-int8 --- torchao/dtypes/uintx/plain_layout.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index 3551214d7e..f2ac3606b9 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -253,16 +253,21 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() + w_vals_int8 = weight_tensor.tensor_impl.int_data scale = weight_tensor.tensor_impl.scale - m = torch.mm( - input_tensor.reshape(-1, input_tensor.shape[-1]), - w_vals_int8_t.to(input_tensor.dtype), - ) - y = m * scale.to(m.dtype) + try: + y = torch.ops.aten._weight_int8pack_mm(input_tensor.reshape(-1, input_tensor.shape[-1]), w_vals_int8, scale.to(input_tensor.dtype)) + except Exception: + w_vals_int8_t = w_vals_int8.t() + scale = weight_tensor.tensor_impl.scale + m = torch.mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8_t.to(input_tensor.dtype), + ) + y = m * scale.to(m.dtype) y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) if bias is not None: - y += bias.to(m.dtype) + y += bias.to(input_tensor.dtype) return y From b5572b8eec99374811cfb961ee274571fade903e Mon Sep 17 00:00:00 2001 From: XiaoWang Date: Thu, 25 Sep 2025 01:03:40 -0700 Subject: [PATCH 2/3] Adds _weight_int8pack_mm pass for woq-int8 --- torchao/dtypes/uintx/plain_layout.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index f2ac3606b9..cbee2e0022 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -256,7 +256,11 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): w_vals_int8 = weight_tensor.tensor_impl.int_data scale = weight_tensor.tensor_impl.scale try: - y = torch.ops.aten._weight_int8pack_mm(input_tensor.reshape(-1, input_tensor.shape[-1]), w_vals_int8, scale.to(input_tensor.dtype)) + y = torch.ops.aten._weight_int8pack_mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8, + scale.to(input_tensor.dtype), + ) except Exception: w_vals_int8_t = w_vals_int8.t() scale = weight_tensor.tensor_impl.scale From 39d29715f11442fc32b490eb272b357e5fb5b3c3 Mon Sep 17 00:00:00 2001 From: "Xiao, Wang" <109140002+xiaowangintel@users.noreply.github.com> Date: Mon, 29 Sep 2025 10:50:59 +0800 Subject: [PATCH 3/3] Update plain_layout.py --- torchao/dtypes/uintx/plain_layout.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index cbee2e0022..30c2a39c26 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -263,7 +263,6 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): ) except Exception: w_vals_int8_t = w_vals_int8.t() - scale = weight_tensor.tensor_impl.scale m = torch.mm( input_tensor.reshape(-1, input_tensor.shape[-1]), w_vals_int8_t.to(input_tensor.dtype),