From 22e0acb7c3b865cf0ef061f6e2cb363dad12c9cf Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 10 Sep 2025 17:57:41 -0700 Subject: [PATCH] Add from_int4_tensor in Int4PreshuffledTensor Summary: Added a classmethod `from_int4_tensor` to convert a plain `Int4Tensor` to `Int4PreshuffledTensor` This is in preparation for supporting Int4PreshuffledTensor in vllm, which requires the tensor to be sliced before inference, see https://github.com/pytorch/ao/blob/186aeb01664687d14108ada420c475cc783e1643/torchao/testing/utils.py#L429 for details but Int4PreshuffledTensor can't be easiliy sliced while also preserving alias, so we plan to slice the Plain int4 tensor instead and then convert to Int4PreshuffledTensor at a later stage. Next PR is going to add a top level API in prototype to convert from int4 tensor to int4 preshuffled tensor Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py -k test_from_int4_tensor Reviewers: Subscribers: Tasks: Tags: --- .../int4/test_int4_preshuffled_tensor.py | 30 ++++++++++++++++ .../workflows/int4/int4_preshuffled_tensor.py | 34 +++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py index df25b650b2..3c919740ae 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import copy import tempfile import unittest @@ -17,6 +18,7 @@ from torchao.quantization import ( Float8DynamicActivationInt4WeightConfig, + Int4PreshuffledTensor, Int4WeightOnlyConfig, quantize_, ) @@ -82,6 +84,34 @@ def forward(self, x): quantized = m(input) self.assertTrue(compute_error(original, quantized) > 18) + def test_from_int4_tensor(self): + """Test that constructing Int4PreshuffledTensor from Int4Tensor + is the same as quantizing the original weight to Int4PreshuffledTensor + """ + int4_config = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="plain", + ) + int4_preshuffled_config = Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="preshuffled", + ) + linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear2 = copy.deepcopy(linear1) + + quantize_(linear1, int4_config) + quantize_(linear2, int4_preshuffled_config) + + # now convert the linear1.weight to Int4PreshuffledTensor + w1_preshuffled = Int4PreshuffledTensor.from_int4_tensor(linear1.weight) + linear1.weight = torch.nn.Parameter(w1_preshuffled, requires_grad=False) + + example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),) + + output1 = linear1(*example_inputs) + output2 = linear2(*example_inputs) + self.assertEqual(output1, output2) + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) def test_to_device(self, config): for device in self.GPU_DEVICES: diff --git a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py index a2eca24e38..3f5a4e2b10 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -10,6 +10,7 @@ import torch +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor from torchao.utils import ( TorchAOBaseTensor, ) @@ -27,6 +28,7 @@ ): quantize_int4_preshuffle = None quantize_fp8_row = None + pack_int4 = None else: from fbgemm_gpu.experimental.gen_ai.quantize import ( quantize_fp8_row, @@ -185,6 +187,38 @@ def from_hp( row_scale=row_scale, ) + @classmethod + def from_int4_tensor( + cls, + tensor: Int4Tensor, + ): + assert isinstance(tensor, Int4Tensor), ( + f"Only conversion from Int4Tensor is supportd, got: {tensor}" + ) + # currently Int4Tensor only supports weight only, we can extend it to fp8-int4 a bit later + qdata = tensor.qdata + group_scale = tensor.scale + group_zero = tensor.zero_point + block_size = tensor.block_size + original_shape = tensor.shape + row_scale = None + + # Set scales to activation type. + group_scale = group_scale.to(torch.bfloat16) + group_zero = group_zero.to(torch.bfloat16) + # pack weights and scales into efficient preshuffled format + preshuffled_qdata, group_scale = torch.ops.fbgemm.preshuffle_i4( + qdata, group_scale + ) + return Int4PreshuffledTensor( + qdata=preshuffled_qdata, + group_scale=group_scale, + block_size=block_size, + shape=original_shape, + group_zero=group_zero, + row_scale=row_scale, + ) + implements = Int4PreshuffledTensor.implements