Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quantize_.common import (
_choose_quant_func_and_quantize_tensor,
)
from torchao.quantization.utils import compute_error, get_block_size
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.testing.utils import TorchAOIntegrationTestCase
Expand Down Expand Up @@ -221,5 +225,66 @@ def test_available_gpu_kernels(self):
).check_count("triton_poi_fused", 1).run(code[0])


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
class TestInt8StaticQuant(TorchAOIntegrationTestCase):
@common_utils.parametrize("granularity", [PerRow(), PerTensor()])
@common_utils.parametrize("dtype", [torch.bfloat16])
def test_static_activation_per_row_int8_weight(self, granularity, dtype):
torch.compiler.reset()

M, N, K = 32, 32, 32
input_tensor = torch.randn(M, K, dtype=dtype, device="cuda")

model = torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype)
model_static_quant = copy.deepcopy(model)
model_dynamic_quant = copy.deepcopy(model)

model_out_baseline = model(input_tensor)

dynamic_config = Int8DynamicActivationInt8WeightConfig(
version=2, granularity=granularity
)
quantize_(model_dynamic_quant, dynamic_config)

dynamic_out_eager = model_dynamic_quant(input_tensor)
sqnr_dynamic_eager = compute_error(model_out_baseline, dynamic_out_eager)

model_dynamic_quant = torch.compile(model_dynamic_quant, fullgraph=True)

dynamic_out_compile = model_dynamic_quant(input_tensor)
sqnr_dynamic_compile = compute_error(model_out_baseline, dynamic_out_compile)

# we use eager scales to calculate
int8_input = _choose_quant_func_and_quantize_tensor(
input_tensor, model_dynamic_quant.weight.act_quant_kwargs
)

static_config = Int8StaticActivationInt8WeightConfig(
scale=int8_input.scale.detach().clone(),
granularity=granularity,
)
quantize_(model_static_quant, static_config)

static_out_eager = model_static_quant(input_tensor)
sqnr_static_eager = compute_error(model_out_baseline, static_out_eager)

model_static_quant = torch.compile(model_static_quant, fullgraph=True)

static_out_compile = model_dynamic_quant(input_tensor)
sqnr_static_compile = compute_error(model_out_baseline, static_out_compile)

assert (
sqnr_static_compile
== sqnr_static_eager
== sqnr_dynamic_compile
== sqnr_dynamic_eager
), "SQNR should be the same for all quantization methods and eager/compile"

# eager numerics should match exactly
# for compile, we can't compare dynamic vs static because we may get slightly different qparams when fused
torch.testing.assert_close(dynamic_out_eager, static_out_eager)


if __name__ == "__main__":
common_utils.run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
ModuleFqnToConfig,
Expand Down Expand Up @@ -150,6 +151,7 @@
"Int8DynamicActivationInt4WeightConfig",
"Int8DynamicActivationInt8WeightConfig",
"Int8DynamicActivationIntxWeightConfig",
"Int8StaticActivationInt8WeightConfig",
"Int4WeightOnlyConfig",
"Float8DynamicActivationInt4WeightConfig",
"Int8WeightOnlyConfig",
Expand Down
93 changes: 87 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
IntxPackingFormat,
IntxUnpackedToInt8Tensor,
QuantizeTensorToFloat8Kwargs,
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
Expand Down Expand Up @@ -1590,10 +1591,6 @@ def get_weight_block_size(x):
)
quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func)
else:
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
QuantizeTensorToInt8Kwargs,
)

assert config.granularity in {PerRow(), PerTensor()}, (
"Only PerRow and PerTensor are supported"
)
Expand Down Expand Up @@ -1621,7 +1618,10 @@ def get_weight_block_size(x):

@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig)
def _int8_dynamic_activation_int8_weight_transform(
module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig
module: torch.nn.Module,
config: Int8DynamicActivationInt8WeightConfig,
*,
parameter_name="weight",
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()
Expand All @@ -1634,7 +1634,88 @@ def _int8_dynamic_activation_int8_weight_transform(
module.weight, config
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


@dataclass
class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
"""
Configuration for applying int8 static symmetric quantization to both activation and weight

Args:
scale (torch.Tensor): The scale tensor for activation quantization.
granularity (Granularity): The granularity of quantization. PerRow() and PerTensor() are supported currently
act_mapping_type (MappingType): The mapping type for activation quantization. only SYMMETRIC is supported currently
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
version (int): the version of the config
"""

scale: torch.Tensor
granularity: Granularity = PerRow()
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
set_inductor_config: bool = True
version: int = 1

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8StaticActivationInt8WeightConfig"
)


@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig)
def _int8_static_activation_int8_weight_transform(
module: torch.nn.Module,
config: Int8StaticActivationInt8WeightConfig,
*,
parameter_name="weight",
):
assert config.granularity in {PerRow(), PerTensor()}, (
"Only PerRow and PerTensor is supported currently"
)
assert config.act_mapping_type == MappingType.SYMMETRIC, (
"asymmetric static quant not supported currently"
)
assert hasattr(module, parameter_name), (
f"Expected module to have attribute `{parameter_name}` but not found"
)

