Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao import quantize_
from torchao.quantization import PerGroup, PerRow, PerTensor
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
)


def get_config(granularity):
return Float8DynamicActivationFloat8WeightConfig(
activation_dtype=torch.float8_e4m3fn,
granularity=granularity,
float8_packing_format="opaque",
)


class ToyLinearModel(torch.nn.Module):
def __init__(self, K=64, N=32, bias=False):
super().__init__()
self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float)
self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float)

def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (
torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device)
* 0.1,
)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


class TestFloat8OpaqueTensor(TestCase):
"""Test cases for Float8OpaqueTensor on CPU"""

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [1, 160])
@common_utils.parametrize(
"x_granularity",
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
)
@common_utils.parametrize(
"w_granularity",
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
)
def test_dynamic_float8_linear(
self, dtype, x_dim, bias, bs, x_granularity, w_granularity
):
if isinstance(x_granularity, PerGroup):
if not isinstance(w_granularity, PerGroup):
return
if w_granularity.group_size != x_granularity.group_size:
return
device = "cpu"
m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device)
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)
y = m(*example_inputs)

with torch.no_grad():
quantize_(
m,
get_config([x_granularity, w_granularity]),
)
y1 = m(*example_inputs)
assert compute_error(y, y1) > 20
y2, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
assert compute_error(y, y2) > 20

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [4, 128])
def test_dynamic_float8_linear_ref(self, dtype, x_dim, bias, bs):
device = "cpu"
# the shape is not supported by cpp kernel, so the ref path will be used.
m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device)
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)
y = m(*example_inputs)

with torch.no_grad():
quantize_(
m,
get_config(PerRow()),
)
y1 = m(*example_inputs)
assert compute_error(y, y1) > 20
y2, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
assert compute_error(y, y2) > 20

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype)
quantize_(linear, get_config(PerRow()))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Float8OpaqueTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Float8OpaqueTensor'>",
)


common_utils.instantiate_parametrized_tests(TestFloat8OpaqueTensor)


if __name__ == "__main__":
run_tests()
28 changes: 21 additions & 7 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
from torchao.float8.types import FP8Granularity
from torchao.quantization.granularity import (
PerGroup,
PerRow,
PerTensor,
)
Expand Down Expand Up @@ -204,28 +205,41 @@ def _normalize_granularity(
list[FP8Granularity],
]
],
supported_granularities: tuple[FP8Granularity] = (PerTensor, PerRow),
support_different_granularities: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is weird, I think we should have normalize_granularity to only do normalize, not also validation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the same actually. Where should we put the validation? Thanks.

Copy link
Contributor

@jerryzh168 jerryzh168 Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to _normalize_and_validate_granularities, can you define separate functions for both float8 tensor and float8 opque tensor in the tensor file itself? i.e. float8_tensor.py and float8_opque_tensor.py

probably will be clearer if you do this in a separate PR, that is move the original _normalize function to float8_tensor.py and change all the callsites first, and then in this PR you just need to add a new one for float8_opque_tensor.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Will do. Thanks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about version=1? Call _validate_granularity explicitly? In that case, _validate_granularity cannot be bound to a specific tensor type I guess. And _normalize_granularity (with checks) is called elsewhere too:

How shall we do validation at these locations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jerryzh168 Do you have any suggestions for this? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just create _normalize_and_validate_granularities for both float8_tensor.py and float8_opaque_tensor.py in their own files without sharing anything seems to be the cleanest

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh actually looks like

(act_granularity, weight_granularity) = _normalize_granularity(
needs a pure normalize function

I think you can do the following:

torchao/quantization/quantize_/workflows/float8/utils.py
_normalize_granularity (make a single element to be a tuple)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py
_validate_granluarity (per row, per tensor)

torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py
_validate_granluarity (per row, per tensor, per block)

) -> Tuple[FP8Granularity, FP8Granularity]:
processed_granularity = None
if granularity is None:
processed_granularity = (PerTensor(), PerTensor())
elif isinstance(granularity, (PerTensor, PerRow)):
elif isinstance(granularity, supported_granularities):
processed_granularity = (granularity, granularity)
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
if not (
isinstance(granularity[0], (PerTensor, PerRow))
and isinstance(granularity[1], (PerTensor, PerRow))
isinstance(granularity[0], supported_granularities)
and isinstance(granularity[1], supported_granularities)
):
raise ValueError(
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
f"Invalid granularity types: {granularity}, only {supported_granularities} are supported."
)
if not isinstance(granularity[0], type(granularity[1])):
if not support_different_granularities and not isinstance(
granularity[0], type(granularity[1])
):
raise ValueError(
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
f"Different granularities for activation and weight are not supported: {granularity}, only {supported_granularities} are supported."
)
if isinstance(granularity[0], PerGroup):
if not isinstance(granularity[1], PerGroup):
raise ValueError(
"When granularity for activation is PerGroup, granularity for weight must be PerGroup, too."
)
if granularity[0].group_size != granularity[1].group_size:
raise ValueError(
f"Group sizes for activation and weight must be the same, got {granularity[0].group_size} and {granularity[1].group_size}."
)
processed_granularity = tuple(granularity)
else:
raise ValueError(
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
f"Invalid granularity specification: {granularity}, only {supported_granularities} are supported."
)
return processed_granularity

Expand Down
4 changes: 2 additions & 2 deletions torchao/float8/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import TYPE_CHECKING, Union

if TYPE_CHECKING:
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.granularity import PerGroup, PerRow, PerTensor


# Define FP8Granularity type alias to break circular import dependencies
FP8Granularity = Union["PerTensor", "PerRow"]
FP8Granularity = Union["PerTensor", "PerRow", "PerGroup"]
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
quantize_affine,
)
from .quantize_.workflows import (
Float8OpaqueTensor,
Float8Tensor,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Expand Down Expand Up @@ -148,6 +149,7 @@
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Int4OpaqueTensor",
"Float8OpaqueTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
78 changes: 56 additions & 22 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
KernelPreference,
)
from torchao.quantization.quantize_.workflows import (
Float8OpaqueTensor,
Float8PackingFormat,
Float8Tensor,
Int4ChooseQParamsAlgorithm,
Int4MarlinSparseTensor,
Expand Down Expand Up @@ -1676,6 +1678,22 @@ def _input_activation_quant_func_fp8(
return activation


def _input_activation_quant_cpu_fp8(
x: torch.Tensor,
activation_granularity: FP8Granularity,
activation_dtype: torch.dtype,
):
"""Dynamic quantize activation to fp8 for CPU."""
block_size = get_block_size(x.shape, activation_granularity)
return to_affine_quantized_floatx(
input_float=x,
block_size=block_size,
target_dtype=activation_dtype,
scale_dtype=torch.float32,
_layout=PlainLayout(),
)


def _fp8_mm_compat(weight: torch.Tensor) -> bool:
"""
Check if a weight tensor meets float8 quantization requirements.
Expand Down Expand Up @@ -1734,15 +1752,24 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
kernel_preference: KernelPreference = KernelPreference.AUTO
set_inductor_config: bool = True
version: int = 2
float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
)
if self.mm_config is None:
self.mm_config = Float8MMConfig(use_fast_accum=True)
supported_granularities = ()
if self.float8_packing_format == Float8PackingFormat.PLAIN:
supported_granularities = (PerTensor, PerRow)
elif self.float8_packing_format == Float8PackingFormat.OPAQUE:
supported_granularities = (PerTensor, PerRow, PerGroup)
support_different_granularities = (
self.float8_packing_format == Float8PackingFormat.OPAQUE
)
activation_granularity, weight_granularity = _normalize_granularity(
self.granularity
self.granularity, supported_granularities, support_different_granularities
)
self.granularity = [activation_granularity, weight_granularity]

Expand All @@ -1755,17 +1782,12 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
activation_value_lb = config.activation_value_lb
activation_value_ub = config.activation_value_ub
kernel_preference = config.kernel_preference
float8_packing_format = config.float8_packing_format

# Ensure works on device
_check_hardware_support(granularity)
activation_granularity, weight_granularity = granularity

if not _fp8_mm_compat(weight):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return weight

if isinstance(weight_granularity, PerRow):
if weight.device.type != "cpu" and isinstance(weight_granularity, PerRow):
assert weight.dtype == torch.bfloat16, (
"PerRow quantization only works for bfloat16 precision input weight"
)
Expand All @@ -1775,6 +1797,12 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
"Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details"
)

_check_hardware_support(granularity)
if not _fp8_mm_compat(weight):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return weight

block_size = get_block_size(weight.shape[-2:], weight_granularity)
if weight.dim() == 3:
block_size = tuple([1] + list(block_size))
Expand Down Expand Up @@ -1805,14 +1833,26 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
kernel_preference=kernel_preference,
)

quantized_weight = Float8Tensor.from_hp(
weight,
float8_dtype=weight_dtype,
granularity=weight_granularity,
mm_config=mm_config,
kernel_preference=kernel_preference,
act_quant_kwargs=act_quant_kwargs,
)
if float8_packing_format == Float8PackingFormat.PLAIN:
quantized_weight = Float8Tensor.from_hp(
weight,
float8_dtype=weight_dtype,
granularity=weight_granularity,
mm_config=mm_config,
kernel_preference=kernel_preference,
act_quant_kwargs=act_quant_kwargs,
)
elif float8_packing_format == Float8PackingFormat.OPAQUE:
block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = Float8OpaqueTensor.from_hp(
weight,
block_size=block_size,
act_quant_kwargs=act_quant_kwargs,
)
else:
raise ValueError(
f"Unsupported float8 packing format: {float8_packing_format}"
)

return quantized_weight

Expand All @@ -1821,12 +1861,6 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
def _float8_dynamic_activation_float8_weight_transform(
module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig
):
assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
)
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

assert hasattr(module, "weight"), (
"applying float8 dynamic activation quant requires module to have weight attribute"
+ f"but {module} does not have one"
Expand Down
Loading
Loading