diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 35870a5e6b..738e9b6164 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -152,7 +152,7 @@ def test_invalid_granularity(self): def test_mismatched_granularity(self): with pytest.raises( ValueError, - match="Different granularities for activation and weight are not supported", + match="Unsupported granularity types", ): Float8DynamicActivationFloat8WeightConfig( granularity=(PerTensor(), PerRow()) @@ -165,7 +165,7 @@ def test_unsupported_granularity(self): class UnsupportedGranularity: pass - with pytest.raises(ValueError, match="Invalid granularity types"): + with pytest.raises(ValueError, match="Unsupported granularity types"): Float8DynamicActivationFloat8WeightConfig( granularity=(UnsupportedGranularity(), UnsupportedGranularity()), ) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 010682474e..5e2a125d15 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -18,6 +18,7 @@ from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, + PerBlock, PerRow, PerTensor, quantize_, @@ -64,7 +65,10 @@ def setUp(self): @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only"]) @common_utils.parametrize("compile", [True, False]) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "granularity", + [PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))], + ) @common_utils.parametrize( "kernel_preference", [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], @@ -74,7 +78,7 @@ def setUp(self): "sizes", [ ((128,), 256, 128), - ((32, 128), 64, 256), + ((32, 128), 256, 512), ], ) def test_fp8_linear_variants( @@ -86,13 +90,24 @@ def test_fp8_linear_variants( kernel_preference: KernelPreference, sizes: Tuple, ): - if ( - isinstance(granularity, PerTensor) - and kernel_preference == KernelPreference.FBGEMM - ): - return unittest.skip( - "per tensor with fbgemm kernel preferece does not work yet" - ) + if isinstance(granularity, PerTensor): + if kernel_preference is KernelPreference.FBGEMM: + return unittest.skip( + "per tensor with fbgemm kernel preference does not work yet" + ) + elif mode == "weight-only": + return unittest.skip("unimplemented") + + elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))): + if dtype is not torch.bfloat16: + return unittest.skip("unimplemented") + elif mode != "dynamic": + return unittest.skip("unimplemented") + elif kernel_preference not in ( + KernelPreference.AUTO, + KernelPreference.TORCH, + ): + return unittest.skip("unimplemented") error_message = None if isinstance(granularity, PerRow): @@ -137,6 +152,20 @@ def test_fp8_linear_variants( quantize_(quantized_model, config) + # ensure weight scaling is what we expect + qs1 = quantized_model.linear1.weight.scale + qs2 = quantized_model.linear2.weight.scale + if granularity == PerTensor(): + assert qs1.shape == (1, 1) + assert qs2.shape == (1, 1) + elif granularity == PerRow(): + assert qs1.shape == (N, 1) + assert qs2.shape == (K, 1) + else: + assert granularity == (PerBlock((1, 128)), PerBlock((128, 128))) + assert qs1.shape == (N // 128, K // 128) + assert qs2.shape == (K // 128, N // 128) + if compile: quantized_model = torch.compile(quantized_model, fullgraph=True) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index bed8421671..5f7895b4ea 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -14,9 +14,11 @@ MappingType, ZeroPointDomain, _choose_qparams_affine_tinygemm, + _choose_scale_float8, _fake_quantize_affine, _fake_quantize_affine_cachemask, _maybe_expand_scale_to_tensor_shape, + _quantize_affine_float8, choose_qparams_affine, dequantize_affine, quantize_affine, @@ -55,6 +57,23 @@ def check_idempotent(self, fn, *args, **kwargs): return output1 +# from https://github.com/pytorch/pytorch/blob/7563f61cc8a40a5ba21a498a2d98895b4eec3f39/test/test_scaled_matmul_cuda.py#L100 +# with scale modified to be the inverse of the version in PT core +def _tensor_to_scale_block( + x: torch.Tensor, + float8_dtype: torch.dtype, + block_outer: int, + block_inner: int, +) -> tuple[torch.Tensor, torch.Tensor]: + x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer)) + amax = x.abs().amax(dim=[1, 3], keepdim=True).float() + scale = amax / torch.finfo(float8_dtype).max + x = x.div(scale).to(float8_dtype) + x = x.flatten(2, 3).flatten(0, 1) + scale = scale.flatten(2, 3).flatten(0, 1) + return x, scale + + # Legacy tinygemm ops def _get_groupwise_affine_qparams( w, @@ -798,6 +817,33 @@ def test_maybe_expand_scale_to_tensor_shape(self): self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8])) self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2])) + def test_float8_blockwise_scaling(self): + M, K = 512, 1024 + hp_tensor = torch.randn(M, K, dtype=torch.float) + # make the scales from some of the blocks obviously different + hp_tensor[0:128, 0:128] *= 3.0 + hp_tensor[0:128, 128:256] *= 7.0 + hp_tensor[128:256, 0:128] *= 2.0 + hp_tensor[128:256, 128:256] *= 100.0 + + block_size = (128, 128) + + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=torch.float8_e4m3fn, + block_size=block_size, + hp_value_lb=None, + hp_value_ub=None, + ) + data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) + + ref_data, ref_scale = _tensor_to_scale_block( + hp_tensor, torch.float8_e4m3fn, 128, 128 + ) + + torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0) + torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f15d38576c..19c77d43d7 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -7,6 +7,7 @@ Defines an nn module designed to be used during inference """ +import math from typing import List, NamedTuple, Optional, Tuple, Union import torch @@ -14,6 +15,7 @@ from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul from torchao.float8.types import FP8Granularity from torchao.quantization.granularity import ( + PerBlock, PerRow, PerTensor, ) @@ -196,6 +198,36 @@ def _is_tensorwise_scaled(x: torch.Tensor) -> bool: ) +def _is_1_128_scaled(x: torch.Tensor) -> bool: + """Checks if a quantized tensor is scaled with a block size of 1x128 + Args: + x: quantized tensor (should have `block_size` attribute) + """ + assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" + b = x.block_size + return len(b) >= 2 and math.prod(b[:-1]) == 1 and b[-1] == 128 + + +def _is_128_128_scaled(x: torch.Tensor) -> bool: + """Checks if a quantized tensor is scaled with a block size of 128x128 + Args: + x: quantized tensor (should have `block_size` attribute) + """ + assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute" + b = x.block_size + return len(b) == 2 and b[0] == 128 and b[1] == 128 + + +def _granularity_is_a_1_128_w_128_128( + g: Union[ + FP8Granularity, + Tuple[FP8Granularity, FP8Granularity], + list[FP8Granularity], + ], +) -> bool: + return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128)) + + def _normalize_granularity( granularity: Optional[ Union[ @@ -211,22 +243,23 @@ def _normalize_granularity( elif isinstance(granularity, (PerTensor, PerRow)): processed_granularity = (granularity, granularity) elif isinstance(granularity, (tuple, list)) and len(granularity) == 2: - if not ( - isinstance(granularity[0], (PerTensor, PerRow)) - and isinstance(granularity[1], (PerTensor, PerRow)) - ): - raise ValueError( - f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported." - ) + is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance( + granularity[1], PerTensor + ) + is_per_row = isinstance(granularity[0], PerRow) and isinstance( + granularity[1], PerRow + ) + is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity) + + if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128): + raise ValueError(f"Unsupported granularity types: {granularity}.") if not isinstance(granularity[0], type(granularity[1])): raise ValueError( - f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." + f"Different granularities for activation and weight are not supported: {granularity}." ) processed_granularity = tuple(granularity) else: - raise ValueError( - f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported." - ) + raise ValueError(f"Invalid granularity specification: {granularity}.") return processed_granularity @@ -243,12 +276,22 @@ def _check_hardware_support( AssertionError: If hardware doesn't support the requested granularity ValueError: If invalid granularity type is provided """ - for _granularity in granularities: - if not isinstance(_granularity, (PerTensor, PerRow)): - raise ValueError( - f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported." - ) + is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance( + granularities[1], PerTensor + ) + is_per_row = isinstance(granularities[0], PerRow) and isinstance( + granularities[1], PerRow + ) + is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities) + if is_per_tensor or is_per_row: assert is_sm_at_least_89() or is_MI300(), ( "Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+." ) + elif is_a_1_128_w_128_128: + # TODO(future PR): look into AMD support + assert is_sm_at_least_89(), ( + "Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9." + ) + else: + raise ValueError(f"Invalid granularities {granularities}.") diff --git a/torchao/kernel/blockwise_quantization.py b/torchao/kernel/blockwise_quantization.py index 1d296249f9..192f6d5887 100644 --- a/torchao/kernel/blockwise_quantization.py +++ b/torchao/kernel/blockwise_quantization.py @@ -8,272 +8,301 @@ from typing import Tuple import torch -import triton -import triton.language as tl -from triton import Config - -# Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py - -fp8_gemm_configs = [ - Config( - {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n}, - num_stages=num_stages, - num_warps=8, - ) - for block_m in [16, 32, 64, 128] - for block_n in [32, 64, 128] - for num_stages in [3, 4, 5, 6] -] - - -@triton.autotune(configs=fp8_gemm_configs, key=["N", "K", "M_BUCKET", "BLOCK_SIZE_K"]) -@triton.jit -def blockwise_fp8_gemm_kernel( - a_ptr, - b_ptr, - c_ptr, - a_s_ptr, - b_s_ptr, - M, - N: tl.constexpr, - K: tl.constexpr, - M_BUCKET: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - k = tl.cdiv(K, BLOCK_SIZE_K) - offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] - b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] - a_s_ptrs = a_s_ptr + offs_m * k - b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(k): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) - a_s = tl.load(a_s_ptrs) - b_s = tl.load(b_s_ptrs) - accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] - a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K - a_s_ptrs += 1 - b_s_ptrs += 1 - - c = accumulator.to(c_ptr.dtype.element_ty) - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) - tl.store(c_ptrs, c, mask=mask) - - -def blockwise_fp8_gemm( - a: torch.Tensor, - a_s: torch.Tensor, - b: torch.Tensor, - b_s: torch.Tensor, - block_size: int = 128, -): - assert a.is_contiguous() and b.is_contiguous() - assert a_s.is_contiguous() and b_s.is_contiguous() - K = a.size(-1) - M = a.numel() // K - N = b.size(0) - M_BUCKET = math.ceil(math.log2(M)) - c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]), - triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) - blockwise_fp8_gemm_kernel[grid]( - a, b, c, a_s, b_s, M, N, K, M_BUCKET, BLOCK_SIZE_K=block_size - ) - return c - - -@triton.jit -def fp8_blockwise_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): - """ - Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. - - Args: - x_ptr (triton.Pointer): Pointer to the input tensor. - y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. - s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. - BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. - - Returns: - None - """ - pid = tl.program_id(axis=0) - offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448.0 - y = x / s - y = y.to(y_ptr.dtype.element_ty) - tl.store(y_ptr + offs, y) - tl.store(s_ptr + pid, s) - - -def fp8_blockwise_act_quant( - x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantizes the input tensor `x` using block-wise quantization with block size being BLOCK_SIZEx1. - - Args: - x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. - block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. - dtype (torch.dtype, optional): The dtype to use for the quantized tensor. Default is `torch.float8_e4m3fn`. - - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The quantized tensor with dtype `dtype`. - - A tensor of scaling factors with dtype `torch.float32`. - """ - assert x.is_contiguous(), "Input tensor must be contiguous" - assert x.size(-1) % block_size == 0, ( - f"Last dimension size must be divisible by block_size (block_size={block_size})" - ) - assert dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2" - y = torch.empty_like(x, dtype=dtype) - s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) - grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) - fp8_blockwise_act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) - return y, s - - -@triton.jit -def fp8_blockwise_weight_quant_kernel( - x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr -): - """ - Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factors in `s_ptr`. - - Args: - x_ptr (tl.pointer): Pointer to the input tensor. - y_ptr (tl.pointer): Pointer to the output tensor where quantized values will be stored. - s_ptr (tl.pointer): Pointer to the output tensor where scaling factors will be stored. - M (int): Number of rows in the weight matrix. - N (int): Number of columns in the weight matrix. - BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. - """ - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - n = tl.cdiv(N, BLOCK_SIZE) - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs = offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) - x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) - s = tl.max(tl.abs(x)) / 448.0 - y = x / s - y = y.to(y_ptr.dtype.element_ty) - tl.store(y_ptr + offs, y, mask=mask) - tl.store(s_ptr + pid_m * n + pid_n, s) - - -def fp8_blockwise_weight_quant( - x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantizes the given weight tensor using block-wise quantization with block size being BLOCK_SIZExBLOCK_SIZE. - - Args: - x (torch.Tensor): The weight tensor to be quantized. - block_size (int, optional): The block size to use for quantization. Defaults to 128. - dtype (torch.dtype, optional): The dtype to use for the quantized tensor. Defaults to `torch.float8_e4m3fn`. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The quantized weight tensor with dtype `dtype`. - - A tensor of scaling factors with dtype `torch.float32`. - """ - assert x.is_contiguous(), "Input tensor must be contiguous" - assert x.dim() == 2, "Input tensor must have 2 dimensions" - assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( - f"Both dimensions of x must be divisible by block_size (block_size={block_size})" - ) - assert dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2" - M, N = x.size() - y = torch.empty_like(x, dtype=dtype) - s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(N, meta["BLOCK_SIZE"]), - ) - fp8_blockwise_weight_quant_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size) - return y, s - - -@triton.jit -def fp8_blockwise_weight_dequant_kernel( - x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr -): - """ - Dequantizes weights using the provided scaling factors and stores the result. - - Args: - x_ptr (tl.pointer): Pointer to the quantized weights. - s_ptr (tl.pointer): Pointer to the scaling factors. - y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. - M (int): Number of rows in the weight matrix. - N (int): Number of columns in the weight matrix. - BLOCK_SIZE (tl.constexpr): Size of the block for tiling. - - Returns: - None - """ - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - n = tl.cdiv(N, BLOCK_SIZE) - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs = offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) - x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) - s = tl.load(s_ptr + pid_m * n + pid_n) - y = x * s - tl.store(y_ptr + offs, y, mask=mask) - - -def fp8_blockwise_weight_dequant( - x: torch.Tensor, s: torch.Tensor, block_size: int = 128 -) -> torch.Tensor: - """ - Dequantizes the given weight tensor using the provided scale tensor. - - Args: - x (torch.Tensor): The quantized weight tensor of shape (M, N). - s (torch.Tensor): The scale tensor of shape (M, N). - block_size (int, optional): The block size to use for dequantization. Defaults to 128. - - Returns: - torch.Tensor: The dequantized weight tensor of the same shape as `x`. - - Raises: - AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. - """ - assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" - assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" - M, N = x.size() - y = torch.empty_like(x, dtype=torch.get_default_dtype()) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE"]), - triton.cdiv(N, meta["BLOCK_SIZE"]), +from torch.utils._triton import has_triton + +if has_triton(): + import triton + import triton.language as tl + from triton import Config + + # Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py + + fp8_gemm_configs = [ + Config( + {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [16, 32, 64, 128] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] + ] + + @triton.autotune( + configs=fp8_gemm_configs, key=["N", "K", "M_BUCKET", "BLOCK_SIZE_K"] ) - fp8_blockwise_weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) - return y + @triton.jit + def blockwise_fp8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + M_BUCKET: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + ): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + @torch.library.custom_op("ao::blockwise_fp8_gemm", mutates_args=()) + def blockwise_fp8_gemm( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + block_size: int = 128, + ) -> torch.Tensor: + assert a.is_contiguous() + assert b.is_contiguous() + assert a_s.is_contiguous() + assert b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + M_BUCKET = math.ceil(math.log2(M)) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + blockwise_fp8_gemm_kernel[grid]( + a, b, c, a_s, b_s, M, N, K, M_BUCKET, BLOCK_SIZE_K=block_size + ) + return c + + @blockwise_fp8_gemm.register_fake + def _(a, a_s, b, b_s, block_size=128): + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + return c + + @triton.jit + def fp8_blockwise_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + def fp8_blockwise_act_quant( + x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization with block size being BLOCK_SIZEx1. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + dtype (torch.dtype, optional): The dtype to use for the quantized tensor. Default is `torch.float8_e4m3fn`. + + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `dtype`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension size must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2" + y = torch.empty_like(x, dtype=dtype) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) + fp8_blockwise_act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + + @triton.jit + def fp8_blockwise_weight_quant_kernel( + x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr + ): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factors in `s_ptr`. + + Args: + x_ptr (tl.pointer): Pointer to the input tensor. + y_ptr (tl.pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (tl.pointer): Pointer to the output tensor where scaling factors will be stored. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.max(tl.abs(x)) / 448.0 + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n + pid_n, s) + + def fp8_blockwise_weight_quant( + x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the given weight tensor using block-wise quantization with block size being BLOCK_SIZExBLOCK_SIZE. + + Args: + x (torch.Tensor): The weight tensor to be quantized. + block_size (int, optional): The block size to use for quantization. Defaults to 128. + dtype (torch.dtype, optional): The dtype to use for the quantized tensor. Defaults to `torch.float8_e4m3fn`. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized weight tensor with dtype `dtype`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.dim() == 2, "Input tensor must have 2 dimensions" + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, ( + f"Both dimensions of x must be divisible by block_size (block_size={block_size})" + ) + assert dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ], "dtype must be torch.float8_e4m3fn or torch.float8_e5m2" + M, N = x.size() + y = torch.empty_like(x, dtype=dtype) + s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + fp8_blockwise_weight_quant_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size) + return y, s + + @triton.jit + def fp8_blockwise_weight_dequant_kernel( + x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr + ): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + def fp8_blockwise_weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128 + ) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), ( + "Input tensors must be contiguous" + ) + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) + fp8_blockwise_weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y + +else: + + def blockwise_fp8_gemm( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + block_size: int = 128, + ) -> torch.Tensor: + raise AssertionError("unsupported without triton") + + def fp8_blockwise_act_quant( + x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise AssertionError("unsupported without triton") + + def fp8_blockwise_weight_quant( + x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise AssertionError("unsupported without triton") diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index 6c7b582fe5..ccf7099c54 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -106,11 +106,24 @@ class PerToken(Granularity): @dataclass(frozen=True) class PerBlock(Granularity): """ - Represents per-block granularity in quantization. See - :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for - `block_size` + Represents multidimensional per-block granularity in quantization. + + Example: + * block_size has shape [X, Y] + * input_tensor shape [A] -> scaling undefined + * input_tensor shape [A, B] -> scale shape [A // X, B // Y] + * input_tensor shape [A, B, C] -> scale shape [A, B // X, C // Y] + * input_tensor shape [A, B, C, D] -> scale shape [A, B, C // X, D // Y], and so on + + Note that `PerBlock((1, Y))` is equivalent to `PerGroup(Y)` + Attributes: block_size (tuple[int, ...]): The size of each quantization group """ + # TODO(future PR): consider renaming this attribute to make the meaning + # of `block_size` consistent. + # 1. `block_size` in this class can support tensors of multiple ranks + # 2. `block_size` in other places in the codebase has rank equal to the + # corresponding tensor block_size: tuple[int, ...] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 417150229c..e9feae102d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -62,6 +62,7 @@ Float8MMConfig, FP8Granularity, _check_hardware_support, + _granularity_is_a_1_128_w_128_128, _normalize_granularity, ) from torchao.quantization.linear_activation_weight_observed_tensor import ( @@ -1770,13 +1771,26 @@ def __post_init__(self): torch._C._log_api_usage_once( "torchao.quantization.Float8DynamicActivationFloat8WeightConfig" ) - if self.mm_config is None: - self.mm_config = Float8MMConfig(use_fast_accum=True) activation_granularity, weight_granularity = _normalize_granularity( self.granularity ) self.granularity = [activation_granularity, weight_granularity] + default_use_fast_accum = True + if _granularity_is_a_1_128_w_128_128(self.granularity): + assert self.activation_value_lb is None, "unimplemented" + assert self.activation_value_ub is None, "unimplemented" + assert self.kernel_preference in ( + KernelPreference.AUTO, + KernelPreference.TORCH, + ), "unimplemented" + assert self.mm_config is None, "unimplemented" + assert self.version >= 2, "unimplemented" + default_use_fast_accum = False + + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=default_use_fast_accum) + # for bc float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper( diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 47395a15af..964833d072 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -15,6 +15,8 @@ from torchao.float8.inference import ( Float8MMConfig, FP8Granularity, + _is_1_128_scaled, + _is_128_128_scaled, _is_rowwise_scaled, _is_tensorwise_scaled, _slice_scale_for_dimension, @@ -22,6 +24,9 @@ preprocess_data, preprocess_scale, ) +from torchao.kernel.blockwise_quantization import ( + blockwise_fp8_gemm, +) from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.quant_primitives import ( _choose_scale_float8, @@ -272,7 +277,11 @@ def _(func, types, args, kwargs): if weight_tensor.kernel_preference == KernelPreference.AUTO: kernel_choice = "torch" - if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90(): + if ( + _is_fbgemm_gpu_genai_available() + and is_sm_at_least_90() + and (not _is_128_128_scaled(weight_tensor)) + ): kernel_choice = "fbgemm" elif weight_tensor.kernel_preference == KernelPreference.FBGEMM: kernel_choice = "fbgemm" @@ -289,6 +298,7 @@ def _(func, types, args, kwargs): assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai" mm_config = weight_tensor.mm_config assert mm_config is not None + assert not _is_128_128_scaled(weight_tensor), "unimplemented" out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1]) @@ -337,19 +347,39 @@ def _(func, types, args, kwargs): "Input tensor must be rowwise block size" ) w_scale = w_scale.transpose(-1, -2) + elif _is_128_128_scaled(weight_tensor): + assert _is_1_128_scaled(input_tensor), ( + "input_tensor must be 1x128 scaled" + ) + w_scale = w_scale.transpose(-1, -2) input_scale = preprocess_scale(input_scale, input_tensor.shape) inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) - return addmm_float8_unwrapped_inference( - inpt_data, - input_scale, - w_data, - w_scale, - output_dtype=input_tensor.dtype, - bias=bias, - use_fast_accum=scaled_mm_config.use_fast_accum, - ).reshape(out_shape) + if _is_128_128_scaled(weight_tensor): + # TODO(future PR): add testing for torch._scaled_mm with + # blockwise scaling on CUDA 12.9 + # TODO(future PR): add fbgemm_gpu_genai path if available + # TODO(future PR): proper out_dtype handling + assert _is_1_128_scaled(input_tensor), "unsupported" + res = blockwise_fp8_gemm( + inpt_data, + input_scale, + w_data.t(), + w_scale.t(), + block_size=128, + ) + else: + res = addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=input_tensor.dtype, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + ) + return res.reshape(out_shape) else: assert not isinstance(input_tensor, TorchAOBaseTensor), ( "Expecting input_tensor to be unquantized" diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index b4b1a1087d..db9a5149c3 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -706,6 +706,15 @@ def get_block_size( return tuple(block_size) elif isinstance(granularity, PerBlock): block_size = granularity.block_size + + # pad the start of `block_size` with 1s, to make 2d block_size + # handle tensors of rank 3+ + if len(block_size) < len(input_shape): + block_size_list = list(block_size) + while len(block_size_list) < len(input_shape): + block_size_list.insert(0, 1) + block_size = tuple(block_size_list) + assert len(block_size) == len(input_shape), ( f"Block size {block_size} must have the same number of dimensions as input shape {input_shape}" )