Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from torchao.prototype.mx_formats.config import ScaleCalculationMode
from torchao.prototype.mx_formats.kernels import (
triton_to_mxfp8_dim0,
triton_to_mxfp8_dim1,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
Expand Down Expand Up @@ -97,6 +98,7 @@ def run(
"dim0_mxfp8_floor",
"dim0_mxfp4_floor",
"dim0_mxfp8_rceil",
"dim0_mxfp8_triton_floor",
"dim1_mxfp8_floor",
"dim1_mxfp8_rceil",
"dim1_mxfp8_triton_floor",
Expand Down Expand Up @@ -222,6 +224,22 @@ def run(
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim0_mxfp8_triton_floor":
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)

for _ in range(2):
__ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
time_us = benchmark_cuda_function_in_microseconds(
lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE),
x,
BLOCK_SIZE,
)
assert y_d0.dtype == torch.float8_e4m3fn
assert s_d0.dtype == torch.float8_e8m0fnu
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mxfp8_floor":
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
Expand Down
33 changes: 33 additions & 0 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
pack_uint6,
triton_f6_e2m3_to_bf16,
triton_f6_e3m2_to_bf16,
triton_to_mxfp8_dim0,
triton_to_mxfp8_dim1,
triton_to_mxfp8_dim1_reference,
unpack_uint4,
Expand Down Expand Up @@ -431,6 +432,23 @@ def test_fp6_e3m2_pack_unpack():
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)


def triton_to_mxfp8_dim0_reference(
x_hp: torch.Tensor, block_size
) -> tuple[torch.Tensor, torch.Tensor]:
"""
A reference version of `triton_to_mxfp8_dim0` for rowwise quantization.
"""
from torchao.prototype.mx_formats.mx_tensor import to_mx

# cast across dim0 (rowwise) - no transpose needed
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
return (
x_hp_d0_normalized,
scale_e8m0_dim0.unsqueeze(-1),
)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_89(),
Expand All @@ -446,6 +464,21 @@ def test_triton_mxfp8_dim1_randn(M, K):
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
)
@pytest.mark.parametrize("M", (256, 2048))
@pytest.mark.parametrize("K", (256, 2048))
def test_triton_mxfp8_dim0_randn(M, K):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"shape",
Expand Down
169 changes: 169 additions & 0 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,175 @@ def to_mxfp8_dim1_kernel(
# TODO(future): mask this store
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)

@triton.autotune(
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
)
@triton.jit
def to_mxfp8_dim0_kernel(
x_ptr, # pointer to input tensor
output_ptr, # pointer to output tensor (row-normalized)
row_scale_ptr, # pointer to store row-wise maximum absolute values
n_rows, # number of rows in the tensor
n_cols, # number of columns in the tensor
ROW_TILE_SIZE: tl.constexpr,
COL_TILE_SIZE: tl.constexpr,
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
):
"""
Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).

This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
Instead of transposing and scaling across columns, this kernel scales across rows.
"""

BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE

# Get program ID
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)

# Calculate starting row and column for this tile
start_row = pid_row * ROW_TILE_SIZE
start_col = pid_col * COL_TILE_SIZE

# Create offsets for the block
row_offsets = tl.arange(0, ROW_TILE_SIZE)
col_offsets = tl.arange(0, COL_TILE_SIZE)

# Compute global row/col positions
rows = start_row + row_offsets[:, None]
cols = start_col + col_offsets[None, :]

# Create masks for out-of-bounds accesses
row_mask = rows < n_rows
col_mask = cols < n_cols
mask = row_mask & col_mask

# Compute memory offsets for row-major layout (rows, cols)
row_major_offsets = (rows * n_cols + cols).to(tl.int32)

# Load the entire block in a single operation
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)

# Reshape to inner tile size for rowwise scaling
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
x_block_r = x_block.reshape(
ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE
)

# Calculate the absolute values of elements in the block
x_block_abs_r = tl.abs(x_block_r)

# Find the maximum absolute value for each row (across columns)
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
row_scale_r, row_scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)

# Divide each row by scale
# Broadcasting row_scale to match x_block's shape
# x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
# row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
row_normalized_r = x_block_r / row_scale_r[:, None]

# Reshape back to original tile size
row_normalized = tl.reshape(row_normalized_r, ROW_TILE_SIZE, COL_TILE_SIZE)

# Quantize to float8
row_normalized = row_normalized.to(tl.float8e4nv)

# Store the row-normalized result in row-major format
tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask)

# reshape row_scale_e8m0_r for proper storage
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
row_scale_e8m0 = row_scale_e8m0_r.reshape(ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)

row_scale_start_offsets = (
(pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE))
* BLOCKS_PER_COL_TILE # number of blocks seen so far
+ pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE
)

row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets

# calculate row_scale_indices
row_scale_indices = tl.arange(0, ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)

# How many values are in all the other rows for this col_pid, need to jump
# over them for every BLOCKS_PER_COL_TILE values
jump_vals_per_row = (n_cols - COL_TILE_SIZE) // INNER_BLOCK_SIZE

# example transformation (specifics depend on tile sizes):
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
row_scale_indices = row_scale_indices + (
(row_scale_indices // BLOCKS_PER_COL_TILE) * jump_vals_per_row
)

# Store the scales
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)

@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
def triton_to_mxfp8_dim0(
x: torch.Tensor, inner_block_size: int = 32
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Input:
* `x` - input tensor, in row major memory layout
* `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes

Output:
* `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
* `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
"""
assert x.is_contiguous(), "`x` must be contiguous"
assert inner_block_size <= 32

# Get tensor shape
n_rows, n_cols = x.shape

# Masking of loads and stores is not well tested yet, so for now enforce
# shapes which do not need masking. Note that this condition depends on max values of
# ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above.
# TODO(future): implement and test masking and remove this restriction
max_row_tile_size = 128
max_col_tile_size = 128
assert n_rows % max_row_tile_size == 0, "unsupported"
assert n_cols % max_col_tile_size == 0, "unsupported"

# Create output tensors
output = torch.empty(
(n_rows, n_cols), dtype=torch.float8_e4m3fn, device=x.device
)

# Create scale tensors for rowwise scaling
row_scale = torch.empty(
(n_rows, n_cols // inner_block_size, 1),
dtype=torch.uint8,
device=x.device,
)

# Calculate grid dimensions based on tile size
grid = lambda META: (
triton.cdiv(n_rows, META["ROW_TILE_SIZE"]),
triton.cdiv(n_cols, META["COL_TILE_SIZE"]),
)

# Launch the kernel
wrap_triton(to_mxfp8_dim0_kernel)[grid](
x_ptr=x,
output_ptr=output,
row_scale_ptr=row_scale,
n_rows=n_rows,
n_cols=n_cols,
INNER_BLOCK_SIZE=inner_block_size,
)

return (
output,
row_scale.view(torch.float8_e8m0fnu),
)

@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
def triton_to_mxfp8_dim1(
x: torch.Tensor, inner_block_size: int = 32
Expand Down
Loading