From 9c1110a06722afbee474ff3c640bc2435a41b12d Mon Sep 17 00:00:00 2001 From: Randy Shuai Date: Wed, 24 Sep 2025 14:53:40 -0700 Subject: [PATCH] Fix torchAO shape check on fp8 tensors Summary: Previously the shape mismatch on input and weight scales caused semi-tensor fall back to dense version, QPS therefore dropped to 20k from 29k. With this fix, we prevented the fall back to ensure the usage of semi-tensor during inferencing, which boosted QPS to [32k](https://www.internalfb.com/intern/paste/P1962815637/) on 10 linears, of shape(10k+, 1k+). Differential Revision: D83184235 --- .../floatx/cutlass_semi_sparse_layout.py | 60 ++++++++++++++----- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py index e49e8e8129..525ccd073b 100644 --- a/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py +++ b/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py @@ -190,23 +190,53 @@ def _apply_fn_to_data(self, fn): def _linear_fp8_act_fp8_weight_sparse_cutlass_check(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import Float8Layout - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and isinstance(input_tensor._layout, Float8Layout) - and input_tensor.dtype in (torch.float16, torch.bfloat16) - and len(input_tensor.shape) >= 2 - and input_tensor.tensor_impl.scale.dtype == torch.float32 - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 - and isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, CutlassSemiSparseLayout) - and weight_tensor.dtype == input_tensor.dtype - and len(weight_tensor.shape) == 2 - and weight_tensor.tensor_impl.scale.dtype == torch.float32 - and len(weight_tensor.tensor_impl.scale.shape) == 1 - and (bias is None or bias.dtype == input_tensor.dtype) - and (bias is None or len(bias.shape) == 1) + base_check = ( + isinstance(input_tensor, AffineQuantizedTensor) and + isinstance(input_tensor._layout, Float8Layout) and + input_tensor.dtype in (torch.float16, torch.bfloat16) and + len(input_tensor.shape) >= 2 and + input_tensor.tensor_impl.scale.dtype == torch.float32 and + isinstance(weight_tensor, AffineQuantizedTensor) and + isinstance(weight_tensor._layout, CutlassSemiSparseLayout) and + weight_tensor.dtype == input_tensor.dtype and + len(weight_tensor.shape) == 2 and + weight_tensor.tensor_impl.scale.dtype == torch.float32 and + (bias is None or bias.dtype == input_tensor.dtype) and + (bias is None or len(bias.shape) == 1) ) + if base_check: + + # do extra check and reshape if needed + input_tensor_squeezed = False + if len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) and \ + len(input_tensor.tensor_impl.scale.shape) > 1 and \ + input_tensor.tensor_impl.scale.shape[-1] == 1: + input_tensor.tensor_impl.scale = torch.squeeze(input_tensor.tensor_impl.scale, dim=-1) + input_tensor_squeezed = True + + weight_tensor_squeezed = False + if len(weight_tensor.tensor_impl.scale.shape) == 2 and \ + weight_tensor.tensor_impl.scale.shape[-1] == 1: + weight_tensor.tensor_impl.scale = torch.squeeze(weight_tensor.tensor_impl.scale, dim=-1) + weight_tensor_squeezed = True + + extra_check = ( + len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and len(weight_tensor.tensor_impl.scale.shape) == 1 + ) + + if not extra_check: # revert if extra check failed + if input_tensor_squeezed: + input_tensor.tensor_impl.scale = torch.unsqueeze(input_tensor.tensor_impl.scale, dim=-1) + if weight_tensor_squeezed: + weight_tensor.tensor_impl.scale = torch.unsqueeze(weight_tensor.tensor_impl.scale, dim=-1) + + return extra_check + + else: + return False + def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, bias): from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8