Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions test/prototype/test_tensor_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@
StretchedUnifTorchaoQuantizer,
)
from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64
from torchao.quantization import MappingType
from torchao.prototype.tensor_conversion.api import (
_convert_model_for_aarch64,
convert_to_packed_tensor_based_on_current_hardware,
)
from torchao.quantization import (
Int4PreshuffledTensor,
Int4Tensor,
MappingType,
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int4WeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
quantize_,
Expand All @@ -26,6 +34,7 @@
_is_kernel_library_loaded,
)
from torchao.quantization.utils import compute_error
from torchao.utils import _is_fbgemm_genai_gpu_available


class ToyLinearModelWithTiedEmbedding(torch.nn.Module):
Expand Down Expand Up @@ -178,3 +187,24 @@ def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim):
assert ep.graph_module.code.count(line) == cnt, (
f"expected {cnt} {line} in {ep.graph_module.code}"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA")
@pytest.mark.skipif(
not _is_fbgemm_genai_gpu_available(), reason="Requires fbgemm-gpu-genai >= 1.2.0"
)
def test_int4_tensor_conversion():
m = torch.nn.Sequential(
torch.nn.Linear(256, 512, dtype=torch.bfloat16, device="cuda")
)
quantize_(m, Int4WeightOnlyConfig(group_size=128))
weight = m[0].weight
assert isinstance(weight, Int4Tensor)
example_inputs = (torch.randn(32, 256, dtype=torch.bfloat16, device="cuda"),)
before_conversion = m(*example_inputs)
m[0].weight = torch.nn.Parameter(
convert_to_packed_tensor_based_on_current_hardware(weight), requires_grad=False
)
after_conversion = m(*example_inputs)
assert isinstance(m[0].weight, Int4PreshuffledTensor)
assert torch.equal(before_conversion, after_conversion)
29 changes: 28 additions & 1 deletion torchao/prototype/tensor_conversion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
import torch
import torch.nn as nn

from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor
# TODO: move the function to torchao.utils
from torchao.dtypes.utils import is_device
from torchao.quantization import (
Int4PreshuffledTensor,
Int4Tensor,
IntxUnpackedToInt8Tensor,
)
from torchao.utils import TorchAOBaseTensor, _is_fbgemm_genai_gpu_available


def _convert_linear_weight_to_int8_lut_tensor(module):
Expand Down Expand Up @@ -156,3 +163,23 @@ def _convert_model_for_aarch64(
raise ValueError(f"Unexpected tensor_type={tensor_type}")

return model


def convert_to_packed_tensor_based_on_current_hardware(tensor: TorchAOBaseTensor):
"""Convert a plain / unpacked torchao tensor to a packed one based on hardware

Goal is to have an optimized performance on current hardware, while also allow
us to
(1). distribute a single unpacked / plain format that can be used in multiple hardwares
(2). support the vLLM use case, where we need to slice the weights for distributed
inference. Since slice is not always supported in packed weight, we would like to first
load plain / unpacked weight, slice it and then convert to packed weight to get the best
inference speed
"""
if (
isinstance(tensor, Int4Tensor)
and is_device("cuda", tensor.device)
and _is_fbgemm_genai_gpu_available()
):
return Int4PreshuffledTensor.from_int4_tensor(tensor)
return tensor
Loading