diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 7d47f2edef..121f996d70 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -440,12 +440,11 @@ def triton_to_mxfp8_dim0_reference( """ from torchao.prototype.mx_formats.mx_tensor import to_mx - # cast across dim0 (rowwise) - no transpose needed scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size) scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu) return ( x_hp_d0_normalized, - scale_e8m0_dim0.unsqueeze(-1), + scale_e8m0_dim0, ) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 2d28ade35d..9fba36d98a 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -32,7 +32,7 @@ MXGemmKernelChoice, ScaleCalculationMode, ) -from torchao.prototype.mx_formats.mx_tensor import to_mx +from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0 from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper logger: logging.Logger = logging.getLogger(__name__) @@ -303,16 +303,16 @@ def forward( # A_data shape: (M, K) # A_scale shape: (M, K//block_size) - A_scale, A_data = to_mx( - A, elem_dtype=torch.float8_e4m3fn, block_size=block_size + A_data, A_scale = triton_to_mxfp8_dim0( + A, + inner_block_size=block_size, ) # B_data shape: (E, N, K) # B_scale shape: (E, N, K//block_size) - B_scales, B_data = to_mx( + B_data, B_scales = triton_to_mxfp8_dim0( B_t.transpose(-2, -1), - elem_dtype=torch.float8_e4m3fn, - block_size=block_size, + inner_block_size=block_size, ) # Convert scales to blocked format for 2d-3d grouped mm @@ -351,8 +351,8 @@ def backward(ctx, grad_out: torch.Tensor): # grad_out_data shape: (M, N) # grad_out_scale shape: (M, N//block_size) - grad_out_scale, grad_out_data = to_mx( - grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size + grad_out_data, grad_out_scale = triton_to_mxfp8_dim0( + grad_out, inner_block_size=block_size ) # Quantize 3d expert weights along N (contraction dimension for next grouped gemm) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 45263c2884..e14a79e774 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1162,7 +1162,9 @@ def triton_to_mxfp8_dim0( assert x.is_contiguous(), "`x` must be contiguous" assert inner_block_size <= 32 - # Get tensor shape + # Reshape tensor to 2d if necessary and get shape + x_orig_shape = x.shape + x = x.reshape(-1, x.shape[-1]) n_rows, n_cols = x.shape # Masking of loads and stores is not well tested yet, so for now enforce @@ -1181,7 +1183,7 @@ def triton_to_mxfp8_dim0( # Create scale tensors for rowwise scaling row_scale = torch.empty( - (n_rows, n_cols // inner_block_size, 1), + (n_rows, n_cols // inner_block_size), dtype=torch.uint8, device=x.device, ) @@ -1202,6 +1204,10 @@ def triton_to_mxfp8_dim0( INNER_BLOCK_SIZE=inner_block_size, ) + # Reshape output back to original shape + output = output.reshape(x_orig_shape) + row_scale = row_scale.reshape(*x_orig_shape[:-1], row_scale.shape[-1]) + return ( output, row_scale.view(torch.float8_e8m0fnu),