From 20c657924d77747e19c912ec31990f19804e7d8d Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 21 Nov 2025 13:33:40 +0000 Subject: [PATCH] Move float8_opaque_tensor to prototype --- .../test_float8_opaque_tensor.py | 12 +- .../float8_opaque_tensor/__init__.py | 7 ++ .../float8_opaque_tensor.py | 7 +- .../inference_workflow.py | 117 ++++++++++++++++++ torchao/quantization/__init__.py | 2 - torchao/quantization/quant_api.py | 107 ++++++---------- .../quantize_/workflows/__init__.py | 6 - .../workflows/float8/float8_packing_format.py | 31 ----- 8 files changed, 173 insertions(+), 116 deletions(-) rename test/{quantization/quantize_/workflows/float8 => prototype}/test_float8_opaque_tensor.py (93%) create mode 100644 torchao/prototype/float8_opaque_tensor/__init__.py rename torchao/{quantization/quantize_/workflows/float8 => prototype/float8_opaque_tensor}/float8_opaque_tensor.py (98%) create mode 100644 torchao/prototype/float8_opaque_tensor/inference_workflow.py delete mode 100644 torchao/quantization/quantize_/workflows/float8/float8_packing_format.py diff --git a/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py b/test/prototype/test_float8_opaque_tensor.py similarity index 93% rename from test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py rename to test/prototype/test_float8_opaque_tensor.py index 9ec3f4e3ca..f031bbae6d 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py +++ b/test/prototype/test_float8_opaque_tensor.py @@ -15,8 +15,10 @@ ) from torchao import quantize_ +from torchao.prototype.float8_opaque_tensor import ( + Float8DynamicActivationFloat8WeightOpaqueTensorConfig, +) from torchao.quantization import ( - Float8DynamicActivationFloat8WeightConfig, PerGroup, PerRow, PerTensor, @@ -29,10 +31,8 @@ def get_config(granularity): - return Float8DynamicActivationFloat8WeightConfig( - activation_dtype=torch.float8_e4m3fn, + return Float8DynamicActivationFloat8WeightOpaqueTensorConfig( granularity=granularity, - float8_packing_format="opaque", ) @@ -133,7 +133,7 @@ def test_module_path(self, dtype): quantize_(linear, get_config(PerRow())) self.assertEqual( str(type(linear.weight)), - "", + "", ) with tempfile.NamedTemporaryFile() as f: @@ -142,7 +142,7 @@ def test_module_path(self, dtype): state_dict = torch.load(f) self.assertEqual( str(type(state_dict["weight"])), - "", + "", ) diff --git a/torchao/prototype/float8_opaque_tensor/__init__.py b/torchao/prototype/float8_opaque_tensor/__init__.py new file mode 100644 index 0000000000..c6814fb1dc --- /dev/null +++ b/torchao/prototype/float8_opaque_tensor/__init__.py @@ -0,0 +1,7 @@ +from .float8_opaque_tensor import Float8OpaqueTensor +from .inference_workflow import Float8DynamicActivationFloat8WeightOpaqueTensorConfig + +__all__ = [ + "Float8OpaqueTensor", + "Float8DynamicActivationFloat8WeightOpaqueTensorConfig", +] diff --git a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py b/torchao/prototype/float8_opaque_tensor/float8_opaque_tensor.py similarity index 98% rename from torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py rename to torchao/prototype/float8_opaque_tensor/float8_opaque_tensor.py index c5c47048c8..f166267fb3 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py +++ b/torchao/prototype/float8_opaque_tensor/float8_opaque_tensor.py @@ -21,12 +21,13 @@ from torchao.quantization.quantize_.common import ( _choose_quant_func_and_quantize_tensor, ) +from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( + QuantizeTensorToFloat8Kwargs, +) from torchao.utils import ( TorchAOBaseTensor, ) -from .float8_tensor import QuantizeTensorToFloat8Kwargs - __all__ = [ "Float8OpaqueTensor", ] @@ -267,7 +268,7 @@ def _(func, types, args, kwargs): return y -Float8OpaqueTensor.__module__ = "torchao.quantization" +Float8OpaqueTensor.__module__ = "torchao.prototype.float8_opaque_tensor" # Allow a model with Float8OpaqueTensor weights to be loaded with `weights_only=True` torch.serialization.add_safe_globals([Float8OpaqueTensor]) diff --git a/torchao/prototype/float8_opaque_tensor/inference_workflow.py b/torchao/prototype/float8_opaque_tensor/inference_workflow.py new file mode 100644 index 0000000000..163c8066df --- /dev/null +++ b/torchao/prototype/float8_opaque_tensor/inference_workflow.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Union + +import torch + +import torchao +from torchao.core.config import AOBaseConfig + +if TYPE_CHECKING: + from torchao.quantization.granularity import PerGroup, PerRow, PerTensor + + +# Define FP8Granularity type alias to break circular import dependencies +FP8Granularity = Union["PerTensor", "PerRow", "PerGroup"] + +import types +from functools import partial + +from torchao.quantization.quant_api import _module_extra_repr +from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) +from torchao.quantization.utils import get_block_size + +from .float8_opaque_tensor import Float8OpaqueTensor + + +@dataclass +class Float8DynamicActivationFloat8WeightOpaqueTensorConfig(AOBaseConfig): + """ + Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers. + + Args: + activation_dtype (torch.dtype): The target data type for activation quantization. Only torch.float8_e4m3fn supported. + weight_dtype (torch.dtype): The target data type for weight quantization. Only torch.float8_e4m3fn supported. + granularity (Optional[Union[FP8Granularity, List[FP8Granularity]]]): + The granularity for quantization. Can be either a single granularity (applied to both + activations and weights) or a tuple of two granularities (one for activations, one for weights). + If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And + only PerTensor/PerRow/PerGroup are supported. + + """ + + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None + set_inductor_config: bool = True + + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Float8DynamicActivationFloat8WeightConfig" + ) + activation_granularity, weight_granularity = ( + Float8OpaqueTensor._normalize_and_check_granularity(self.granularity) + ) + self.granularity = [activation_granularity, weight_granularity] + + +def _float8_dynamic_activation_float8_weight_opaque_tensor_quantize(weight, config): + activation_dtype = config.activation_dtype + granularity = config.granularity + + activation_granularity, weight_granularity = granularity + + act_quant_kwargs = QuantizeTensorToFloat8Kwargs( + activation_dtype, + activation_granularity, + ) + + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = Float8OpaqueTensor.from_hp( + weight, + block_size=block_size, + act_quant_kwargs=act_quant_kwargs, + ) + + return quantized_weight + + +@register_quantize_module_handler(Float8DynamicActivationFloat8WeightOpaqueTensorConfig) +def _float8_dynamic_activation_float8_weight_opaque_tensor_transform( + module: torch.nn.Module, + config: Float8DynamicActivationFloat8WeightOpaqueTensorConfig, + *, + parameter_name: str = "weight", +): + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, parameter_name), ( + f"applying float8 dynamic activation quant requires module to have parameter {parameter_name} attribute" + + f" but {module} does not have one" + ) + quantized_tensor = _float8_dynamic_activation_float8_weight_opaque_tensor_quantize( + getattr(module, parameter_name), config + ) + setattr( + module, + parameter_name, + torch.nn.Parameter(quantized_tensor, requires_grad=False), + ) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, + ) + return module diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 577ac40721..8dd6410597 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -92,7 +92,6 @@ quantize_affine, ) from .quantize_.workflows import ( - Float8OpaqueTensor, Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, @@ -175,7 +174,6 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", - "Float8OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 1e176f9e9b..33b82413be 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -74,8 +74,6 @@ KernelPreference, ) from torchao.quantization.quantize_.workflows import ( - Float8OpaqueTensor, - Float8PackingFormat, Float8Tensor, Int4ChooseQParamsAlgorithm, Int4MarlinSparseTensor, @@ -1808,23 +1806,14 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True version: int = 2 - float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN def __post_init__(self): torch._C._log_api_usage_once( "torchao.quantization.Float8DynamicActivationFloat8WeightConfig" ) - if ( - self.version == 2 - and self.float8_packing_format == Float8PackingFormat.OPAQUE - ): - activation_granularity, weight_granularity = ( - Float8OpaqueTensor._normalize_and_check_granularity(self.granularity) - ) - else: - activation_granularity, weight_granularity = _normalize_granularity( - self.granularity - ) + activation_granularity, weight_granularity = _normalize_granularity( + self.granularity + ) self.granularity = [activation_granularity, weight_granularity] default_use_fast_accum = True @@ -1854,48 +1843,43 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_value_lb = config.activation_value_lb activation_value_ub = config.activation_value_ub kernel_preference = config.kernel_preference - float8_packing_format = config.float8_packing_format # Ensure works on device + _check_hardware_support(granularity) activation_granularity, weight_granularity = granularity - if float8_packing_format == Float8PackingFormat.PLAIN: - # Note: right now we assume it's weights of conv2d and conv3d purely based - # on the dimension of weight, currently there is no conflict with linear 2d - # and moe weights 3d - # if we need to support conv1d, which also has 3d weight, we may have to - # pass around the module as well to distinguish between conv1d and 3d moe weight - if weight.dim() in [4, 5]: - # weights for conv2d or 3d - assert isinstance(activation_granularity, PerTensor) and isinstance( - weight_granularity, PerTensor - ), ( - "4D/5D tensor only supports per tensor activation and weight quantization" - ) - - # conv3d weight dim: (C_out, C_in, K1, K2, K3) - # conv2d weight dim: (C_out, C_in, K1, K2) - # skip quantization when either C_out or C_in - # is not a multiple of 16 - if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0: - return weight - - elif not _fp8_mm_compat(weight): - # TODO(future PR): this should really throw an exception instead of silently - # not doing what the user asked + # Note: right now we assume it's weights of conv2d and conv3d purely based + # on the dimension of weight, currently there is no conflict with linear 2d + # and moe weights 3d + # if we need to support conv1d, which also has 3d weight, we may have to + # pass around the module as well to distinguish between conv1d and 3d moe weight + if weight.dim() in [4, 5]: + # weights for conv2d or 3d + assert isinstance(activation_granularity, PerTensor) and isinstance( + weight_granularity, PerTensor + ), "4D/5D tensor only supports per tensor activation and weight quantization" + + # conv3d weight dim: (C_out, C_in, K1, K2, K3) + # conv2d weight dim: (C_out, C_in, K1, K2) + # skip quantization when either C_out or C_in + # is not a multiple of 16 + if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0: return weight + elif not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return weight - if isinstance(weight_granularity, PerRow): - assert weight.dtype == torch.bfloat16, ( - "PerRow quantization only works for bfloat16 precision input weight" - ) + if isinstance(weight_granularity, PerRow): + assert weight.dtype == torch.bfloat16, ( + "PerRow quantization only works for bfloat16 precision input weight" + ) if config.version == 1: warnings.warn( "Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" ) - _check_hardware_support(granularity) block_size = get_block_size(weight.shape[-2:], weight_granularity) if weight.dim() == 3: block_size = tuple([1] + list(block_size)) @@ -1926,26 +1910,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): kernel_preference=kernel_preference, ) - if float8_packing_format == Float8PackingFormat.PLAIN: - quantized_weight = Float8Tensor.from_hp( - weight, - float8_dtype=weight_dtype, - granularity=weight_granularity, - mm_config=mm_config, - kernel_preference=kernel_preference, - act_quant_kwargs=act_quant_kwargs, - ) - elif float8_packing_format == Float8PackingFormat.OPAQUE: - block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = Float8OpaqueTensor.from_hp( - weight, - block_size=block_size, - act_quant_kwargs=act_quant_kwargs, - ) - else: - raise ValueError( - f"Unsupported float8 packing format: {float8_packing_format}" - ) + quantized_weight = Float8Tensor.from_hp( + weight, + float8_dtype=weight_dtype, + granularity=weight_granularity, + mm_config=mm_config, + kernel_preference=kernel_preference, + act_quant_kwargs=act_quant_kwargs, + ) return quantized_weight @@ -1957,10 +1929,9 @@ def _float8_dynamic_activation_float8_weight_transform( *, parameter_name: str = "weight", ): - if config.float8_packing_format == Float8PackingFormat.PLAIN: - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + ) if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index e379327689..4307637f8e 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,7 +1,3 @@ -from .float8.float8_opaque_tensor import ( - Float8OpaqueTensor, -) -from .float8.float8_packing_format import Float8PackingFormat from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -41,9 +37,7 @@ "Int4MarlinSparseTensor", "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", - "Float8OpaqueTensor", "Float8Tensor", - "Float8PackingFormat", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", "Int4ChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py b/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py deleted file mode 100644 index 04ae64241c..0000000000 --- a/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -from enum import Enum - - -# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) -# after python 3.10 is end of life (https://devguide.python.org/versions/) -class Float8PackingFormat(str, Enum): - """Packing format for quantized data in Float8 Tensor subclasses in torchao, represents how - the values in quantized data are packed and laid out in memory. - """ - - """ - plain means the format that quantized Tensor data lays out elements in Tensor sequentially, - for example, for a Tensor of shape (4, 6): - a_0_0, a_0_1, ..., a_0_5, - ... - a_3_0, a_3_1, ..., a_3_5 - """ - PLAIN = "plain" - - """ - Opaque packing format that's used for tensors that does not have a predefined packing format - (that may be decided on hardware, tensor shape, library availability etc.) and it's not - needed for the rest of the system to understand the specific format that's adopted. - """ - OPAQUE = "opaque"