Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from torchao.prototype.awq import AWQConfig, AWQStep
from torchao.prototype.int4_opaque_tensor import Int4WeightOnlyOpaqueTensorConfig
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.utils import _is_fbgemm_gpu_genai_available, torch_version_at_least

Expand Down Expand Up @@ -76,7 +77,7 @@ def forward(self, x):
# Note: the functionality unit test doesn't work for hqq
Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d"),
],
"cpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="opaque")],
"cpu": [Int4WeightOnlyOpaqueTensorConfig(group_size=128)],
"xpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="plain_int32")],
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
run_tests,
)

from torchao.quantization import (
Int4WeightOnlyConfig,
quantize_,
)
from torchao.prototype.int4_opaque_tensor import Int4WeightOnlyOpaqueTensorConfig
from torchao.quantization import quantize_
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.utils import compute_error
from torchao.utils import (
Expand All @@ -27,9 +25,8 @@


def get_config(group_size, use_hqq):
return Int4WeightOnlyConfig(
return Int4WeightOnlyOpaqueTensorConfig(
group_size=group_size,
int4_packing_format="opaque",
int4_choose_qparams_algorithm="hqq" if use_hqq else "tinygemm",
)

Expand Down Expand Up @@ -68,7 +65,7 @@ def test_module_path(self, dtype, use_hqq):
quantize_(linear, get_config(group_size=128, use_hqq=use_hqq))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4OpaqueTensor'>",
"<class 'torchao.prototype.int4_opaque_tensor.Int4OpaqueTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
Expand All @@ -77,7 +74,7 @@ def test_module_path(self, dtype, use_hqq):
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4OpaqueTensor'>",
"<class 'torchao.prototype.int4_opaque_tensor.Int4OpaqueTensor'>",
)

@parametrize("use_hqq", [True, False])
Expand Down
9 changes: 3 additions & 6 deletions torchao/prototype/awq/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchao.prototype.awq import (
AWQConfig,
)
from torchao.prototype.int4_opaque_tensor import Int4WeightOnlyOpaqueTensorConfig
from torchao.quantization import Int4WeightOnlyConfig, quantize_


Expand Down Expand Up @@ -259,9 +260,7 @@ def quantize_and_eval(
group_size=group_size, int4_packing_format="plain_int32"
)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="opaque"
)
base_config = Int4WeightOnlyOpaqueTensorConfig(group_size=group_size)
else:
assert False, "Unsupported device: {}".format(device)
print(f"running {quant} prepare and calibrate")
Expand Down Expand Up @@ -301,9 +300,7 @@ def quantize_and_eval(
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="opaque"
)
base_config = Int4WeightOnlyOpaqueTensorConfig(group_size=group_size)
else:
assert False, "Unsupported device: {}".format(device)
quantize_(model, base_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Float8DynamicActivationFloat8WeightOpaqueTensorConfig(AOBaseConfig):

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
"torchao.prototype.float8_opaque_tensor.Float8DynamicActivationFloat8WeightOpaqueTensorConfig"
)
activation_granularity, weight_granularity = (
Float8OpaqueTensor._normalize_and_check_granularity(self.granularity)
Expand Down
7 changes: 7 additions & 0 deletions torchao/prototype/int4_opaque_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .inference_workflow import Int4WeightOnlyOpaqueTensorConfig
from .int4_opaque_tensor import Int4OpaqueTensor

__all__ = [
"Int4OpaqueTensor",
"Int4WeightOnlyOpaqueTensorConfig",
]
88 changes: 88 additions & 0 deletions torchao/prototype/int4_opaque_tensor/inference_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import logging
from dataclasses import dataclass

import torch

import torchao
from torchao.core.config import AOBaseConfig

logger = logging.getLogger(__name__)
import types

from torchao.quantization.quant_api import _linear_extra_repr
from torchao.quantization.quantize_.workflows import (
Int4ChooseQParamsAlgorithm,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)

from .int4_opaque_tensor import Int4OpaqueTensor


@dataclass
class Int4WeightOnlyOpaqueTensorConfig(AOBaseConfig):
"""
Configuration for int4 weight only quantization, only groupwise quantization is supported right now.

Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32]
`int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4, currently support TINYGEMM ("tinygemm") and HQQ ("hqq")
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values
"""

group_size: int = 128
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = (
Int4ChooseQParamsAlgorithm.TINYGEMM
)
set_inductor_config: bool = True

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.prototype.int4_opaque_tensor.Int4WeightOnlyOpaqueTensorConfig"
)


