From b24d0d4b3275e15d2611faa830b2b02168a74209 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 18 Sep 2025 10:38:55 -0700 Subject: [PATCH] Add main tensor conversion API for packed tensors Summary: Added `_convert_to_packed_tensor_based_on_current_hardware` to convert a tensor from the unpacked / plain version to a packed version This is to enable vllm for packed weights, vllm will do a slice for the quantized weight, but slicing is not always supported for all torchao tensor subclasses. So we want to first ship an plain / unpacked checkpoint and then convert to the packed version using this API Test Plan: pytest test/prototype/test_tensor_conversion.py Reviewers: Subscribers: Tasks: Tags: --- test/prototype/test_tensor_conversion.py | 34 ++++++++++++++++++++-- torchao/prototype/tensor_conversion/api.py | 29 +++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/test/prototype/test_tensor_conversion.py b/test/prototype/test_tensor_conversion.py index 2cee9a08ef..1647a13693 100644 --- a/test/prototype/test_tensor_conversion.py +++ b/test/prototype/test_tensor_conversion.py @@ -13,10 +13,18 @@ StretchedUnifTorchaoQuantizer, ) from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor -from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64 -from torchao.quantization import MappingType +from torchao.prototype.tensor_conversion.api import ( + _convert_model_for_aarch64, + convert_to_packed_tensor_based_on_current_hardware, +) +from torchao.quantization import ( + Int4PreshuffledTensor, + Int4Tensor, + MappingType, +) from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, quantize_, @@ -26,6 +34,7 @@ _is_kernel_library_loaded, ) from torchao.quantization.utils import compute_error +from torchao.utils import _is_fbgemm_genai_gpu_available class ToyLinearModelWithTiedEmbedding(torch.nn.Module): @@ -178,3 +187,24 @@ def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim): assert ep.graph_module.code.count(line) == cnt, ( f"expected {cnt} {line} in {ep.graph_module.code}" ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA") +@pytest.mark.skipif( + not _is_fbgemm_genai_gpu_available(), reason="Requires fbgemm-gpu-genai >= 1.2.0" +) +def test_int4_tensor_conversion(): + m = torch.nn.Sequential( + torch.nn.Linear(256, 512, dtype=torch.bfloat16, device="cuda") + ) + quantize_(m, Int4WeightOnlyConfig(group_size=128)) + weight = m[0].weight + assert isinstance(weight, Int4Tensor) + example_inputs = (torch.randn(32, 256, dtype=torch.bfloat16, device="cuda"),) + before_conversion = m(*example_inputs) + m[0].weight = torch.nn.Parameter( + convert_to_packed_tensor_based_on_current_hardware(weight), requires_grad=False + ) + after_conversion = m(*example_inputs) + assert isinstance(m[0].weight, Int4PreshuffledTensor) + assert torch.equal(before_conversion, after_conversion) diff --git a/torchao/prototype/tensor_conversion/api.py b/torchao/prototype/tensor_conversion/api.py index 63a1bcc2ef..6533e5de2d 100644 --- a/torchao/prototype/tensor_conversion/api.py +++ b/torchao/prototype/tensor_conversion/api.py @@ -7,7 +7,14 @@ import torch import torch.nn as nn -from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor +# TODO: move the function to torchao.utils +from torchao.dtypes.utils import is_device +from torchao.quantization import ( + Int4PreshuffledTensor, + Int4Tensor, + IntxUnpackedToInt8Tensor, +) +from torchao.utils import TorchAOBaseTensor, _is_fbgemm_genai_gpu_available def _convert_linear_weight_to_int8_lut_tensor(module): @@ -156,3 +163,23 @@ def _convert_model_for_aarch64( raise ValueError(f"Unexpected tensor_type={tensor_type}") return model + + +def convert_to_packed_tensor_based_on_current_hardware(tensor: TorchAOBaseTensor): + """Convert a plain / unpacked torchao tensor to a packed one based on hardware + + Goal is to have an optimized performance on current hardware, while also allow + us to + (1). distribute a single unpacked / plain format that can be used in multiple hardwares + (2). support the vLLM use case, where we need to slice the weights for distributed + inference. Since slice is not always supported in packed weight, we would like to first + load plain / unpacked weight, slice it and then convert to packed weight to get the best + inference speed + """ + if ( + isinstance(tensor, Int4Tensor) + and is_device("cuda", tensor.device) + and _is_fbgemm_genai_gpu_available() + ): + return Int4PreshuffledTensor.from_int4_tensor(tensor) + return tensor