Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions torchao/prototype/moe_training/kernels/mxfp8/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
blockwise_barrier,
sync_threads,
)
from torchao.prototype.mx_formats.config import ScaleCalculationMode
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx


Expand Down Expand Up @@ -473,11 +473,9 @@ def forward(
"""
# Quantize input
block_size = 32
input_scales, input_data = to_mx(
input_data, input_scales = triton_to_mxfp8_dim0(
input,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
scaling_mode=ScaleCalculationMode.RCEIL,
inner_block_size=block_size,
)

# Dispatch data (async)
Expand Down Expand Up @@ -529,11 +527,9 @@ def backward(ctx, grad_output_hp):

# Quantize grad_output
block_size = 32
grad_out_scales, grad_out_data = to_mx(
grad_out_data, grad_out_scales = triton_to_mxfp8_dim0(
grad_output_hp,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
scaling_mode=ScaleCalculationMode.RCEIL,
inner_block_size=block_size,
)

# Dispatch data (async)
Expand Down
53 changes: 23 additions & 30 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ def to_mxfp8_dim1_kernel(

@triton.autotune(
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
key=["n_cols", "INNER_BLOCK_SIZE"],
)
@triton.jit
def to_mxfp8_dim0_kernel(
Expand Down Expand Up @@ -1118,33 +1118,31 @@ def to_mxfp8_dim0_kernel(
# Store the row-normalized result in row-major format
tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask)

# reshape row_scale_e8m0_r for proper storage
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
row_scale_e8m0 = row_scale_e8m0_r.reshape(ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)
# For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
# Calculate base offset for this tile's scales
scales_per_row = n_cols // INNER_BLOCK_SIZE

row_scale_start_offsets = (
(pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE))
* BLOCKS_PER_COL_TILE # number of blocks seen so far
+ pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE
# Create row and column indices for scale storage
scale_row_indices = tl.arange(0, ROW_TILE_SIZE)[:, None] + (
pid_row * ROW_TILE_SIZE
)
scale_col_indices = tl.arange(0, BLOCKS_PER_COL_TILE)[None, :] + (
pid_col * BLOCKS_PER_COL_TILE
)

row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets

# calculate row_scale_indices
row_scale_indices = tl.arange(0, ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)
# Calculate linear indices into scale tensor
scale_offsets = scale_row_indices * scales_per_row + scale_col_indices

# How many values are in all the other rows for this col_pid, need to jump
# over them for every BLOCKS_PER_COL_TILE values
jump_vals_per_row = (n_cols - COL_TILE_SIZE) // INNER_BLOCK_SIZE
# Create masks for valid scale indices
scale_row_mask = scale_row_indices < n_rows
scale_col_mask = scale_col_indices < scales_per_row
scale_mask = scale_row_mask & scale_col_mask

# example transformation (specifics depend on tile sizes):
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
row_scale_indices = row_scale_indices + (
(row_scale_indices // BLOCKS_PER_COL_TILE) * jump_vals_per_row
)
# Reshape scale values and masks to match the flattened layout
row_scale_e8m0_2d = row_scale_e8m0_r.reshape(ROW_TILE_SIZE, BLOCKS_PER_COL_TILE)

# Store the scales
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)
# Store the scales with proper masking
tl.store(row_scale_ptr + scale_offsets, row_scale_e8m0_2d, mask=scale_mask)

@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
def triton_to_mxfp8_dim0(
Expand All @@ -1167,14 +1165,9 @@ def triton_to_mxfp8_dim0(
x = x.reshape(-1, x.shape[-1])
n_rows, n_cols = x.shape

# Masking of loads and stores is not well tested yet, so for now enforce
# shapes which do not need masking. Note that this condition depends on max values of
# ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above.
# TODO(future): implement and test masking and remove this restriction
max_row_tile_size = 128
max_col_tile_size = 128
assert n_rows % max_row_tile_size == 0, "unsupported"
assert n_cols % max_col_tile_size == 0, "unsupported"
assert n_cols % inner_block_size == 0, (
"columns must be divisible by inner block size"
)

# Create output tensors
output = torch.empty(
Expand Down
Loading