From e776c4194561e98c6b63d7f03bb8e30caa05d179 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 12 Sep 2025 11:09:51 -0400 Subject: [PATCH 1/2] Fix test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe Signed-off-by: mgoin --- tests/kernels/moe/test_mxfp4_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 9fd72ee152b5..a3b8f07638d9 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -771,11 +771,11 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, w13_ref = dequant_mxfp4_batches( w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( - num_experts, 2 * intermediate_size, hidden_size) + num_experts, 2 * intermediate_size, hidden_size).to(device) w2_ref = dequant_mxfp4_batches( w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( - num_experts, hidden_size, intermediate_size) + num_experts, hidden_size, intermediate_size).to(device) # Quantize activations for SM100 path and dequantize for reference hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32) From f84512c59aee0c7421667dd9972bcdc0b7d223b3 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 12 Sep 2025 18:38:11 +0000 Subject: [PATCH 2/2] Fix dummy def for sm100_cutlass_mla_decode Signed-off-by: mgoin --- csrc/attention/mla/sm100_cutlass_mla_kernel.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index c60f1823b8a1..d1874515cc8f 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -43,6 +43,7 @@ void sm100_cutlass_mla_decode( torch::Tensor const& seq_lens, torch::Tensor const& page_table, torch::Tensor const& workspace, + double sm_scale, int64_t num_kv_splits) { TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); }