Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_invalid_granularity(self):
def test_mismatched_granularity(self):
with pytest.raises(
ValueError,
match="Different granularities for activation and weight are not supported",
match="Unsupported granularity types",
):
Float8DynamicActivationFloat8WeightConfig(
granularity=(PerTensor(), PerRow())
Expand All @@ -165,7 +165,7 @@ def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass

with pytest.raises(ValueError, match="Invalid granularity types"):
with pytest.raises(ValueError, match="Unsupported granularity types"):
Float8DynamicActivationFloat8WeightConfig(
granularity=(UnsupportedGranularity(), UnsupportedGranularity()),
)
Expand Down
47 changes: 38 additions & 9 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
PerBlock,
PerRow,
PerTensor,
quantize_,
Expand Down Expand Up @@ -64,7 +65,10 @@ def setUp(self):
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
"granularity",
[PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))],
)
@common_utils.parametrize(
"kernel_preference",
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
Expand All @@ -74,7 +78,7 @@ def setUp(self):
"sizes",
[
((128,), 256, 128),
((32, 128), 64, 256),
((32, 128), 256, 512),
],
)
def test_fp8_linear_variants(
Expand All @@ -86,13 +90,24 @@ def test_fp8_linear_variants(
kernel_preference: KernelPreference,
sizes: Tuple,
):
if (
isinstance(granularity, PerTensor)
and kernel_preference == KernelPreference.FBGEMM
):
return unittest.skip(
"per tensor with fbgemm kernel preferece does not work yet"
)
if isinstance(granularity, PerTensor):
if kernel_preference is KernelPreference.FBGEMM:
return unittest.skip(
"per tensor with fbgemm kernel preference does not work yet"
)
elif mode == "weight-only":
return unittest.skip("unimplemented")

elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))):
if dtype is not torch.bfloat16:
return unittest.skip("unimplemented")
elif mode != "dynamic":
return unittest.skip("unimplemented")
elif kernel_preference not in (
KernelPreference.AUTO,
KernelPreference.TORCH,
):
return unittest.skip("unimplemented")

error_message = None
if isinstance(granularity, PerRow):
Expand Down Expand Up @@ -137,6 +152,20 @@ def test_fp8_linear_variants(

quantize_(quantized_model, config)

# ensure weight scaling is what we expect
qs1 = quantized_model.linear1.weight.scale
qs2 = quantized_model.linear2.weight.scale
if granularity == PerTensor():
assert qs1.shape == (1, 1)
assert qs2.shape == (1, 1)
elif granularity == PerRow():
assert qs1.shape == (N, 1)
assert qs2.shape == (K, 1)
else:
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
assert qs1.shape == (N // 128, K // 128)
assert qs2.shape == (K // 128, N // 128)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

Expand Down
46 changes: 46 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
MappingType,
ZeroPointDomain,
_choose_qparams_affine_tinygemm,
_choose_scale_float8,
_fake_quantize_affine,
_fake_quantize_affine_cachemask,
_maybe_expand_scale_to_tensor_shape,
_quantize_affine_float8,
choose_qparams_affine,
dequantize_affine,
quantize_affine,
Expand Down Expand Up @@ -55,6 +57,23 @@ def check_idempotent(self, fn, *args, **kwargs):
return output1


# from https://github.com/pytorch/pytorch/blob/7563f61cc8a40a5ba21a498a2d98895b4eec3f39/test/test_scaled_matmul_cuda.py#L100
# with scale modified to be the inverse of the version in PT core
def _tensor_to_scale_block(
x: torch.Tensor,
float8_dtype: torch.dtype,
block_outer: int,
block_inner: int,
) -> tuple[torch.Tensor, torch.Tensor]:
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
scale = amax / torch.finfo(float8_dtype).max
x = x.div(scale).to(float8_dtype)
x = x.flatten(2, 3).flatten(0, 1)
scale = scale.flatten(2, 3).flatten(0, 1)
return x, scale


# Legacy tinygemm ops
def _get_groupwise_affine_qparams(
w,
Expand Down Expand Up @@ -798,6 +817,33 @@ def test_maybe_expand_scale_to_tensor_shape(self):
self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8]))
self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2]))

def test_float8_blockwise_scaling(self):
M, K = 512, 1024
hp_tensor = torch.randn(M, K, dtype=torch.float)
# make the scales from some of the blocks obviously different
hp_tensor[0:128, 0:128] *= 3.0
hp_tensor[0:128, 128:256] *= 7.0
hp_tensor[128:256, 0:128] *= 2.0
hp_tensor[128:256, 128:256] *= 100.0

block_size = (128, 128)

scale = _choose_scale_float8(
hp_tensor,
float8_dtype=torch.float8_e4m3fn,
block_size=block_size,
hp_value_lb=None,
hp_value_ub=None,
)
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)

ref_data, ref_scale = _tensor_to_scale_block(
hp_tensor, torch.float8_e4m3fn, 128, 128
)

torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
75 changes: 59 additions & 16 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
Defines an nn module designed to be used during inference
"""

import math
from typing import List, NamedTuple, Optional, Tuple, Union

import torch

from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
from torchao.float8.types import FP8Granularity
from torchao.quantization.granularity import (
PerBlock,
PerRow,
PerTensor,
)
Expand Down Expand Up @@ -196,6 +198,36 @@ def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
)


def _is_1_128_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is scaled with a block size of 1x128
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
b = x.block_size
return len(b) >= 2 and math.prod(b[:-1]) == 1 and b[-1] == 128


def _is_128_128_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is scaled with a block size of 128x128
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
b = x.block_size
return len(b) == 2 and b[0] == 128 and b[1] == 128


def _granularity_is_a_1_128_w_128_128(
g: Union[
FP8Granularity,
Tuple[FP8Granularity, FP8Granularity],
list[FP8Granularity],
],
) -> bool:
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))


def _normalize_granularity(
granularity: Optional[
Union[
Expand All @@ -211,22 +243,23 @@ def _normalize_granularity(
elif isinstance(granularity, (PerTensor, PerRow)):
processed_granularity = (granularity, granularity)
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
if not (
isinstance(granularity[0], (PerTensor, PerRow))
and isinstance(granularity[1], (PerTensor, PerRow))
):
raise ValueError(
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
)
is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance(
granularity[1], PerTensor
)
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
granularity[1], PerRow
)
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)

if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
raise ValueError(f"Unsupported granularity types: {granularity}.")
if not isinstance(granularity[0], type(granularity[1])):
raise ValueError(
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
f"Different granularities for activation and weight are not supported: {granularity}."
)
processed_granularity = tuple(granularity)
else:
raise ValueError(
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
)
raise ValueError(f"Invalid granularity specification: {granularity}.")
return processed_granularity


Expand All @@ -243,12 +276,22 @@ def _check_hardware_support(
AssertionError: If hardware doesn't support the requested granularity
ValueError: If invalid granularity type is provided
"""
for _granularity in granularities:
if not isinstance(_granularity, (PerTensor, PerRow)):
raise ValueError(
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
)
is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance(
granularities[1], PerTensor
)
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
granularities[1], PerRow
)
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)

if is_per_tensor or is_per_row:
assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
)
elif is_a_1_128_w_128_128:
# TODO(future PR): look into AMD support
assert is_sm_at_least_89(), (
"Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
)
else:
raise ValueError(f"Invalid granularities {granularities}.")
Loading
Loading