From e25cced310210d4fca0e1e5bbf2ff3ac177054c4 Mon Sep 17 00:00:00 2001 From: Armand Sauzay Date: Fri, 21 Nov 2025 08:08:37 -0800 Subject: [PATCH] Enable specifying output dtype for fp8 quantized communication (#5154) Summary: X-link: https://github.com/meta-pytorch/torchrec/pull/3568 X-link: https://github.com/facebookresearch/FBGEMM/pull/2154 Adding fp8_output_dtype parameter to the qcomms config allowing fp8 to dequantize in different float formats as opposed to only FP32 Reviewed By: spcyppt Differential Revision: D86890315 --- fbgemm_gpu/fbgemm_gpu/quantize_comm.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index 3b5c6dfccc..806dcb4959 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -123,6 +123,7 @@ def _dequantize_tensor( comm_precision: SparseType, ctx: Optional[QuantizationContext] = None, is_fwd: bool = True, + fp8_output_dtype: Optional[SparseType] = None, ) -> torch.Tensor: if comm_precision == SparseType.FP32: assert quantized_tensor.dtype == torch.float @@ -137,8 +138,14 @@ def _dequantize_tensor( if ctx is not None and ctx.row_dim > 0: row_dim_quant = ctx.row_dim_quant quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant)) + # use provided fp8_output_dtype or default to FP32 (0) + output_dtype_int = ( + fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0 + ) dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat( - quantized_tensor_2d, is_fwd + quantized_tensor_2d, + is_fwd, + output_dtype_int, ) return dequant_tensor.view(-1) else: @@ -168,6 +175,7 @@ def __init__( row_dim: Optional[int] = None, is_fwd: bool = True, rounding_mode: Optional[RoundingMode] = None, + fp8_output_dtype: Optional[SparseType] = None, ) -> None: if loss_scale is not None: if comm_precision not in [SparseType.FP16, SparseType.BF16]: @@ -185,6 +193,7 @@ def __init__( self._is_fwd = is_fwd self._row_dim: int = -1 if row_dim is None else row_dim self._rounding_mode: Optional[RoundingMode] = rounding_mode + self._fp8_output_dtype: Optional[SparseType] = fp8_output_dtype if self._comm_precision == SparseType.MX4: self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim self._rounding_mode = ( @@ -216,7 +225,11 @@ def decode( f"## decoder {self._comm_precision} {self._loss_scale} ##" ): dequantized_tensor = _dequantize_tensor( - input_tensor, self._comm_precision, ctx, self._is_fwd + input_tensor, + self._comm_precision, + ctx, + self._is_fwd, + fp8_output_dtype=self._fp8_output_dtype, ) return dequantized_tensor