Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,40 +142,40 @@ def _slice_scale_for_dimension(
"""
aten = torch.ops.aten

# Unsupported case for now, this would be 1 scale per data element
if scale.shape == data_shape:
return aten.slice.Tensor(scale, dim, start, end, step)

# Reconstruct block sizes based on data shape and scale shape
block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape)))

if dim >= len(block_sizes):
# Slicing beyond the dimensions we care about
# Case 1: Per-tensor quantization (scalar scale)
if scale.numel() <= 1:
return scale

# Case 2: Per-row quantization (1D scale)
# Scale is per-element along this dimension
if scale.ndim == 1:
if dim == 0:
return aten.slice.Tensor(scale, 0, start, end, step)
else:
return scale

# Case 3: Per-block quantization (2D scale)
block_sizes = tuple(
data_shape[i] // scale.shape[i] for i in range(len(scale.shape))
)

block_size_for_dim = block_sizes[dim]

if block_size_for_dim == 1:
# Scale is per-element along this dimension
# Slice away as normal
return aten.slice.Tensor(scale, dim, start, end, step)
else:
# There is blocking in this dimension
# Calculate which scale elements correspond to the sliced data
scale_start = start // block_size_for_dim if start is not None else None
scale_end = (
(end + block_size_for_dim - 1) // block_size_for_dim
if end is not None
else None
if step > 1:
raise NotImplementedError(
"Slicing with step > 1 is not implemented for scale tensors."
)

# Error on Step > 1
if step > 1:
raise NotImplementedError(
"Slicing with step > 1 is not implemented for scale tensors."
)
# There is blocking in this dimension
# Calculate which scale elements correspond to the sliced data
scale_start = start // block_size_for_dim if start is not None else None
scale_end = (
(end + block_size_for_dim - 1) // block_size_for_dim
if end is not None
else None
)

return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)
return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)


def _is_rowwise_scaled(x: torch.Tensor) -> bool:
Expand Down