def _int4_weight_only_opaque_tensor_quantize(weight, config):
group_size = config.group_size
int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm

if weight.shape[-1] % group_size != 0:
logger.info(
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
)
return weight

block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])

block_size = list(block_size)

new_weight = Int4OpaqueTensor.from_hp(
weight,
block_size,
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
)
return new_weight


@register_quantize_module_handler(Int4WeightOnlyOpaqueTensorConfig)
def _int4_weight_only_transform(
module: torch.nn.Module, config: Int4WeightOnlyOpaqueTensorConfig
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

assert hasattr(module, "weight"), (
"applying int4 weight only quant requires module to have weight attribute"
+ " but {module} does not have one"
)
new_weight = _int4_weight_only_opaque_tensor_quantize(module.weight, config)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
_choose_qparams_and_quantize_affine_hqq,
_quantize_affine_tinygemm,
)
from torchao.quantization.quantize_.workflows import (
Int4ChooseQParamsAlgorithm,
)
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
from torchao.utils import (
TorchAOBaseTensor,
)

from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm

__all__ = [
"Int4OpaqueTensor",
]
Expand Down Expand Up @@ -241,7 +242,7 @@ def _(func, types, args, kwargs):
return y.to(orig_dtype)


Int4OpaqueTensor.__module__ = "torchao.quantization"
Int4OpaqueTensor.__module__ = "torchao.prototype.int4_opaque_tensor"

# Allow a model with Int4OpaqueTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Int4OpaqueTensor])
2 changes: 0 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
from .quantize_.workflows import (
Float8Tensor,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Int4PlainInt32Tensor,
Int4PreshuffledTensor,
Int4Tensor,
Expand Down Expand Up @@ -173,7 +172,6 @@
"IntxUnpackedToInt8Tensor",
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Int4OpaqueTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
15 changes: 2 additions & 13 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
Float8Tensor,
Int4ChooseQParamsAlgorithm,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Int4PackingFormat,
Int4PlainInt32Tensor,
Int4PreshuffledTensor,
Expand Down Expand Up @@ -1163,12 +1162,9 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size = list(block_size)

if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
assert int4_packing_format in [
Int4PackingFormat.TILE_PACKED_TO_4D,
Int4PackingFormat.OPAQUE,
], (
assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, "
f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D and Int4PackingFormat.OPAQUE currently"
f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D currently"
)

if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
Expand Down Expand Up @@ -1196,13 +1192,6 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.OPAQUE:
new_weight = Int4OpaqueTensor.from_hp(
weight,
block_size,
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:
new_weight = Int4TilePackedTo4dTensor.from_hp(
weight,
Expand Down
4 changes: 0 additions & 4 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from .int4.int4_marlin_sparse_tensor import (
Int4MarlinSparseTensor,
)
from .int4.int4_opaque_tensor import (
Int4OpaqueTensor,
)
from .int4.int4_packing_format import Int4PackingFormat
from .int4.int4_plain_int32_tensor import (
Int4PlainInt32Tensor,
Expand Down Expand Up @@ -39,7 +36,6 @@
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
"Int4OpaqueTensor",
"Int4ChooseQParamsAlgorithm",
"Int4PackingFormat",
"IntxChooseQParamsAlgorithm",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,3 @@ class Int4PackingFormat(str, Enum):
for simplication of Int4TilePackedTo4dTensor API
"""
TILE_PACKED_TO_4D = "tile_packed_to_4d"

"""
Opaque packing format that's used for tensors that does not have a predefined packing format
(that may be decided on hardware, tensor shape, library availability etc.) and it's not
needed for the rest of the system to understand the specific format that's adopted.
"""
OPAQUE = "opaque"