diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index cd77ca90f9..7f79409a9c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -354,6 +354,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { MainloopParams mainloop_params; EpilogueArguments epilogue; KernelHardwareInfo hw_info; + // L2 cache swizzle parameters (calculated on host) + int kSwizzle; + int num_blocks_k; + int num_hb_quotient; + int num_hb_remainder; + cutlass::FastDivmod l2_major_divmod; + cutlass::FastDivmod l2_minor_divmod; + cutlass::FastDivmod head_divmod; + cutlass::FastDivmod l2_minor_residual_divmod; }; template @@ -436,6 +445,34 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { SmemLayoutDQ{}(_, _, _0{}) ); + auto [H, B] = HB; + auto [H_R, H_K] = H; + + long long const size_one_qdo_head = + long(K_) * long(D + D_VO) * long(sizeof(Element)); + long long const size_one_dqaccum_head = long(K_) * long(D) * sizeof(float); + long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head; + + int l2_cache_size = 0; + cudaDeviceGetAttribute( + &l2_cache_size, cudaDevAttrL2CacheSize, args.hw_info.device_id); + int const size_l2_reserved = static_cast(l2_cache_size * 0.8); + + auto find_log2_floor = [](int n) { return 31 - cutlass::clz(n); }; + int const kSwizzle = size_l2_reserved < size_one_head + ? 1 + : (1 << find_log2_floor(size_l2_reserved / size_one_head)); + int num_blocks_k = ceil_div(K_, TileShapeK{}); + int total_heads_batches = H_K * B; + int num_hb_quotient = total_heads_batches / kSwizzle; + int num_hb_remainder = total_heads_batches % kSwizzle; + + cutlass::FastDivmod l2_major_divmod(kSwizzle * num_blocks_k); + cutlass::FastDivmod l2_minor_divmod(kSwizzle); + cutlass::FastDivmod head_divmod(H_K); + cutlass::FastDivmod l2_minor_residual_divmod( + num_hb_remainder > 0 ? num_hb_remainder : 1); + return Params{ args.problem_shape, args.mainloop, @@ -448,9 +485,17 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { args.mainloop.window_size_left, args.mainloop.window_size_right }, - args.epilogue, - args.hw_info - }; + args.epilogue, + args.hw_info, + kSwizzle, + num_blocks_k, + num_hb_quotient, + num_hb_remainder, + l2_major_divmod, + l2_minor_divmod, + head_divmod, + l2_minor_residual_divmod + }; } @@ -1813,7 +1858,45 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { pipeline_init_wait(size(ClusterShape{})); - auto blk_coord = make_coord(_0{}, int(gridDim.x) -1 - blockIdx.x, _0{}, _0{}, make_coord(make_coord(0, blockIdx.y), blockIdx.z)); + // Head swizzling: Decompose tile_idx to (block_k, head_idx, batch_idx) + + int tile_idx = blockIdx.x; + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + auto [H_R, H_K] = H; + + // Use pre-calculated swizzle parameters from host + int const kSwizzle = params.kSwizzle; + int num_blocks_k = params.num_blocks_k; + int num_hb_quotient = params.num_hb_quotient; + + // Step 1: Which section (bidhb) and position within section (l2_mod) + int bidhb, l2_mod; + bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); + + // Step 2: Within section, get block_k and head-batch offset + int block_k, bidhb_residual; + + if (bidhb < num_hb_quotient) { + block_k = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + block_k = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + // Step 3: Convert to actual head and batch indices + int head_batch_idx = bidhb * kSwizzle + bidhb_residual; + int batch_idx, head_idx; + batch_idx = params.head_divmod.divmod(head_idx, head_batch_idx); + + if constexpr ( + std::is_base_of_v, Mask> || + std::is_base_of_v, Mask> || + std::is_base_of_v, Mask> || + std::is_base_of_v, Mask>) { + // Reverse block_k ordering (for SPT scheduling) + block_k = num_blocks_k - 1 - block_k; + } + + auto blk_coord = make_coord(_0{}, block_k, _0{}, _0{}, make_coord(make_coord(0, head_idx), batch_idx)); auto [problem_shape, blk_offset] = apply_variable_length_offset( params.problem_shape, blk_coord @@ -1974,7 +2057,12 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto [Q, K, D, D_VO, HB] = params.problem_shape; auto [H, B] = HB; auto [H_R, H_K] = H; - dim3 grid(ceil_div(K, TileShapeK{}), H_K, B); + + int num_blocks_k = ceil_div(K, TileShapeK{}); + int total_heads_batches = H_K * B; + int total_tiles = num_blocks_k * total_heads_batches; + + dim3 grid(total_tiles, 1, 1); return grid; } };