if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

activation_granularity = config.granularity
weight_granularity = config.granularity

quantized_tensor = Int8Tensor.from_hp(
getattr(module, parameter_name),
granularity=weight_granularity,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=activation_granularity,
mapping_type=config.act_mapping_type,
),
act_scale=config.scale.detach(),
)

setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_tensor, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)
return module


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import abc
from typing import ClassVar
from typing import ClassVar, Optional

import torch

Expand All @@ -31,7 +31,9 @@ def from_hp(cls, tensor, quant_kwargs: QuantizeTensorKwargs)


def _choose_quant_func_and_quantize_tensor(
tensor: torch.Tensor, quant_kwargs: QuantizeTensorKwargs
tensor: torch.Tensor,
quant_kwargs: QuantizeTensorKwargs,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Given a tensor and a kwargs container, chooses a derived dtype (float8, int8, etc) to quantize tensor to, based on the type of quant_kwargs
quantizes tensor to the derived dtype chosen in (1)
Expand Down Expand Up @@ -60,6 +62,7 @@ def _choose_quant_func_and_quantize_tensor(
tensor,
quant_kwargs.granularity,
mapping_type=quant_kwargs.mapping_type,
scale=scale,
)

raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")
47 changes: 32 additions & 15 deletions torchao/quantization/quantize_/workflows/int8/int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@ class Int8Tensor(TorchAOBaseTensor):
Tensor Attributes:
qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D)
scale: scale factors for dequantization
# TODO: Static quantization support using `static_scale`

Non-Tensor Attributes:
granularity: the granularity for quantization (e.g., PerRow(), PerTensor())
act_quant_kwargs: flags for dynamic activation quantization
"""

# TODO: Static quantization support using `static_scale`
tensor_data_names = ["qdata", "scale"]
optional_tensor_data_names = ["act_scale"]
tensor_attribute_names = ["block_size", "dtype"]
optional_tensor_attribute_names = [
"act_quant_kwargs",
Expand All @@ -73,6 +72,7 @@ def __new__(
scale: torch.Tensor,
block_size: List[int],
dtype: torch.dtype,
act_scale=None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
):
kwargs = {
Expand All @@ -88,6 +88,7 @@ def __init__(
scale: torch.Tensor,
block_size: List[int],
dtype: torch.dtype,
act_scale=None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
):
super().__init__()
Expand All @@ -96,13 +97,15 @@ def __init__(
self.block_size = block_size
# don't set dtype because this gets done in __new__
self.act_quant_kwargs = act_quant_kwargs
self.act_scale = act_scale

def __repr__(self):
return (
f"{self.__class__.__name__}("
f"act_quant_kwargs={self.act_quant_kwargs}, "
f"qdata={self.qdata}, "
f"scale={self.scale}, "
f"act_scale={self.act_scale}, "
f"block_size={self.block_size}, "
f"shape={self.shape}, "
f"device={self.device}, "
Expand All @@ -114,24 +117,35 @@ def from_hp(
cls,
hp_tensor: torch.Tensor,
granularity: Granularity,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
mapping_type=MappingType.SYMMETRIC,
scale: Optional[torch.Tensor] = None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
act_scale: Optional[torch.Tensor] = None,
):
"""Create Int8Tensor from high-precision tensor"""
block_size = get_block_size(hp_tensor.shape, granularity)
block_size = list(block_size)

scale, zero_point = choose_qparams_affine(
input=hp_tensor,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=torch.int8,
quant_min=-128,
quant_max=127,
scale_dtype=hp_tensor.dtype,
zero_point_dtype=torch.int8,
keepdim=True,
)
if scale is None:
scale, zero_point = choose_qparams_affine(
input=hp_tensor,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=torch.int8,
quant_min=-128,
quant_max=127,
scale_dtype=hp_tensor.dtype,
zero_point_dtype=torch.int8,
keepdim=True,
)
else:
# Scale can be provided in the case of static quant
assert scale.ndim == hp_tensor.ndim
assert all(
(hp_tensor.shape[i] // block_size[i]) == scale.shape[i]
for i in range(hp_tensor.ndim)
)
zero_point = torch.zeros_like(scale, dtype=torch.int8)

int_data = quantize_affine(
hp_tensor,
Expand All @@ -146,6 +160,7 @@ def from_hp(
scale,
block_size,
hp_tensor.dtype,
act_scale=act_scale,
act_quant_kwargs=act_quant_kwargs,
)

Expand Down Expand Up @@ -185,7 +200,9 @@ def _(func, types, args, kwargs):

if weight_tensor.act_quant_kwargs is not None:
activation_tensor = _choose_quant_func_and_quantize_tensor(
activation_tensor, weight_tensor.act_quant_kwargs
activation_tensor,
weight_tensor.act_quant_kwargs,
scale=weight_tensor.act_scale,
)
# Dynamic activation quantization path

Expand Down
Loading