diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 2bb3166d16..95cb5f4ce1 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -296,6 +296,7 @@ def forward( out_dtype: Optional[torch.dtype] = torch.bfloat16, emulated: bool = False, use_triton_for_dim0_cast: bool = False, + scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL, ) -> torch.Tensor: # torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D. assert A.ndim == 2, "A must be 2D" @@ -321,11 +322,13 @@ def forward( A, elem_dtype=torch.float8_e4m3fn, block_size=block_size, + scaling_mode=scale_calculation_mode, ) B_scales, B_data = to_mx( B_t.transpose(-2, -1), elem_dtype=torch.float8_e4m3fn, block_size=block_size, + scaling_mode=scale_calculation_mode, ) # Convert scales to blocked format for 2d-3d grouped mm @@ -355,6 +358,7 @@ def forward( ctx.out_dtype = out_dtype ctx.emulated = emulated ctx.use_triton_for_dim0_cast = use_triton_for_dim0_cast + ctx.scale_calculation_mode = scale_calculation_mode return out @staticmethod @@ -363,6 +367,7 @@ def backward(ctx, grad_out: torch.Tensor): block_size = ctx.block_size out_dtype = ctx.out_dtype use_triton_for_dim0_cast = ctx.use_triton_for_dim0_cast + scale_calculation_mode = ctx.scale_calculation_mode # grad_out_data shape: (M, N) # grad_out_scale shape: (M, N//block_size) @@ -375,13 +380,16 @@ def backward(ctx, grad_out: torch.Tensor): grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size, + scaling_mode=scale_calculation_mode, ) # Quantize 3d expert weights along N (contraction dimension for next grouped gemm) # (E, K, N) -> (E, N, K) B = B_t.transpose(-2, -1) B_data, B_scales = mxfp8_quantize_cuda_3d( - B._data if hasattr(B, "_data") else B, block_size=block_size + B._data if hasattr(B, "_data") else B, + block_size=block_size, + scaling_mode=scale_calculation_mode.value.lower(), ) # (E, N//block_size, K) -> (E, K, N//block_size) B_scales = B_scales.transpose(-2, -1) @@ -413,7 +421,7 @@ def backward(ctx, grad_out: torch.Tensor): hp_dtype=grad_out.dtype, gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, - scale_calculation_mode=ScaleCalculationMode.FLOOR, + scale_calculation_mode=scale_calculation_mode, ) grad_out_t_data = grad_out_t_mx.qdata grad_out_t_scales = grad_out_t_mx.scale @@ -429,7 +437,7 @@ def backward(ctx, grad_out: torch.Tensor): hp_dtype=A.dtype, gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, - scale_calculation_mode=ScaleCalculationMode.FLOOR, + scale_calculation_mode=scale_calculation_mode, ) A_t_data = A_t_mx.qdata A_t_scales = A_t_mx.scale