Skip to content

Commit a951643

Browse files
authored
Add main tensor conversion API for packed tensors (#3029)
Summary: Added `_convert_to_packed_tensor_based_on_current_hardware` to convert a tensor from the unpacked / plain version to a packed version This is to enable vllm for packed weights, vllm will do a slice for the quantized weight, but slicing is not always supported for all torchao tensor subclasses. So we want to first ship an plain / unpacked checkpoint and then convert to the packed version using this API Test Plan: pytest test/prototype/test_tensor_conversion.py Reviewers: Subscribers: Tasks: Tags:
1 parent cfa39c8 commit a951643

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

test/prototype/test_tensor_conversion.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@
1313
StretchedUnifTorchaoQuantizer,
1414
)
1515
from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor
16-
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64
17-
from torchao.quantization import MappingType
16+
from torchao.prototype.tensor_conversion.api import (
17+
_convert_model_for_aarch64,
18+
convert_to_packed_tensor_based_on_current_hardware,
19+
)
20+
from torchao.quantization import (
21+
Int4PreshuffledTensor,
22+
Int4Tensor,
23+
MappingType,
24+
)
1825
from torchao.quantization.granularity import PerAxis, PerGroup
1926
from torchao.quantization.quant_api import (
27+
Int4WeightOnlyConfig,
2028
Int8DynamicActivationIntxWeightConfig,
2129
IntxWeightOnlyConfig,
2230
quantize_,
@@ -26,6 +34,7 @@
2634
_is_kernel_library_loaded,
2735
)
2836
from torchao.quantization.utils import compute_error
37+
from torchao.utils import _is_fbgemm_genai_gpu_available
2938

3039

3140
class ToyLinearModelWithTiedEmbedding(torch.nn.Module):
@@ -178,3 +187,24 @@ def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim):
178187
assert ep.graph_module.code.count(line) == cnt, (
179188
f"expected {cnt} {line} in {ep.graph_module.code}"
180189
)
190+
191+
192+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA")
193+
@pytest.mark.skipif(
194+
not _is_fbgemm_genai_gpu_available(), reason="Requires fbgemm-gpu-genai >= 1.2.0"
195+
)
196+
def test_int4_tensor_conversion():
197+
m = torch.nn.Sequential(
198+
torch.nn.Linear(256, 512, dtype=torch.bfloat16, device="cuda")
199+
)
200+
quantize_(m, Int4WeightOnlyConfig(group_size=128))
201+
weight = m[0].weight
202+
assert isinstance(weight, Int4Tensor)
203+
example_inputs = (torch.randn(32, 256, dtype=torch.bfloat16, device="cuda"),)
204+
before_conversion = m(*example_inputs)
205+
m[0].weight = torch.nn.Parameter(
206+
convert_to_packed_tensor_based_on_current_hardware(weight), requires_grad=False
207+
)
208+
after_conversion = m(*example_inputs)
209+
assert isinstance(m[0].weight, Int4PreshuffledTensor)
210+
assert torch.equal(before_conversion, after_conversion)

torchao/prototype/tensor_conversion/api.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
import torch
88
import torch.nn as nn
99

10-
from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor
10+
# TODO: move the function to torchao.utils
11+
from torchao.dtypes.utils import is_device
12+
from torchao.quantization import (
13+
Int4PreshuffledTensor,
14+
Int4Tensor,
15+
IntxUnpackedToInt8Tensor,
16+
)
17+
from torchao.utils import TorchAOBaseTensor, _is_fbgemm_genai_gpu_available
1118

1219

1320
def _convert_linear_weight_to_int8_lut_tensor(module):
@@ -156,3 +163,23 @@ def _convert_model_for_aarch64(
156163
raise ValueError(f"Unexpected tensor_type={tensor_type}")
157164

158165
return model
166+
167+
168+
def convert_to_packed_tensor_based_on_current_hardware(tensor: TorchAOBaseTensor):
169+
"""Convert a plain / unpacked torchao tensor to a packed one based on hardware
170+
171+
Goal is to have an optimized performance on current hardware, while also allow
172+
us to
173+
(1). distribute a single unpacked / plain format that can be used in multiple hardwares
174+
(2). support the vLLM use case, where we need to slice the weights for distributed
175+
inference. Since slice is not always supported in packed weight, we would like to first
176+
load plain / unpacked weight, slice it and then convert to packed weight to get the best
177+
inference speed
178+
"""
179+
if (
180+
isinstance(tensor, Int4Tensor)
181+
and is_device("cuda", tensor.device)
182+
and _is_fbgemm_genai_gpu_available()
183+
):
184+
return Int4PreshuffledTensor.from_int4_tensor(tensor)
185+
return tensor

0 commit comments

Comments
 (0)