diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp index 2cbc2379579e..1f62c37ba4b7 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) - continue; + if (local_split_kv <= get<3>(blk_coord)) + continue; load_page_table( blk_coord, problem_shape, params.mainloop, shared_storage.tensors, pipeline_page_table, pipeline_pt_producer_state, - local_split_kv + local_split_kv ); } } @@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; load_cpasync( blk_coord, @@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { params.mainloop_params, shared_storage.tensors, pipeline_load_qk, pipeline_load_qk_producer_state, - local_split_kv, + local_split_kv, /* must be shared pipe */ pipeline_page_table, pipeline_pt_consumer_state ); @@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { - local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; - } + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; load_tma( blk_coord, @@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { shared_storage.tensors, pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state, - local_split_kv + local_split_kv ); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); } @@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; - } + } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; load_tma( blk_coord, @@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { shared_storage.tensors, pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state, - local_split_kv + local_split_kv ); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); } @@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; mma(blk_coord, problem_shape, @@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { pipeline_mma_s, pipeline_mma_s_producer_state, pipeline_p_mma, pipeline_p_mma_consumer_state, pipeline_mma_o, pipeline_mma_o_producer_state, - local_split_kv + local_split_kv ); } } @@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto problem_shape = params.problem_shape; - auto split_kv = params.split_kv; - auto local_split_kv = split_kv; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; compute( blk_coord, @@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { pipeline_mma_s, pipeline_mma_s_consumer_state, pipeline_p_mma, pipeline_p_mma_producer_state, pipeline_mma_o, pipeline_mma_o_consumer_state, - local_split_kv + local_split_kv ); } @@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { cutlass::arch::NamedBarrier( (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue - ).arrive(); + ).arrive_and_wait(); return; }