Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,11 @@ def triton_to_mxfp8_dim0_reference(
"""
from torchao.prototype.mx_formats.mx_tensor import to_mx

# cast across dim0 (rowwise) - no transpose needed
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
return (
x_hp_d0_normalized,
scale_e8m0_dim0.unsqueeze(-1),
scale_e8m0_dim0,
)


Expand Down
16 changes: 8 additions & 8 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
MXGemmKernelChoice,
ScaleCalculationMode,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -303,16 +303,16 @@ def forward(

# A_data shape: (M, K)
# A_scale shape: (M, K//block_size)
A_scale, A_data = to_mx(
A, elem_dtype=torch.float8_e4m3fn, block_size=block_size
A_data, A_scale = triton_to_mxfp8_dim0(
A,
inner_block_size=block_size,
)

# B_data shape: (E, N, K)
# B_scale shape: (E, N, K//block_size)
B_scales, B_data = to_mx(
B_data, B_scales = triton_to_mxfp8_dim0(
B_t.transpose(-2, -1),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
inner_block_size=block_size,
)

# Convert scales to blocked format for 2d-3d grouped mm
Expand Down Expand Up @@ -351,8 +351,8 @@ def backward(ctx, grad_out: torch.Tensor):

# grad_out_data shape: (M, N)
# grad_out_scale shape: (M, N//block_size)
grad_out_scale, grad_out_data = to_mx(
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
grad_out_data, grad_out_scale = triton_to_mxfp8_dim0(
grad_out, inner_block_size=block_size
)

# Quantize 3d expert weights along N (contraction dimension for next grouped gemm)
Expand Down
10 changes: 8 additions & 2 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,9 @@ def triton_to_mxfp8_dim0(
assert x.is_contiguous(), "`x` must be contiguous"
assert inner_block_size <= 32

# Get tensor shape
# Reshape tensor to 2d if necessary and get shape
x_orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
n_rows, n_cols = x.shape

# Masking of loads and stores is not well tested yet, so for now enforce
Expand All @@ -1181,7 +1183,7 @@ def triton_to_mxfp8_dim0(

# Create scale tensors for rowwise scaling
row_scale = torch.empty(
(n_rows, n_cols // inner_block_size, 1),
(n_rows, n_cols // inner_block_size),
dtype=torch.uint8,
device=x.device,
)
Expand All @@ -1202,6 +1204,10 @@ def triton_to_mxfp8_dim0(
INNER_BLOCK_SIZE=inner_block_size,
)

# Reshape output back to original shape
output = output.reshape(x_orig_shape)
row_scale = row_scale.reshape(*x_orig_shape[:-1], row_scale.shape[-1])

return (
output,
row_scale.view(torch.float8_e8m0fnu),
Expand Down
Loading