From 319c8568de64c78aff6b0e1fba5c479d3da5c7e1 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 16 Nov 2025 22:36:52 +0900 Subject: [PATCH] update logics using early stopping --- torchao/float8/inference.py | 54 ++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 212df9c5db..004adf952b 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -142,40 +142,40 @@ def _slice_scale_for_dimension( """ aten = torch.ops.aten - # Unsupported case for now, this would be 1 scale per data element - if scale.shape == data_shape: - return aten.slice.Tensor(scale, dim, start, end, step) - - # Reconstruct block sizes based on data shape and scale shape - block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape))) - - if dim >= len(block_sizes): - # Slicing beyond the dimensions we care about + # Case 1: Per-tensor quantization (scalar scale) + if scale.numel() <= 1: return scale + # Case 2: Per-row quantization (1D scale) + # Scale is per-element along this dimension + if scale.ndim == 1: + if dim == 0: + return aten.slice.Tensor(scale, 0, start, end, step) + else: + return scale + + # Case 3: Per-block quantization (2D scale) + block_sizes = tuple( + data_shape[i] // scale.shape[i] for i in range(len(scale.shape)) + ) + block_size_for_dim = block_sizes[dim] - if block_size_for_dim == 1: - # Scale is per-element along this dimension - # Slice away as normal - return aten.slice.Tensor(scale, dim, start, end, step) - else: - # There is blocking in this dimension - # Calculate which scale elements correspond to the sliced data - scale_start = start // block_size_for_dim if start is not None else None - scale_end = ( - (end + block_size_for_dim - 1) // block_size_for_dim - if end is not None - else None + if step > 1: + raise NotImplementedError( + "Slicing with step > 1 is not implemented for scale tensors." ) - # Error on Step > 1 - if step > 1: - raise NotImplementedError( - "Slicing with step > 1 is not implemented for scale tensors." - ) + # There is blocking in this dimension + # Calculate which scale elements correspond to the sliced data + scale_start = start // block_size_for_dim if start is not None else None + scale_end = ( + (end + block_size_for_dim - 1) // block_size_for_dim + if end is not None + else None + ) - return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) + return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) def _is_rowwise_scaled(x: torch.Tensor) -> bool: