From 990ef89e6b5d4a785011787c938077f1310178ae Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Oct 2025 04:05:16 -0700 Subject: [PATCH 01/14] Update [ghstack-poisoned] --- .../quantize_/workflows/float8/test_float8_tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 786e0cf59f..010682474e 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -294,6 +294,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity): self._test_slice_and_copy_similar_to_vllm(config) @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_bmm(self): # only support per row quantization config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) @@ -406,6 +407,7 @@ def test_cat(self, granularity, sizes): self.assertEqual(cat_qweight2.scale, ref_scale) @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_moe_weight_reshape_ops(self): # only per row quantization is supported for bmm granularity = PerRow() @@ -416,6 +418,7 @@ def test_moe_weight_reshape_ops(self): # that should be moved here after v1 config is deprecated: # https://github.com/pytorch/ao/issues/2649 @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_expected_gpu_kernel_fbgemm(self): """Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels and the bias add happens in the gemm kernel for per row quantization From cce08f0a71883504f526c7f75594b334e325b1d6 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Oct 2025 04:05:20 -0700 Subject: [PATCH 02/14] Update [ghstack-poisoned] --- benchmarks/benchmark_blockwise_scaled_linear_triton.py | 2 +- test/{prototype => kernel}/test_blockwise_triton.py | 2 +- .../blockwise_quantization.py | 0 torchao/prototype/blockwise_fp8_inference/__init__.py | 5 +++-- .../prototype/blockwise_fp8_inference/blockwise_linear.py | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) rename test/{prototype => kernel}/test_blockwise_triton.py (96%) rename torchao/{prototype/blockwise_fp8_inference => kernel}/blockwise_quantization.py (100%) diff --git a/benchmarks/benchmark_blockwise_scaled_linear_triton.py b/benchmarks/benchmark_blockwise_scaled_linear_triton.py index ffdd63ec8d..26ba04f2ce 100644 --- a/benchmarks/benchmark_blockwise_scaled_linear_triton.py +++ b/benchmarks/benchmark_blockwise_scaled_linear_triton.py @@ -13,7 +13,7 @@ from triton.testing import do_bench from torchao.float8.float8_utils import compute_error - from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( + from torchao.kernel.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, fp8_blockwise_weight_quant, diff --git a/test/prototype/test_blockwise_triton.py b/test/kernel/test_blockwise_triton.py similarity index 96% rename from test/prototype/test_blockwise_triton.py rename to test/kernel/test_blockwise_triton.py index 89f8cf869e..5de88ab7d9 100644 --- a/test/prototype/test_blockwise_triton.py +++ b/test/kernel/test_blockwise_triton.py @@ -11,7 +11,7 @@ triton = pytest.importorskip("triton", reason="Triton required to run this test") -from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( +from torchao.kernel.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, fp8_blockwise_weight_dequant, diff --git a/torchao/prototype/blockwise_fp8_inference/blockwise_quantization.py b/torchao/kernel/blockwise_quantization.py similarity index 100% rename from torchao/prototype/blockwise_fp8_inference/blockwise_quantization.py rename to torchao/kernel/blockwise_quantization.py diff --git a/torchao/prototype/blockwise_fp8_inference/__init__.py b/torchao/prototype/blockwise_fp8_inference/__init__.py index f2842417e4..eb6b7824bc 100644 --- a/torchao/prototype/blockwise_fp8_inference/__init__.py +++ b/torchao/prototype/blockwise_fp8_inference/__init__.py @@ -1,11 +1,12 @@ -from .blockwise_linear import BlockwiseQuantLinear -from .blockwise_quantization import ( +from torchao.kernel.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, fp8_blockwise_weight_dequant, fp8_blockwise_weight_quant, ) +from .blockwise_linear import BlockwiseQuantLinear + __all__ = [ "blockwise_fp8_gemm", "BlockwiseQuantLinear", diff --git a/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py b/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py index ebed3a84a4..a43574fa11 100644 --- a/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py +++ b/torchao/prototype/blockwise_fp8_inference/blockwise_linear.py @@ -7,7 +7,7 @@ import torch from torch import nn -from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import ( +from torchao.kernel.blockwise_quantization import ( blockwise_fp8_gemm, fp8_blockwise_act_quant, ) From 681277af9706ba8c0bea6290a1813d85cdc2979b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Oct 2025 04:05:23 -0700 Subject: [PATCH 03/14] Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 30 ++++++-- torchao/float8/inference.py | 68 ++++++++++++++----- .../workflows/float8/float8_tensor.py | 43 +++++++++--- 3 files changed, 110 insertions(+), 31 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 010682474e..7d8de03908 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -18,6 +18,7 @@ from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, + PerBlock, PerRow, PerTensor, quantize_, @@ -61,20 +62,37 @@ def setUp(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) - @common_utils.parametrize("mode", ["dynamic", "weight-only"]) - @common_utils.parametrize("compile", [True, False]) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + # @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize( + "dtype", + [ + torch.bfloat16, + ], + ) + # @common_utils.parametrize("mode", ["dynamic", "weight-only"]) + @common_utils.parametrize( + "mode", + [ + "dynamic", + ], + ) + # @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize("compile", [False]) + # @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "granularity", [(PerBlock((1, 128)), PerBlock((128, 128)))] + ) @common_utils.parametrize( "kernel_preference", - [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], + # [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], + [KernelPreference.TORCH], ) # Inputs are (M,..), K, N @common_utils.parametrize( "sizes", [ ((128,), 256, 128), - ((32, 128), 64, 256), + # ((32, 128), 64, 256), ], ) def test_fp8_linear_variants( diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f15d38576c..21384c1387 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -14,6 +14,7 @@ from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul from torchao.float8.types import FP8Granularity from torchao.quantization.granularity import ( + PerBlock, PerRow, PerTensor, ) @@ -196,6 +197,26 @@ def _is_tensorwise_scaled(x: torch.Tensor) -> bool: ) +def _is_1_128_scaled(x: torch.Tensor) -> bool: + """Checks if a quantized tensor is scaled with a block size of 1x128 + Args: + x: quantized tensor (should have `block_size` attribute) + """ + assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" + b = x.block_size + return len(b) == 2 and b[0] == 1 and b[1] == 128 + + +def _is_128_128_scaled(x: torch.Tensor) -> bool: + """Checks if a quantized tensor is scaled with a block size of 128x128 + Args: + x: quantized tensor (should have `block_size` attribute) + """ + assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" + b = x.block_size + return len(b) == 2 and b[0] == 128 and b[1] == 128 + + def _normalize_granularity( granularity: Optional[ Union[ @@ -211,22 +232,25 @@ def _normalize_granularity( elif isinstance(granularity, (PerTensor, PerRow)): processed_granularity = (granularity, granularity) elif isinstance(granularity, (tuple, list)) and len(granularity) == 2: - if not ( - isinstance(granularity[0], (PerTensor, PerRow)) - and isinstance(granularity[1], (PerTensor, PerRow)) - ): - raise ValueError( - f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported." - ) + is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance( + granularity[1], PerTensor + ) + is_per_row = isinstance(granularity[0], PerRow) and isinstance( + granularity[1], PerRow + ) + is_a_1_128_w_128_128 = granularity[0] == PerBlock((1, 128)) and granularity[ + 1 + ] == PerBlock((128, 128)) + + if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128): + raise ValueError(f"Unsupported granularity types: {granularity}.") if not isinstance(granularity[0], type(granularity[1])): raise ValueError( - f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." + f"Different granularities for activation and weight are not supported: {granularity}." ) processed_granularity = tuple(granularity) else: - raise ValueError( - f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported." - ) + raise ValueError(f"Invalid granularity specification: {granularity}.") return processed_granularity @@ -243,12 +267,24 @@ def _check_hardware_support( AssertionError: If hardware doesn't support the requested granularity ValueError: If invalid granularity type is provided """ - for _granularity in granularities: - if not isinstance(_granularity, (PerTensor, PerRow)): - raise ValueError( - f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported." - ) + is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance( + granularities[1], PerTensor + ) + is_per_row = isinstance(granularities[0], PerRow) and isinstance( + granularities[1], PerRow + ) + is_a_1_128_w_128_128 = granularities[0] == PerBlock((1, 128)) and granularities[ + 1 + ] == PerBlock((128, 128)) + if is_per_tensor or is_per_row: assert is_sm_at_least_89() or is_MI300(), ( "Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+." ) + elif is_a_1_128_w_128_128: + # TODO(future PR): look into AMD support + assert is_sm_at_least_89(), ( + "Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9." + ) + else: + raise ValueError(f"Invalid granularities {granularities}.") diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 47395a15af..bc0ab7afbf 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -15,6 +15,8 @@ from torchao.float8.inference import ( Float8MMConfig, FP8Granularity, + _is_1_128_scaled, + _is_128_128_scaled, _is_rowwise_scaled, _is_tensorwise_scaled, _slice_scale_for_dimension, @@ -22,6 +24,9 @@ preprocess_data, preprocess_scale, ) +from torchao.kernel.blockwise_quantization import ( + blockwise_fp8_gemm, +) from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.quant_primitives import ( _choose_scale_float8, @@ -337,19 +342,39 @@ def _(func, types, args, kwargs): "Input tensor must be rowwise block size" ) w_scale = w_scale.transpose(-1, -2) + elif _is_128_128_scaled(weight_tensor): + assert _is_1_128_scaled(input_tensor), ( + "input_tensor must be 1x128 scaled" + ) + w_scale = w_scale.transpose(-1, -2) input_scale = preprocess_scale(input_scale, input_tensor.shape) inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) - return addmm_float8_unwrapped_inference( - inpt_data, - input_scale, - w_data, - w_scale, - output_dtype=input_tensor.dtype, - bias=bias, - use_fast_accum=scaled_mm_config.use_fast_accum, - ).reshape(out_shape) + if _is_128_128_scaled(weight_tensor): + # TODO(before land): ensure fast_accum is False for blockwise + # TODO(future PR): add testing for torch._scaled_mm with + # blockwise scaling on CUDA 12.9 + # TODO(future PR): add fbgemm_gpu_genai path if available + assert _is_1_128_scaled(input_tensor), "unsupported" + res = blockwise_fp8_gemm( + inpt_data, + input_scale, + w_data.t(), + w_scale, + block_size=128, + ) + else: + res = addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=input_tensor.dtype, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + ) + return res.reshape(out_shape) else: assert not isinstance(input_tensor, TorchAOBaseTensor), ( "Expecting input_tensor to be unquantized" From f76e10b9de0cd139feb6deb587b2484b960716d6 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Oct 2025 07:07:09 -0700 Subject: [PATCH 04/14] Update [ghstack-poisoned] --- .github/workflows/1xL4_tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/1xL4_tests.yml b/.github/workflows/1xL4_tests.yml index 58980d8504..7a1c293074 100644 --- a/.github/workflows/1xL4_tests.yml +++ b/.github/workflows/1xL4_tests.yml @@ -51,3 +51,4 @@ jobs: pytest test/dtypes/test_affine_quantized_float.py --verbose -s ./test/float8/test_everything_single_gpu.sh python test/quantization/quantize_/workflows/float8/test_float8_tensor.py + python test/kernel/test_blockwise_triton.py --verbose -s From 6994e20676b2308adf029fa73456cf4b13795091 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Oct 2025 08:05:42 -0700 Subject: [PATCH 05/14] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized_float.py | 4 +- .../workflows/float8/test_float8_tensor.py | 38 +- torchao/float8/inference.py | 3 +- torchao/kernel/blockwise_quantization.py | 565 +++++++++--------- torchao/quantization/quant_api.py | 6 +- .../workflows/float8/float8_tensor.py | 11 +- torchao/quantization/utils.py | 9 + 7 files changed, 340 insertions(+), 296 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 35870a5e6b..738e9b6164 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -152,7 +152,7 @@ def test_invalid_granularity(self): def test_mismatched_granularity(self): with pytest.raises( ValueError, - match="Different granularities for activation and weight are not supported", + match="Unsupported granularity types", ): Float8DynamicActivationFloat8WeightConfig( granularity=(PerTensor(), PerRow()) @@ -165,7 +165,7 @@ def test_unsupported_granularity(self): class UnsupportedGranularity: pass - with pytest.raises(ValueError, match="Invalid granularity types"): + with pytest.raises(ValueError, match="Unsupported granularity types"): Float8DynamicActivationFloat8WeightConfig( granularity=(UnsupportedGranularity(), UnsupportedGranularity()), ) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 7d8de03908..e7182be255 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -62,37 +62,23 @@ def setUp(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - # @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("mode", ["dynamic", "weight-only"]) + @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize( - "dtype", - [ - torch.bfloat16, - ], - ) - # @common_utils.parametrize("mode", ["dynamic", "weight-only"]) - @common_utils.parametrize( - "mode", - [ - "dynamic", - ], - ) - # @common_utils.parametrize("compile", [True, False]) - @common_utils.parametrize("compile", [False]) - # @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - @common_utils.parametrize( - "granularity", [(PerBlock((1, 128)), PerBlock((128, 128)))] + "granularity", + [PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))], ) @common_utils.parametrize( "kernel_preference", - # [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], - [KernelPreference.TORCH], + [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], ) # Inputs are (M,..), K, N @common_utils.parametrize( "sizes", [ ((128,), 256, 128), - # ((32, 128), 64, 256), + ((32, 128), 256, 512), ], ) def test_fp8_linear_variants( @@ -109,9 +95,17 @@ def test_fp8_linear_variants( and kernel_preference == KernelPreference.FBGEMM ): return unittest.skip( - "per tensor with fbgemm kernel preferece does not work yet" + "per tensor with fbgemm kernel preference does not work yet" ) + if granularity == (PerBlock((1, 128)), PerBlock((128, 128))): + if dtype is torch.float32: + return unittest.skip("unimplemented") + elif mode == "weight-only": + return unittest.skip("unimplemented") + elif kernel_preference is KernelPreference.FBGEMM: + return unittest.skip("unimplemented") + error_message = None if isinstance(granularity, PerRow): if mode == "dynamic" and dtype != torch.bfloat16: diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index a65d51c019..19c77d43d7 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -7,6 +7,7 @@ Defines an nn module designed to be used during inference """ +import math from typing import List, NamedTuple, Optional, Tuple, Union import torch @@ -204,7 +205,7 @@ def _is_1_128_scaled(x: torch.Tensor) -> bool: """ assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" b = x.block_size - return len(b) == 2 and b[0] == 1 and b[1] == 128 + return len(b) >= 2 and math.prod(b[:-1]) == 1 and b[-1] == 128 def _is_128_128_scaled(x: torch.Tensor) -> bool: diff --git a/torchao/kernel/blockwise_quantization.py b/torchao/kernel/blockwise_quantization.py index 1d296249f9..192f6d5887 100644 --- a/torchao/kernel/blockwise_quantization.py +++ b/torchao/kernel/blockwise_quantization.py @@ -8,272 +8,301 @@ from typing import Tuple import torch -import triton -import triton.language as tl -from triton import Config - -# Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py - -fp8_gemm_configs = [ - Config( - {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n}, - num_stages=num_stages, - num_warps=8, - ) - for block_m in [16, 32, 64, 128] - for block_n in [32, 64, 128] - for num_stages in [3, 4, 5, 6] -] - - -@triton.autotune(configs=fp8_gemm_configs, key=["N", "K", "M_BUCKET", "BLOCK_SIZE_K"]) -@triton.jit -def blockwise_fp8_gemm_kernel( - a_ptr, - b_ptr, - c_ptr, - a_s_ptr, - b_s_ptr, - M, - N: tl.constexpr, - K: tl.constexpr, - M_BUCKET: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - k = tl.cdiv(K, BLOCK_SIZE_K) - offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] - b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] - a_s_ptrs = a_s_ptr + offs_m * k - b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(k): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) - a_s = tl.load(a_s_ptrs) - b_s = tl.load(b_s_ptrs) - accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] - a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K - a_s_ptrs += 1 - b_s_ptrs += 1 - - c = accumulator.to(c_ptr.dtype.element_ty) - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) - tl.store(c_ptrs, c, mask=mask) - - -def blockwise_fp8_gemm( - a: torch.Tensor, - a_s: torch.Tensor, - b: torch.Tensor, - b_s: torch.Tensor, - block_size: int = 128, -): - assert a.is_contiguous() and b.is_contiguous() - assert a_s.is_contiguous() and b_s.is_contiguous() - K = a.size(-1) - M = a.numel() // K - N = b.size(0) - M_BUCKET = math.ceil(math.log2(M)) - c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]), - triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) - blockwise_fp8_gemm_kernel[grid]( - a, b, c, a_s, b_s, M, N, K, M_BUCKET, BLOCK_SIZE_K=block_size - ) - return c - - -@triton.jit -def fp8_blockwise_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): - """ - Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. - - Args: - x_ptr (triton.Pointer): Pointer to the input tensor. - y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. - s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. - BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. - - Returns: - None - """ - pid = tl.program_id(axis=0) - offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448.0 - y = x / s - y = y.to(y_ptr.dtype.element_ty) - tl.store(y_ptr + offs, y) - tl.store(s_ptr + pid, s) - - -def fp8_blockwise_act_quant( - x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantizes the input tensor `x` using block-wise quantization with block size being BLOCK_SIZEx1. - - Args: - x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. - block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. - dtype (torch.dtype, optional): The dtype to use for the quantized tensor. Default is `torch.float8_e4m3fn`. - - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The quantized tensor with dtype `dtype`. - - A tensor of scaling factors with dtype `torch.float32`. - """ - assert x.is_contiguous(), "Input tensor must be contiguous" - assert x.size(-1) % block_size == 0, ( - f"Last dimension size must be divisible by block_size (block_size={block_size})" - ) - assert dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2" - y = torch.empty_like(x, dtype=dtype) - s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) - grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) - fp8_blockwise_act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) - return y, s - - -@triton.jit -def fp8_blockwise_weight_quant_kernel( - x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr -): - """ - Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factors in `s_ptr`. - - Args: - x_ptr (tl.pointer): Pointer to the input tensor. - y_ptr (tl.pointer): Pointer to the output tensor where quantized values will be stored. - s_ptr (tl.pointer): Pointer to the output tensor where scaling factors will be stored. - M (int): Number of rows in the weight matrix. - N (int): Number of columns in the weight matrix. - BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. - """ - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - n = tl.cdiv(N, BLOCK_SIZE) - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs = offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) - x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) - s = tl.max(tl.abs(x)) / 448.0 - y = x / s - y = y.to(y_ptr.dtype.element_ty) - tl.store(y_ptr + offs, y, mask=mask) - tl.store(s_ptr + pid_m * n + pid_n, s) - - -def fp8_blockwise_weight_quant( - x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantizes the given weight tensor using block-wise quantization with block size being BLOCK_SIZExBLOCK_SIZE. - - Args: - x (torch.Tensor): The weight tensor to be quantized. - block_size (int, optional): The block size to use for quantization. Defaults to 128. - dtype (torch.dtype, optional): The dtype to use for the quantized tensor. Defaults to `torch.float8_e4m3fn`. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The quantized weight tensor with dtype `dtype`. - - A tensor of scaling factors with dtype `torch.float32`. - """ - assert x.is_contiguous(), "Input tensor must be contiguous" - assert x.dim() == 2, "Input tensor must have 2 dimensions" - assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( - f"Both dimensions of x must be divisible by block_size (block_size={block_size})" - ) - assert dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2" - M, N = x.size() - y = torch.empty_like(x, dtype=dtype) - s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(N, meta["BLOCK_SIZE"]), - ) - fp8_blockwise_weight_quant_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size) - return y, s - - -@triton.jit -def fp8_blockwise_weight_dequant_kernel( - x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr -): - """ - Dequantizes weights using the provided scaling factors and stores the result. - - Args: - x_ptr (tl.pointer): Pointer to the quantized weights. - s_ptr (tl.pointer): Pointer to the scaling factors. - y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. - M (int): Number of rows in the weight matrix. - N (int): Number of columns in the weight matrix. - BLOCK_SIZE (tl.constexpr): Size of the block for tiling. - - Returns: - None - """ - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - n = tl.cdiv(N, BLOCK_SIZE) - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs = offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) - x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) - s = tl.load(s_ptr + pid_m * n + pid_n) - y = x * s - tl.store(y_ptr + offs, y, mask=mask) - - -def fp8_blockwise_weight_dequant( - x: torch.Tensor, s: torch.Tensor, block_size: int = 128 -) -> torch.Tensor: - """ - Dequantizes the given weight tensor using the provided scale tensor. - - Args: - x (torch.Tensor): The quantized weight tensor of shape (M, N). - s (torch.Tensor): The scale tensor of shape (M, N). - block_size (int, optional): The block size to use for dequantization. Defaults to 128. - - Returns: - torch.Tensor: The dequantized weight tensor of the same shape as `x`. - - Raises: - AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. - """ - assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" - assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" - M, N = x.size() - y = torch.empty_like(x, dtype=torch.get_default_dtype()) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(N, meta["BLOCK_SIZE"]), +from torch.utils._triton import has_triton + +if has_triton(): + import triton + import triton.language as tl + from triton import Config + + # Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py + + fp8_gemm_configs = [ + Config( + {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [16, 32, 64, 128] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] + ] + + @triton.autotune( + configs=fp8_gemm_configs, key=["N", "K", "M_BUCKET", "BLOCK_SIZE_K"] ) - fp8_blockwise_weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) - return y + @triton.jit + def blockwise_fp8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + M_BUCKET: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + ): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + @torch.library.custom_op("ao::blockwise_fp8_gemm", mutates_args=()) + def blockwise_fp8_gemm( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + block_size: int = 128, + ) -> torch.Tensor: + assert a.is_contiguous() + assert b.is_contiguous() + assert a_s.is_contiguous() + assert b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + M_BUCKET = math.ceil(math.log2(M)) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + blockwise_fp8_gemm_kernel[grid]( + a, b, c, a_s, b_s, M, N, K, M_BUCKET, BLOCK_SIZE_K=block_size + ) + return c + + @blockwise_fp8_gemm.register_fake + def _(a, a_s, b, b_s, block_size=128): + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + return c + + @triton.jit + def fp8_blockwise_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + def fp8_blockwise_act_quant( + x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization with block size being BLOCK_SIZEx1. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + dtype (torch.dtype, optional): The dtype to use for the quantized tensor. Default is `torch.float8_e4m3fn`. + + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `dtype`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension size must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2" + y = torch.empty_like(x, dtype=dtype) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) + fp8_blockwise_act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + @triton.jit + def fp8_blockwise_weight_quant_kernel( + x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr + ): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factors in `s_ptr`. + + Args: + x_ptr (tl.pointer): Pointer to the input tensor. + y_ptr (tl.pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (tl.pointer): Pointer to the output tensor where scaling factors will be stored. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n + pid_n, s) + + def fp8_blockwise_weight_quant( + x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the given weight tensor using block-wise quantization with block size being BLOCK_SIZExBLOCK_SIZE. + + Args: + x (torch.Tensor): The weight tensor to be quantized. + block_size (int, optional): The block size to use for quantization. Defaults to 128. + dtype (torch.dtype, optional): The dtype to use for the quantized tensor. Defaults to `torch.float8_e4m3fn`. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized weight tensor with dtype `dtype`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.dim() == 2, "Input tensor must have 2 dimensions" + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( + f"Both dimensions of x must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2" + M, N = x.size() + y = torch.empty_like(x, dtype=dtype) + s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + fp8_blockwise_weight_quant_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size) + return y, s + + @triton.jit + def fp8_blockwise_weight_dequant_kernel( + x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr + ): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + def fp8_blockwise_weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128 + ) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), ( + "Input tensors must be contiguous" + ) + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + fp8_blockwise_weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + +else: + + def blockwise_fp8_gemm( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + block_size: int = 128, + ) -> torch.Tensor: + raise AssertionError("unsupported without triton") + + def fp8_blockwise_act_quant( + x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise AssertionError("unsupported without triton") + + def fp8_blockwise_weight_quant( + x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise AssertionError("unsupported without triton") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 134de439c5..e9feae102d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1780,8 +1780,12 @@ def __post_init__(self): if _granularity_is_a_1_128_w_128_128(self.granularity): assert self.activation_value_lb is None, "unimplemented" assert self.activation_value_ub is None, "unimplemented" - assert self.kernel_preference is KernelPreference.TORCH, "unimplemented" + assert self.kernel_preference in ( + KernelPreference.AUTO, + KernelPreference.TORCH, + ), "unimplemented" assert self.mm_config is None, "unimplemented" + assert self.version >= 2, "unimplemented" default_use_fast_accum = False if self.mm_config is None: diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index b7474853f4..4d7fdce225 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -277,7 +277,11 @@ def _(func, types, args, kwargs): if weight_tensor.kernel_preference == KernelPreference.AUTO: kernel_choice = "torch" - if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90(): + if ( + _is_fbgemm_gpu_genai_available() + and is_sm_at_least_90() + and (not _is_128_128_scaled(weight_tensor)) + ): kernel_choice = "fbgemm" elif weight_tensor.kernel_preference == KernelPreference.FBGEMM: kernel_choice = "fbgemm" @@ -294,6 +298,7 @@ def _(func, types, args, kwargs): assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai" mm_config = weight_tensor.mm_config assert mm_config is not None + assert not _is_128_128_scaled(weight_tensor), "unimplemented" out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1]) @@ -355,12 +360,14 @@ def _(func, types, args, kwargs): # TODO(future PR): add testing for torch._scaled_mm with # blockwise scaling on CUDA 12.9 # TODO(future PR): add fbgemm_gpu_genai path if available + # TODO(before land): proper out_dtype handling assert _is_1_128_scaled(input_tensor), "unsupported" + # breakpoint() res = blockwise_fp8_gemm( inpt_data, input_scale, w_data.t(), - w_scale, + w_scale.t(), block_size=128, ) else: diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index b4b1a1087d..db9a5149c3 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -706,6 +706,15 @@ def get_block_size( return tuple(block_size) elif isinstance(granularity, PerBlock): block_size = granularity.block_size + + # pad the start of `block_size` with 1s, to make 2d block_size + # handle tensors of rank 3+ + if len(block_size) < len(input_shape): + block_size_list = list(block_size) + while len(block_size_list) < len(input_shape): + block_size_list.insert(0, 1) + block_size = tuple(block_size_list) + assert len(block_size) == len(input_shape), ( f"Block size {block_size} must have the same number of dimensions as input shape {input_shape}" ) From 1aff468c4010182093a2151a5d8fcd35a1e7da26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Oct 2025 08:24:25 -0700 Subject: [PATCH 06/14] Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index e7182be255..06db73c704 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -90,15 +90,15 @@ def test_fp8_linear_variants( kernel_preference: KernelPreference, sizes: Tuple, ): - if ( - isinstance(granularity, PerTensor) - and kernel_preference == KernelPreference.FBGEMM - ): - return unittest.skip( - "per tensor with fbgemm kernel preference does not work yet" - ) + if isinstance(granularity, PerTensor): + if kernel_preference is KernelPreference.FBGEMM: + return unittest.skip( + "per tensor with fbgemm kernel preference does not work yet" + ) + elif mode == "weight-only": + return unittest.skip("unimplemented") - if granularity == (PerBlock((1, 128)), PerBlock((128, 128))): + elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))): if dtype is torch.float32: return unittest.skip("unimplemented") elif mode == "weight-only": @@ -149,6 +149,20 @@ def test_fp8_linear_variants( quantize_(quantized_model, config) + # ensure weight scaling is what we expect + qs1 = quantized_model.linear1.weight.scale + qs2 = quantized_model.linear2.weight.scale + if granularity == PerTensor(): + assert qs1.shape == (1, 1) + assert qs2.shape == (1, 1) + elif granularity == PerRow(): + assert qs1.shape == (N, 1) + assert qs2.shape == (K, 1) + else: + assert granularity == (PerBlock((1, 128)), PerBlock((128, 128))) + assert qs1.shape == (N // 128, K // 128) + assert qs2.shape == (K // 128, N // 128) + if compile: quantized_model = torch.compile(quantized_model, fullgraph=True) From f6fa134ea37827ea4af53461e8761250442ca2ea Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Oct 2025 11:56:09 -0700 Subject: [PATCH 07/14] Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 23 +++++++++++++++---- .../workflows/float8/float8_tensor.py | 2 ++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 06db73c704..0a5a0c091f 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -39,10 +39,10 @@ class ToyLinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): + def __init__(self, in_features, out_features, bias): super().__init__() - self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) - self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) + self.linear1 = torch.nn.Linear(in_features, out_features, bias=bias) + self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias) def forward(self, x): x = self.linear1(x) @@ -81,6 +81,8 @@ def setUp(self): ((32, 128), 256, 512), ], ) + @common_utils.parametrize("bias", [False, True]) + @torch.no_grad() def test_fp8_linear_variants( self, dtype: torch.dtype, @@ -89,6 +91,7 @@ def test_fp8_linear_variants( granularity, kernel_preference: KernelPreference, sizes: Tuple, + bias: bool, ): if isinstance(granularity, PerTensor): if kernel_preference is KernelPreference.FBGEMM: @@ -106,6 +109,16 @@ def test_fp8_linear_variants( elif kernel_preference is KernelPreference.FBGEMM: return unittest.skip("unimplemented") + if bias is True: + if ( + sizes != (128,), + 256, + 128, + ) or kernel_preference is not KernelPreference.TORCH: + return unittest.skip( + "cut down on number of options to save test time" + ) + error_message = None if isinstance(granularity, PerRow): if mode == "dynamic" and dtype != torch.bfloat16: @@ -134,7 +147,7 @@ def test_fp8_linear_variants( input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + model = ToyLinearModel(K, N, bias).eval().to(dtype).to("cuda") quantized_model = copy.deepcopy(model) @@ -257,7 +270,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes): dtype = torch.bfloat16 input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to("cuda") # reference kernel preference and results # we are using KerenelPreference.TORCH as the reference diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 4d7fdce225..d81083aad9 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -370,6 +370,8 @@ def _(func, types, args, kwargs): w_scale.t(), block_size=128, ) + if bias is not None: + res = res + bias else: res = addmm_float8_unwrapped_inference( inpt_data, From 191121226069826d63afa44dea06a75607c5cf72 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Oct 2025 11:56:48 -0700 Subject: [PATCH 08/14] Update [ghstack-poisoned] --- torchao/quantization/quantize_/workflows/float8/float8_tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index d81083aad9..8b15b428bd 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -362,7 +362,6 @@ def _(func, types, args, kwargs): # TODO(future PR): add fbgemm_gpu_genai path if available # TODO(before land): proper out_dtype handling assert _is_1_128_scaled(input_tensor), "unsupported" - # breakpoint() res = blockwise_fp8_gemm( inpt_data, input_scale, From 9ec8ce131f097239bdbd20124e931f057a484410 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Oct 2025 12:13:23 -0700 Subject: [PATCH 09/14] Update [ghstack-poisoned] --- .../quantize_/workflows/float8/test_float8_tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 0a5a0c091f..51a9749f74 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -110,11 +110,11 @@ def test_fp8_linear_variants( return unittest.skip("unimplemented") if bias is True: + sizes_to_keep = ((128,), 256, 128) if ( - sizes != (128,), - 256, - 128, - ) or kernel_preference is not KernelPreference.TORCH: + sizes != sizes_to_keep + or kernel_preference is not KernelPreference.TORCH + ): return unittest.skip( "cut down on number of options to save test time" ) From ce5a8ebbcfb4f5e9013f0217c80d49e36753bd47 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 30 Oct 2025 04:41:56 -0700 Subject: [PATCH 10/14] Update [ghstack-poisoned] --- .../quantization/quantize_/workflows/float8/float8_tensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 4d7fdce225..964833d072 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -360,9 +360,8 @@ def _(func, types, args, kwargs): # TODO(future PR): add testing for torch._scaled_mm with # blockwise scaling on CUDA 12.9 # TODO(future PR): add fbgemm_gpu_genai path if available - # TODO(before land): proper out_dtype handling + # TODO(future PR): proper out_dtype handling assert _is_1_128_scaled(input_tensor), "unsupported" - # breakpoint() res = blockwise_fp8_gemm( inpt_data, input_scale, From 6a3684b8e69e0ab756f590c1e0020893e793ad67 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 30 Oct 2025 12:50:23 -0700 Subject: [PATCH 11/14] Update [ghstack-poisoned] --- test/quantization/test_quant_primitives.py | 46 ++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index bed8421671..5f7895b4ea 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -14,9 +14,11 @@ MappingType, ZeroPointDomain, _choose_qparams_affine_tinygemm, + _choose_scale_float8, _fake_quantize_affine, _fake_quantize_affine_cachemask, _maybe_expand_scale_to_tensor_shape, + _quantize_affine_float8, choose_qparams_affine, dequantize_affine, quantize_affine, @@ -55,6 +57,23 @@ def check_idempotent(self, fn, *args, **kwargs): return output1 +# from https://github.com/pytorch/pytorch/blob/7563f61cc8a40a5ba21a498a2d98895b4eec3f39/test/test_scaled_matmul_cuda.py#L100 +# with scale modified to be the inverse of the version in PT core +def _tensor_to_scale_block( + x: torch.Tensor, + float8_dtype: torch.dtype, + block_outer: int, + block_inner: int, +) -> tuple[torch.Tensor, torch.Tensor]: + x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer)) + amax = x.abs().amax(dim=[1, 3], keepdim=True).float() + scale = amax / torch.finfo(float8_dtype).max + x = x.div(scale).to(float8_dtype) + x = x.flatten(2, 3).flatten(0, 1) + scale = scale.flatten(2, 3).flatten(0, 1) + return x, scale + + # Legacy tinygemm ops def _get_groupwise_affine_qparams( w, @@ -798,6 +817,33 @@ def test_maybe_expand_scale_to_tensor_shape(self): self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8])) self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2])) + def test_float8_blockwise_scaling(self): + M, K = 512, 1024 + hp_tensor = torch.randn(M, K, dtype=torch.float) + # make the scales from some of the blocks obviously different + hp_tensor[0:128, 0:128] *= 3.0 + hp_tensor[0:128, 128:256] *= 7.0 + hp_tensor[128:256, 0:128] *= 2.0 + hp_tensor[128:256, 128:256] *= 100.0 + + block_size = (128, 128) + + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=torch.float8_e4m3fn, + block_size=block_size, + hp_value_lb=None, + hp_value_ub=None, + ) + data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) + + ref_data, ref_scale = _tensor_to_scale_block( + hp_tensor, torch.float8_e4m3fn, 128, 128 + ) + + torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0) + torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0) + if __name__ == "__main__": unittest.main() From d28b0aedf6a56317af590e5cd0c24364964c6019 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 30 Oct 2025 12:55:57 -0700 Subject: [PATCH 12/14] Update [ghstack-poisoned] --- test/kernel/test_blockwise_triton.py | 1 + torchao/kernel/blockwise_quantization.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/kernel/test_blockwise_triton.py b/test/kernel/test_blockwise_triton.py index 5de88ab7d9..ba377c560f 100644 --- a/test/kernel/test_blockwise_triton.py +++ b/test/kernel/test_blockwise_triton.py @@ -66,6 +66,7 @@ def test_blockwise_fp8_gemm(M, N, K, dtype): A_q, A_s = fp8_blockwise_act_quant(A, dtype=dtype) B_q, B_s = fp8_blockwise_weight_quant(B, dtype=dtype) C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) + assert C_q.dtype == torch.bfloat16, "unsupported" error = torch.linalg.vector_norm(C - C_q) / torch.linalg.vector_norm(C) print(f"Relative Error: {error.item():.6f}") diff --git a/torchao/kernel/blockwise_quantization.py b/torchao/kernel/blockwise_quantization.py index 192f6d5887..1a43e71a97 100644 --- a/torchao/kernel/blockwise_quantization.py +++ b/torchao/kernel/blockwise_quantization.py @@ -92,7 +92,7 @@ def blockwise_fp8_gemm( M = a.numel() // K N = b.size(0) M_BUCKET = math.ceil(math.log2(M)) - c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.bfloat16) grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]), @@ -105,7 +105,7 @@ def blockwise_fp8_gemm( @blockwise_fp8_gemm.register_fake def _(a, a_s, b, b_s, block_size=128): N = b.size(0) - c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.bfloat16) return c @triton.jit From 6c087b47af3007b4e8f3be2022f44382c8639b76 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 31 Oct 2025 06:26:43 -0700 Subject: [PATCH 13/14] Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 9 ++++++--- torchao/quantization/granularity.py | 19 ++++++++++++++++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 06db73c704..5e2a125d15 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -99,11 +99,14 @@ def test_fp8_linear_variants( return unittest.skip("unimplemented") elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))): - if dtype is torch.float32: + if dtype is not torch.bfloat16: return unittest.skip("unimplemented") - elif mode == "weight-only": + elif mode != "dynamic": return unittest.skip("unimplemented") - elif kernel_preference is KernelPreference.FBGEMM: + elif kernel_preference not in ( + KernelPreference.AUTO, + KernelPreference.TORCH, + ): return unittest.skip("unimplemented") error_message = None diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index 6c7b582fe5..75a8ba3724 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -106,11 +106,24 @@ class PerToken(Granularity): @dataclass(frozen=True) class PerBlock(Granularity): """ - Represents per-block granularity in quantization. See - :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for - `block_size` + Represents multidimensional per-block granularity in quantization. + + Example: + * block_size has shape [X, Y] + * input_tensor shape [A] -> scaling undefined + * input_tensor shape [A, B] -> scale shape [A // X, B // Y] + * input_tensor shape [A, B, C] -> scale shape [1, B // X, C // Y] + * input_tensor shape [A, B, C, D] -> scale shape [1, 1, C // X, D // Y], and so on + + Note that `PerBlock((1, Y))` is equivalent to `PerGroup(Y)` + Attributes: block_size (tuple[int, ...]): The size of each quantization group """ + # TODO(future PR): consider renaming this attribute to make the meaning + # of `block_size` consistent. + # 1. `block_size` in this class can support tensors of multiple ranks + # 2. `block_size` in other places in the codebase has rank equal to the + # corresponding tensor block_size: tuple[int, ...] From c4769a67961756a688300f18c2e2c3199428d24c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 31 Oct 2025 09:59:22 -0700 Subject: [PATCH 14/14] Update [ghstack-poisoned] --- torchao/quantization/granularity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index 75a8ba3724..ccf7099c54 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -112,8 +112,8 @@ class PerBlock(Granularity): * block_size has shape [X, Y] * input_tensor shape [A] -> scaling undefined * input_tensor shape [A, B] -> scale shape [A // X, B // Y] - * input_tensor shape [A, B, C] -> scale shape [1, B // X, C // Y] - * input_tensor shape [A, B, C, D] -> scale shape [1, 1, C // X, D // Y], and so on + * input_tensor shape [A, B, C] -> scale shape [A, B // X, C // Y] + * input_tensor shape [A, B, C, D] -> scale shape [A, B, C // X, D // Y], and so on Note that `PerBlock((1, Y))` is equivalent to `PerGroup(Y)`