Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy
import tempfile
import unittest

Expand All @@ -17,6 +18,7 @@

from torchao.quantization import (
Float8DynamicActivationInt4WeightConfig,
Int4PreshuffledTensor,
Int4WeightOnlyConfig,
quantize_,
)
Expand Down Expand Up @@ -82,6 +84,34 @@ def forward(self, x):
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)

def test_from_int4_tensor(self):
"""Test that constructing Int4PreshuffledTensor from Int4Tensor
is the same as quantizing the original weight to Int4PreshuffledTensor
"""
int4_config = Int4WeightOnlyConfig(
group_size=128,
int4_packing_format="plain",
)
int4_preshuffled_config = Int4WeightOnlyConfig(
group_size=128,
int4_packing_format="preshuffled",
)
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
linear2 = copy.deepcopy(linear1)

quantize_(linear1, int4_config)
quantize_(linear2, int4_preshuffled_config)

# now convert the linear1.weight to Int4PreshuffledTensor
w1_preshuffled = Int4PreshuffledTensor.from_int4_tensor(linear1.weight)
linear1.weight = torch.nn.Parameter(w1_preshuffled, requires_grad=False)

example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),)

output1 = linear1(*example_inputs)
output2 = linear2(*example_inputs)
self.assertEqual(output1, output2)

@parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG])
def test_to_device(self, config):
for device in self.GPU_DEVICES:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch

from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from torchao.utils import (
TorchAOBaseTensor,
)
Expand All @@ -27,6 +28,7 @@
):
quantize_int4_preshuffle = None
quantize_fp8_row = None
pack_int4 = None
else:
from fbgemm_gpu.experimental.gen_ai.quantize import (
quantize_fp8_row,
Expand Down Expand Up @@ -185,6 +187,38 @@ def from_hp(
row_scale=row_scale,
)

@classmethod
def from_int4_tensor(
cls,
tensor: Int4Tensor,
):
assert isinstance(tensor, Int4Tensor), (
f"Only conversion from Int4Tensor is supportd, got: {tensor}"
)
# currently Int4Tensor only supports weight only, we can extend it to fp8-int4 a bit later
qdata = tensor.qdata
group_scale = tensor.scale
group_zero = tensor.zero_point
block_size = tensor.block_size
original_shape = tensor.shape
row_scale = None

# Set scales to activation type.
group_scale = group_scale.to(torch.bfloat16)
group_zero = group_zero.to(torch.bfloat16)
# pack weights and scales into efficient preshuffled format
preshuffled_qdata, group_scale = torch.ops.fbgemm.preshuffle_i4(
qdata, group_scale
)
return Int4PreshuffledTensor(
qdata=preshuffled_qdata,
group_scale=group_scale,
block_size=block_size,
shape=original_shape,
group_zero=group_zero,
row_scale=row_scale,
)


implements = Int4PreshuffledTensor.implements

Expand Down
Loading