From 35f0a6b3a672715c81bc0cf598eef264b1019caf Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 6 Aug 2025 19:04:51 +0000 Subject: [PATCH 01/17] Add files from https://github.com/deepseek-ai/FlashMLA/pull/54 Signed-off-by: Matthew Bonanni --- csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu | 3 + csrc/kernels_fp8/flash_fwd_mla_kernel.h | 705 +++++++++++++++++++++ csrc/kernels_fp8/flash_fwd_mla_metadata.cu | 77 +++ csrc/kernels_fp8/flash_mla.h | 66 ++ csrc/kernels_fp8/fp8_transpose_v.h | 83 +++ csrc/kernels_fp8/named_barrier.h | 16 + csrc/kernels_fp8/softmax.h | 197 ++++++ csrc/kernels_fp8/static_switch.h | 65 ++ csrc/kernels_fp8/utils.h | 274 ++++++++ 9 files changed, 1486 insertions(+) create mode 100644 csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu create mode 100644 csrc/kernels_fp8/flash_fwd_mla_kernel.h create mode 100644 csrc/kernels_fp8/flash_fwd_mla_metadata.cu create mode 100644 csrc/kernels_fp8/flash_mla.h create mode 100644 csrc/kernels_fp8/fp8_transpose_v.h create mode 100644 csrc/kernels_fp8/named_barrier.h create mode 100644 csrc/kernels_fp8/softmax.h create mode 100644 csrc/kernels_fp8/static_switch.h create mode 100644 csrc/kernels_fp8/utils.h diff --git a/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu b/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu new file mode 100644 index 0000000..b678962 --- /dev/null +++ b/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/kernels_fp8/flash_fwd_mla_kernel.h b/csrc/kernels_fp8/flash_fwd_mla_kernel.h new file mode 100644 index 0000000..6e92b5e --- /dev/null +++ b/csrc/kernels_fp8/flash_fwd_mla_kernel.h @@ -0,0 +1,705 @@ +#pragma once + +#include +#include +#include +#include + +using namespace cute; + +#include "named_barrier.h" +#include "utils.h" +#include "softmax.h" +#include "static_switch.h" +#include "flash_mla.h" +#include "fp8_transpose_v.h" + + +template +constexpr auto getSmemLayoutK() { + constexpr int headSizeBytes = sizeof(PrecType) * DIM; + constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; + + if constexpr (major == GMMA::Major::K) { + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else { + return GMMA::Layout_K_SW32_Atom{}; + } + } else { + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } else { + return GMMA::Layout_MN_SW32_Atom{}; + } + } + +} + +template +struct Flash_fwd_kernel_traits_mla { + using Element = elem_type; + using ElementO = elem_type_o; + using ElementAccum = float; + using index_t = int64_t; + + static constexpr bool Is_FP8 = cute::is_same_v; + + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + static constexpr int kNWarpsS = 4; + static constexpr int kNThreadsS = kNWarpsS * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; + static_assert(kHeadDimV % 32 == 0); + static_assert(kHeadDimV <= kHeadDim); + + static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3; + + static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K; + + using TiledMma = decltype(make_tiled_mma( + cute::GMMA::ss_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout, _1, _1>>{})); + + static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; + using TiledMmaO = decltype(make_tiled_mma( + cute::GMMA::rs_op_selector, Int, Int>, + GMMA::Major::K, MmaMajorV>(), + Layout, Int, _1>>{})); + + using SmemLayoutQ = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutK = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutV = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + + using SmemLayoutP = std::conditional_t< + Is_FP8, + Layout, Int, _1, _2, Int>>, + Layout, Int, _1, _2, Int>> + >; + using SmemLayoutRow = Layout>, Stride<_1, _2>>; + + using SmemLayoutAtomO = decltype(composition( + Swizzle{}, + Layout, Int>, Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; + static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; + static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + + static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO); + static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO"); + static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO; + static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO"); + + using GmemLayoutAtom = Layout< + Shape, Int>, + Stride, _1>>; + + + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 vals per read + + using GmemLayoutAtomO = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom, ElementO>{}, + GmemLayoutAtomO{}, + Layout>>{})); // Val layout, 8 vals per store + + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum; + using GmemLayoutAtomOaccum = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout>>{})); // Val layout, 4 vals per store + + + // ------ for f8 ------ + using SmemFp8Tranpose = SmemTransposeFp8_64x64; + using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt; +}; + +namespace flash { + +using namespace cute; + +template +struct SharedStorageMLA { + using SmemV_t = std::conditional_t>, + cute::array_aligned>; + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned * 2> smem_k; // Double buffer + SmemV_t smem_vt; + cute::array_aligned> smem_p; + cute::array_aligned> smem_scale; + }; + struct { + cute::array_aligned> smem_max; + cute::array_aligned> smem_sum; + cute::array_aligned> smem_o; + }; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, + SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + using Element = typename Kernel_traits::ElementO; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + + // Epilogue + + const int split_offset = __ldg(params.num_splits_ptr + bidb); + + Tensor lse = softmax.template normalize_softmax_lse(tOrO, scale_softmax, descale_k); + + using ElementO = std::conditional_t; + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(tOrO); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + __syncthreads(); + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), + Shape>{}, Stride<_1>{}); + + using GmemTiledCopyO = std::conditional_t; + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + if (tidx >= kNThreadsS) { return; } + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) + Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM + ); +} + +template +__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, + const int bidb, const int bidh, const int m_block, + const int n_split_idx, const int seqlen_k, + const int n_block_min, const int n_block_max, const bool NoSplit, + SharedStorage &shared_storage, const float descale_k, const float scale_softmax, const float scale_softmax_log2) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreads = Kernel_traits::kNThreads; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + static_assert(kNThreads == 256 and kNThreadsS == 128); + using Element = typename Kernel_traits::Element; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + int n_block = n_block_max - 1; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); + + auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + auto sVt = [&](){ + if constexpr(Kernel_traits::Is_FP8){ + return make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + } + }(); + + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _); + Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); + Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); + Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) + Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + clear(tOrO); + + flash::Softmax<2 * size<1>(tOrO)> softmax; + + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 0) { + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; +#pragma unroll 1 + for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { + __syncthreads(); + + Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); + + const bool is_masking_step = masking_step > 0; + const bool is_first_masking_step = masking_step == n_masking_steps; + + if (is_masking_step) { + Tensor cS = make_identity_tensor(Shape, Int>{}); + Tensor tScS = thr_mma.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; + } else { + // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + int row = int(get<0>(tScS(i))); + int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; + if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; + } + } + } + + // We have key_padding_mask so we'll need to Check_inf + Tensor scale_o = is_first_masking_step + ? softmax.template softmax(tSrS, scale_softmax_log2) + : is_masking_step ? + softmax.template softmax(tSrS, scale_softmax_log2) + : softmax.template softmax(tSrS, scale_softmax_log2); + + if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + + cute::copy(tOrP, tPsP); // send Aregs of MMA1 instead of Cregs of MMA0 + cute::copy(scale_o, tScale_osScale_o); + + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); + + flash::rescale_o(tOrO, scale_o); + + if constexpr (Kernel_traits::Is_FP8) { + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); + __syncthreads(); + } + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + cute::copy(softmax.row_max, tRow_maxsRow_max); + cute::copy(softmax.row_sum, tRow_sumsRow_sum); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + } else { + const int *block_table = params.block_table + bidb * params.block_table_batch_stride; + int cur_block_table = __ldg(&block_table[n_block]); + + const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.seqlen_q - m_block * kBlockM); + + const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; + auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); + Tensor tKgK = gmem_thr_copy_K.partition_S(gK); + Tensor tKsK = gmem_thr_copy_K.partition_D(sK); + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tKsK.data() = tKsK.data() + sK_offset; + if constexpr (!Kernel_traits::Is_FP8) { + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + // We need to clear the sK smem tiles because K is V. + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, + seqlen_k - n_block * kBlockN); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + + if (n_block - 1 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 1]); + } + +#pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + flash::cp_async_wait<0>(); + __syncthreads(); + + if (n_block - 1 >= n_block_min) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tKsK.data() = tKsK.data() + sK_offset; + + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + } + + if constexpr (Kernel_traits::Is_FP8) { + auto TransV = [&]() { + using SmemFp8Tranpose = typename Kernel_traits::SmemFp8Tranpose; + SmemFp8Tranpose smem_transpose_V; + Tensor sV_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename SmemFp8Tranpose::SmemLayoutTransposeV{})); + Tensor sVt_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename SmemFp8Tranpose::SmemLayoutTransposeVt{})); + + if (n_block % 2 == 1) { + sV_divide.data() = sV_divide.data() + size(sK); + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < shape<2>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < shape<1>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++i) { + smem_transpose_V.transpose(flatten(sV_divide(_, i, j)), flatten(sVt_divide(_, i, j))); + } + } + }; + + TransV(); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::TransVReady)); + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); + + if (n_block - 2 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 2]); + } + + typename Kernel_traits::TiledMma tiled_mma; + auto tSrS_layout = flash::convert_layout_acc_Aregs(partition_fragment_C(tiled_mma, Shape, Int>{}).layout()); + Tensor tOrP = make_tensor(tSrS_layout); + Tensor scale_o = make_tensor(Shape<_2>{}); + cute::copy(tScale_osScale_o, scale_o); + cute::copy(tPsP, tOrP); + + flash::rescale_o(tOrO, scale_o); + + if constexpr (Kernel_traits::Is_FP8) __syncthreads(); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + if constexpr (!Kernel_traits::Is_FP8) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + cute::copy(tRow_maxsRow_max, softmax.row_max); + cute::copy(tRow_sumsRow_sum, softmax.row_sum); + } + + if (NoSplit) + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); + else + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); +} + +template +__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kBlockN = Kernel_traits::kBlockN; + const int m_block = blockIdx.x; + const int bidh = blockIdx.y; + const int partition_idx = blockIdx.z; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + if (begin_idx >= params.b) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + + float descale_k = 1.f; + float scale_softmax = params.scale_softmax; + float scale_softmax_log2 = params.scale_softmax_log2; + if constexpr (Kernel_traits::Is_FP8) { + float descale_q = __ldg(params.descale_q_ptr); + descale_k = __ldg(params.descale_k_ptr); + scale_softmax = scale_softmax * descale_q * descale_k; + scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k; + } + +#pragma unroll 1 + for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { + const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; + const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); + const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; + const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); + if (batch_id > begin_idx) { + __syncthreads(); // Barrier between two tiles. + } + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(256, 1, 1) +flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kNThreads = 128; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int hs = params.h * params.seqlen_q; + const int batch_idx = bidx / hs; + const int hs_idx = bidx % hs; + + const int split_offset = __ldg(params.num_splits_ptr + batch_idx); + const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; + FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); + if (actual_num_splits == 1) return; + + __shared__ ElementAccum sLseScale[kMaxSplits]; + + const index_t row_offset_lseaccum = split_offset * hs + hs_idx; + const index_t row_offset_lse = bidx; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, make_stride(hs)); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape<_1>{}, Stride<_1>{}); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + float local_lse[kNLsePerThread]; + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY; + } + + float max_lse = -INFINITY; + for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]); + for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); + max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf + + float sum_lse = 0; + for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse); + for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); + + float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse; + if (tidx == 0) gLSE(0) = global_lse; + + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + __syncthreads(); + + static_assert(kHeadDimV % kNThreads == 0); + constexpr int Elements = kHeadDimV / kNThreads; + const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape>{}, Stride<_1>{}); + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + Layout>>{}, + Layout>>{})); + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + for (int split = 0; split < actual_num_splits; ++split) { + cute::copy(tOgOaccum, tOrOaccum); + ElementAccum lse_scale = sLseScale[split]; + for (int i = 0; i < size(tOrO); ++i) { + tOrO(i) += lse_scale * tOrOaccum(i); + } + tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV; + } + + Tensor rO = flash::convert_type(tOrO); + const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q; + const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); + cute::copy(rO, gO); +} + +} // namespace flash + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); + const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + auto kernel = &flash::flash_fwd_splitkv_mla_kernel; + constexpr size_t smem_size = sizeof(SharedStorage); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); + + dim3 grid_combine(params.b * params.h * params.seqlen_q); + MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { + auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< + typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + combine_kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + static_assert(Headdim == 576); + FLASH_ASSERT(params.d_v == 512); + FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV + using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>; + run_flash_splitkv_fwd_mla>(params, stream); +} diff --git a/csrc/kernels_fp8/flash_fwd_mla_metadata.cu b/csrc/kernels_fp8/flash_fwd_mla_metadata.cu new file mode 100644 index 0000000..82f5b5a --- /dev/null +++ b/csrc/kernels_fp8/flash_fwd_mla_metadata.cu @@ -0,0 +1,77 @@ +#include "flash_fwd_mla_kernel.h" + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} \ No newline at end of file diff --git a/csrc/kernels_fp8/flash_mla.h b/csrc/kernels_fp8/flash_mla.h new file mode 100644 index 0000000..b7e2fed --- /dev/null +++ b/csrc/kernels_fp8/flash_mla.h @@ -0,0 +1,66 @@ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_mla_params { + using index_t = int64_t; + + int b, seqlen_q, d, d_v; + int h, h_h_k_ratio, ngroups; + bool is_causal; + float scale_softmax, scale_softmax_log2; + int *__restrict__ cu_seqlens_k; + + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + void *__restrict__ o_ptr; + void *__restrict__ softmax_lse_ptr; + + float* __restrict__ descale_q_ptr = nullptr; + float* __restrict__ descale_k_ptr = nullptr; + + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t o_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t o_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + index_t o_head_stride; + + int *__restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + int *__restrict__ tile_scheduler_metadata_ptr; + int num_sm_parts; + int *__restrict__ num_splits_ptr; + + void *__restrict__ softmax_lseaccum_ptr; + void *__restrict__ oaccum_ptr; +}; + +static constexpr int TileSchedulerMetaDataSize = 8; +// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); + +struct Mla_metadata_params { + int *__restrict__ seqlens_k_ptr; + int *__restrict__ tile_scheduler_metadata_ptr; + int *__restrict__ num_splits_ptr; + int batch_size; + int block_size_n; + int fixed_overhead_num_blocks; + int num_sm_parts; +}; + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/kernels_fp8/fp8_transpose_v.h b/csrc/kernels_fp8/fp8_transpose_v.h new file mode 100644 index 0000000..40bb4d5 --- /dev/null +++ b/csrc/kernels_fp8/fp8_transpose_v.h @@ -0,0 +1,83 @@ +/** + * ref to Fa3's SmemTranspose64x64: + * https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26 +*/ + +#pragma once + +template +struct SmemTransposeFp8_64x64 { + static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); + + using Element = cutlass::float_e4m3_t; + using TransposeShapeAtomV = Shape<_64, _64>; + using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- src layout + using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); + using SmemShapeLDSM = Shape, Shape<_16, _4>>; + using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}))); + using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); + + // For fp8, this is the memory transpose. + using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutVt = + decltype(tile_to_shape(SmemLayoutAtomVt{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- dst layout + using SmemLayoutVtTrans = decltype(composition( + SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{}))); + using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); + using SmemShapeSTSM = Shape, Shape<_16, _4>>; + using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}))); + using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); + + + using ldsm_thread_shape = Shape<_4, _1, _8, _4>; + using ldsm_value_shape = Shape<_2, _8, _2, _1>; + using ldsm_value_stride = Stride<_2, _4, _1, _0>; + using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopyLDSM tiled_copy_ldsm; + + using stsm_thread_shape = Shape<_4, _1, _8, _4>; + // using stsm_thread_stride = Stride<_1, _0, _4, _32>; + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; + + using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopySTSM tiled_copy_stsm; + + template + CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) { + using namespace cute; + + auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); + auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); + + auto tXsX = thr_copy_ldsm.partition_S(s_in); + auto tXrX = make_tensor(shape(tXsX)); + auto tXsX_out = thr_copy_stsm.partition_D(s_out); + + cute::copy(tiled_copy_ldsm, tXsX, tXrX); + + auto data = tXrX.data(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size(tXrX); n += 8) { + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); + } + + cute::copy(tiled_copy_stsm, tXrX, tXsX_out); + } +}; + diff --git a/csrc/kernels_fp8/named_barrier.h b/csrc/kernels_fp8/named_barrier.h new file mode 100644 index 0000000..940c934 --- /dev/null +++ b/csrc/kernels_fp8/named_barrier.h @@ -0,0 +1,16 @@ +#pragma once + +#include "cutlass/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class NamedBarriers { + SReady = 1, + SoftmaxReady = 2, + TransVReady = 3, +}; + +} // flash diff --git a/csrc/kernels_fp8/softmax.h b/csrc/kernels_fp8/softmax.h new file mode 100644 index 0000000..bcb8cac --- /dev/null +++ b/csrc/kernels_fp8/softmax.h @@ -0,0 +1,197 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h + +#pragma once + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ auto scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif + } + } + return tensor; +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scale_o); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scale_o; + clear(scale_o); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scale_o(mi) = scores_scale; + row_sum(mi) *= scores_scale; + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + return scale_o; + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float descale_v, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace flash diff --git a/csrc/kernels_fp8/static_switch.h b/csrc/kernels_fp8/static_switch.h new file mode 100644 index 0000000..f156adc --- /dev/null +++ b/csrc/kernels_fp8/static_switch.h @@ -0,0 +1,65 @@ +#pragma once + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) + + +#define FLASH_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + exit(1); \ + } \ + } while(0) + + +#define FLASH_DEVICE_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ + } while(0) + + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + + +#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ + [&] { \ + if (NUM_SPLITS <= 32) { \ + constexpr static int NAME = 32; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int NAME = 64; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 96) { \ + constexpr static int NAME = 96; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 128) { \ + constexpr static int NAME = 128; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 160) { \ + constexpr static int NAME = 160; \ + return __VA_ARGS__(); \ + } else { \ + FLASH_ASSERT(false); \ + } \ + }() diff --git a/csrc/kernels_fp8/utils.h b/csrc/kernels_fp8/utils.h new file mode 100644 index 0000000..716c50c --- /dev/null +++ b/csrc/kernels_fp8/utils.h @@ -0,0 +1,274 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h + +#pragma once + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { + auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) + return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } else { + static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); + static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); + static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); + auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) + // This combines the first two modes (<0, 0> and <0, 1>) into one mode. + // Will require register shuffling later to be correct. + return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), + get<1>(acc_layout), + coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) + // This combination is right but doesn't work with register shuffling. + // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), + // get<1>(acc_layout), + // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash From bded0bbd450980159ab23e42e00496cc5f0ea02f Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 6 Aug 2025 20:32:54 +0000 Subject: [PATCH 02/17] FP8 now extends base implementation Signed-off-by: Matthew Bonanni --- csrc/flash_api.cpp | 38 ++++++++++- csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu | 2 +- csrc/kernels_fp8/flash_fwd_mla_kernel.h | 45 +++++++------ csrc/kernels_fp8/flash_fwd_mla_metadata.cu | 77 ---------------------- csrc/kernels_fp8/flash_mla.h | 61 ++--------------- flash_mla/flash_mla_interface.py | 6 ++ 6 files changed, 71 insertions(+), 158 deletions(-) delete mode 100644 csrc/kernels_fp8/flash_fwd_mla_metadata.cu diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index d6b96c4..e4b1951 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -16,6 +16,8 @@ #include "kernels/params.h" #include "kernels/splitkv_mla.h" +#include "kernels_fp8/flash_mla.h" + #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") @@ -68,7 +70,9 @@ mha_fwd_kvcache_mla( const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize - const at::Tensor &num_splits // batch_size + 1 + const at::Tensor &num_splits, // batch_size + 1 + c10::optional &descale_q, // batch_size + c10::optional &descale_k // batch_size ) { // Check the architecture auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -77,7 +81,7 @@ mha_fwd_kvcache_mla( // Check data types auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf || q_dtype == torch::kFloat8_e4m3fn); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); @@ -115,6 +119,20 @@ mha_fwd_kvcache_mla( TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (q_dtype == torch::kFloat8_e4m3fn) { + TORCH_CHECK(descale_q.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8"); + auto descale_q_ = descale_q.value(); + auto descale_k_ = descale_k.value(); + CHECK_DEVICE(descale_q_); + CHECK_DEVICE(descale_k_); + TORCH_CHECK(descale_q_.stride(-1) == 1); + TORCH_CHECK(descale_k_.stride(-1) == 1); + TORCH_CHECK(descale_q_.dtype() == torch::kFloat); + TORCH_CHECK(descale_k_.dtype() == torch::kFloat); + CHECK_SHAPE(descale_q_, 1); + CHECK_SHAPE(descale_k_, 1); + } + if (seqlen_q_ori == 1) { is_causal = false; } const int num_q_heads_per_hk = num_heads_q / num_heads_k; @@ -196,6 +214,22 @@ mha_fwd_kvcache_mla( #else run_flash_splitkv_mla_kernel(params, stream); run_flash_mla_combine_kernel(params, stream); +#endif + } else if (q_dtype == torch::kFloat8_e4m3fn) { +#ifdef FLASH_MLA_DISABLE_FP8 + TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP8. Please remove this flag from your environment and re-compile FlashMLA."); +#else + // Create FP8-specific params by copying base params and setting FP8 fields + Flash_fwd_mla_params_fp8 fp8_params; + // Copy all base fields + static_cast(fp8_params) = params; + + // Set FP8-specific fields + fp8_params.h_h_k_ratio = 1; + fp8_params.descale_q_ptr = reinterpret_cast(descale_q.value().data_ptr()); + fp8_params.descale_k_ptr = reinterpret_cast(descale_k.value().data_ptr()); + + run_mha_fwd_splitkv_mla(fp8_params, stream); #endif } else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); diff --git a/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu b/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu index b678962..a7eb3db 100644 --- a/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu +++ b/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu @@ -1,3 +1,3 @@ #include "flash_fwd_mla_kernel.h" -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); \ No newline at end of file +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params_fp8 ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/kernels_fp8/flash_fwd_mla_kernel.h b/csrc/kernels_fp8/flash_fwd_mla_kernel.h index 6e92b5e..c33e6db 100644 --- a/csrc/kernels_fp8/flash_fwd_mla_kernel.h +++ b/csrc/kernels_fp8/flash_fwd_mla_kernel.h @@ -182,7 +182,7 @@ struct SharedStorageMLA { //////////////////////////////////////////////////////////////////////////////////////////////////// template -__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, +__forceinline__ __device__ void store(const Flash_fwd_mla_params_fp8 ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; @@ -221,9 +221,9 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h_q + bidh) * params.s_q + m_block * kBlockM) * params.d_v; + const index_t row_offset_lse = (bidb * params.h_q + bidh) * params.s_q + m_block * kBlockM; + const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h_q + bidh) * params.s_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, @@ -252,7 +252,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); - if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + if (row < params.s_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } @@ -263,12 +263,12 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.s_q - m_block * kBlockM ); } template -__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, +__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params_fp8 ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, @@ -358,10 +358,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if constexpr (!Is_causal) { // Just masking based on col if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; } else { - // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups - // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / q_head_per_hk + // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / q_head_per_hk int row = int(get<0>(tScS(i))); - int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; + int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.s_q - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; } } @@ -423,7 +423,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, - params.seqlen_q - m_block * kBlockM); + params.s_q - m_block * kBlockM); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), @@ -539,7 +539,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f template __global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) -flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { +flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params_fp8 params) { constexpr int kBlockN = Kernel_traits::kBlockN; const int m_block = blockIdx.x; const int bidh = blockIdx.y; @@ -570,7 +570,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; - const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); + const int seqlen_k = __ldg(params.seqlens_k_ptr + batch_id); const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); @@ -585,12 +585,12 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params template __global__ void __launch_bounds__(256, 1, 1) -flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { +flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params_fp8 params) { constexpr int kNThreads = 128; const int tidx = threadIdx.x; const int bidx = blockIdx.x; - const int hs = params.h * params.seqlen_q; + const int hs = params.h_q * params.s_q; const int batch_idx = bidx / hs; const int hs_idx = bidx % hs; @@ -663,8 +663,8 @@ flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_param } Tensor rO = flash::convert_type(tOrO); - const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q; - const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q; + const int head_idx = (bidx - batch_idx * hs) / params.s_q; + const int row = bidx - batch_idx * hs - head_idx * params.s_q; auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); cute::copy(rO, gO); @@ -675,18 +675,18 @@ flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_param //////////////////////////////////////////////////////////////////////////////////////////////////// template -void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { +void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params_fp8 ¶ms, cudaStream_t stream) { FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); - const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); + const int num_m_block = cute::ceil_div(params.s_q, Kernel_traits::kBlockM); BOOL_SWITCH(params.is_causal, Is_causal, [&] { auto kernel = &flash::flash_fwd_splitkv_mla_kernel; constexpr size_t smem_size = sizeof(SharedStorage); CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - kernel<<>>(params); + kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); - dim3 grid_combine(params.b * params.h * params.seqlen_q); + dim3 grid_combine(params.b * params.h_q * params.s_q); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; @@ -696,10 +696,9 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream } template -void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params_fp8 ¶ms, cudaStream_t stream) { static_assert(Headdim == 576); FLASH_ASSERT(params.d_v == 512); - FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>; run_flash_splitkv_fwd_mla>(params, stream); } diff --git a/csrc/kernels_fp8/flash_fwd_mla_metadata.cu b/csrc/kernels_fp8/flash_fwd_mla_metadata.cu deleted file mode 100644 index 82f5b5a..0000000 --- a/csrc/kernels_fp8/flash_fwd_mla_metadata.cu +++ /dev/null @@ -1,77 +0,0 @@ -#include "flash_fwd_mla_kernel.h" - -static constexpr int MaxBatchSize = 4096; - -__global__ void __launch_bounds__(256, 1, 1) -get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { - int *seqlens_k_ptr = params.seqlens_k_ptr; - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; - int *num_splits_ptr = params.num_splits_ptr; - int batch_size = params.batch_size; - int block_size_n = params.block_size_n; - int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; - int num_sm_parts = params.num_sm_parts; - - __shared__ int num_blocks_shared[MaxBatchSize]; - __shared__ int num_splits_shared[MaxBatchSize]; - - int total_num_blocks = 0; - for (int i = threadIdx.x; i < batch_size; i += 32) { - int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); - total_num_blocks += num_blocks + fixed_overhead_num_blocks; - num_blocks_shared[i] = num_blocks; - } - for (int offset = 16; offset >= 1; offset /= 2) { - total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); - } - __syncwarp(); - - if (threadIdx.x == 0) { - int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; - - int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; - num_splits_shared[0] = 0; - for (int i = 0; i < num_sm_parts; ++i) { - int tile_scheduler_metadata0[4], tile_scheduler_metadata1; - tile_scheduler_metadata0[0] = now_idx; - tile_scheduler_metadata0[1] = now_block * block_size_n; - tile_scheduler_metadata1 = now_n_split_idx; - int remain_payload = payload; - while (now_idx < batch_size) { - int num_blocks = num_blocks_shared[now_idx]; - int now_remain_blocks = num_blocks - now_block; - if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { - cum_num_splits += now_n_split_idx + 1; - num_splits_shared[now_idx + 1] = cum_num_splits; - remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; - ++now_idx; - now_block = 0; - now_n_split_idx = 0; - } else { - if (remain_payload - fixed_overhead_num_blocks > 0) { - now_block += remain_payload - fixed_overhead_num_blocks; - ++now_n_split_idx; - remain_payload = 0; - } - break; - } - } - tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; - tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; - *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); - tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; - } - FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); - } - __syncwarp(); - - for (int i = threadIdx.x; i <= batch_size; i += 32) { - num_splits_ptr[i] = num_splits_shared[i]; - } -} - -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.batch_size < MaxBatchSize); - get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); - CHECK_CUDA_KERNEL_LAUNCH(); -} \ No newline at end of file diff --git a/csrc/kernels_fp8/flash_mla.h b/csrc/kernels_fp8/flash_mla.h index b7e2fed..0ae85a9 100644 --- a/csrc/kernels_fp8/flash_mla.h +++ b/csrc/kernels_fp8/flash_mla.h @@ -1,66 +1,17 @@ #pragma once -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Flash_fwd_mla_params { - using index_t = int64_t; - - int b, seqlen_q, d, d_v; - int h, h_h_k_ratio, ngroups; - bool is_causal; - float scale_softmax, scale_softmax_log2; - int *__restrict__ cu_seqlens_k; +#include "../kernels/params.h" - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - void *__restrict__ o_ptr; - void *__restrict__ softmax_lse_ptr; +//////////////////////////////////////////////////////////////////////////////////////////////////// +// FP8-specific extension of the original Flash_fwd_mla_params +struct Flash_fwd_mla_params_fp8 : public Flash_fwd_mla_params { + int h_h_k_ratio; float* __restrict__ descale_q_ptr = nullptr; float* __restrict__ descale_k_ptr = nullptr; - - index_t q_batch_stride; - index_t k_batch_stride; - index_t v_batch_stride; - index_t o_batch_stride; - index_t q_row_stride; - index_t k_row_stride; - index_t v_row_stride; - index_t o_row_stride; - index_t q_head_stride; - index_t k_head_stride; - index_t v_head_stride; - index_t o_head_stride; - - int *__restrict__ block_table; - index_t block_table_batch_stride; - int page_block_size; - - int *__restrict__ tile_scheduler_metadata_ptr; - int num_sm_parts; - int *__restrict__ num_splits_ptr; - - void *__restrict__ softmax_lseaccum_ptr; - void *__restrict__ oaccum_ptr; }; -static constexpr int TileSchedulerMetaDataSize = 8; -// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] - //////////////////////////////////////////////////////////////////////////////////////////////////// template -void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); - -struct Mla_metadata_params { - int *__restrict__ seqlens_k_ptr; - int *__restrict__ tile_scheduler_metadata_ptr; - int *__restrict__ num_splits_ptr; - int batch_size; - int block_size_n; - int fixed_overhead_num_blocks; - int num_sm_parts; -}; - -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream); +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params_fp8 ¶ms, cudaStream_t stream); diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 47637f8..d25105a 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -33,6 +33,8 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -45,6 +47,8 @@ def flash_mla_with_kvcache( num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. + descale_q_: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. + descale_k_: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). @@ -62,5 +66,7 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + descale_q, + descale_k, ) return out, softmax_lse From 09a7be5b2b0507fd9ea0973df80385f786163011 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 6 Aug 2025 20:44:39 +0000 Subject: [PATCH 03/17] Fix typo Signed-off-by: Matthew Bonanni --- csrc/flash_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index e4b1951..0f30baf 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -120,7 +120,7 @@ mha_fwd_kvcache_mla( TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (q_dtype == torch::kFloat8_e4m3fn) { - TORCH_CHECK(descale_q.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8"); + TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8"); auto descale_q_ = descale_q.value(); auto descale_k_ = descale_k.value(); CHECK_DEVICE(descale_q_); From 8088f76bac762aef53eddd86521c81097a7a56dc Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 6 Aug 2025 20:44:45 +0000 Subject: [PATCH 04/17] Update tests Signed-off-by: Matthew Bonanni --- tests/test_flash_mla.py | 66 +++++++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 67c9d93..bc152d0 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -28,21 +28,25 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): return attn_weight @ value, lse -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None: x, y = x.double(), y.double() RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5 + if use_fp8: + assert cos_diff < 1e-2 + else: + assert cos_diff < 1e-5 @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype): print( - f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}" ) + use_fp8 = torch_dtype == torch.float8_e4m3fn cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: for i in range(b): @@ -68,7 +72,31 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv ) - + + init_dtype = q.dtype + def prepare_fp8_input(): + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None + + if use_fp8: + nonlocal q, blocked_k, blocked_v + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q_fp8 = q.to(fp8_dtype) + blocked_k_fp8 = blocked_k.to(fp8_dtype) + blocked_v_fp8 = blocked_v.to(fp8_dtype) + + return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k + + + + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input() + if use_fp8: + q = q_fp8 + blocked_k = blocked_k_fp8 + blocked_v = blocked_v_fp8 + def flash_mla(): return flash_mla_with_kvcache( q, @@ -79,18 +107,23 @@ def flash_mla(): tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, + descale_k=descale_k, ) def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), h_q=h_q, h_kv=h_kv, is_causal=causal, @@ -101,14 +134,12 @@ def ref_mla(): out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(q.dtype).bits // 8 - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) print( f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" ) @@ -116,7 +147,8 @@ def ref_mla(): def main(torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(torch_dtype) + init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype + torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) @@ -131,7 +163,7 @@ def main(torch_dtype): for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: - test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) + test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype) if __name__ == "__main__": @@ -139,9 +171,9 @@ def main(torch_dtype): parser.add_argument( "--dtype", type=str, - choices=["bf16", "fp16"], + choices=["bf16", "fp16", "e4m3"], default="bf16", - help="Data type to use for testing (bf16 or fp16)", + help="Data type to use for testing (bf16/fp16/e4m3)", ) args = parser.parse_args() @@ -149,5 +181,7 @@ def main(torch_dtype): torch_dtype = torch.bfloat16 if args.dtype == "fp16": torch_dtype = torch.float16 + elif args.dtype == "e4m3": + torch_dtype = torch.float8_e4m3fn main(torch_dtype) From 94e1e8c35be8701e6a2ab4973b6743e5ddf35dc1 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 6 Aug 2025 21:01:30 +0000 Subject: [PATCH 05/17] Add to build Signed-off-by: Matthew Bonanni --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 131ceff..7c6266a 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ def get_features_args(): "csrc/kernels/get_mla_metadata.cu", "csrc/kernels/mla_combine.cu", "csrc/kernels/splitkv_mla.cu", + "csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), From 4dcd3921922f91117c053f8d3e55dbb15d5ab994 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 7 Aug 2025 15:27:29 +0000 Subject: [PATCH 06/17] Fix installation Signed-off-by: Matthew Bonanni --- csrc/flash_api.cpp | 6 +++--- flash_mla/flash_mla_interface.py | 6 +++--- setup.py | 6 +++++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 0f30baf..dbfa8f4 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -254,8 +254,8 @@ TORCH_LIBRARY(_flashmla_C, m) { m.impl("fwd_kvcache_mla", torch::kCUDA, make_pytorch_shim(&mha_fwd_kvcache_mla)); } -PyMODINIT_FUNC PyInit__flashmla_C() { +PyMODINIT_FUNC PyInit_flash_mla_cuda() { static struct PyModuleDef module = { - PyModuleDef_HEAD_INIT, "_flashmla_C", nullptr, 0, nullptr}; - return PyModule_Create(&module); + PyModuleDef_HEAD_INIT, "flash_mla_cuda", nullptr, 0, nullptr}; + return PyModule_Create(&module); } \ No newline at end of file diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index d25105a..2bb00c0 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -2,7 +2,7 @@ import torch -import flash_mla_cuda +from . import flash_mla_cuda def get_mla_metadata( @@ -20,7 +20,7 @@ def get_mla_metadata( tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) def flash_mla_with_kvcache( @@ -56,7 +56,7 @@ def flash_mla_with_kvcache( """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( + out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( q, k_cache, head_dim_v, diff --git a/setup.py b/setup.py index 7c6266a..e1f82e3 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ def get_features_args(): ext_modules = [] ext_modules.append( CUDAExtension( - name="flash_mla_cuda", + name="flash_mla.flash_mla_cuda", sources=[ "csrc/flash_api.cpp", "csrc/kernels/get_mla_metadata.cu", @@ -91,4 +91,8 @@ def get_features_args(): packages=find_packages(include=['flash_mla']), ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, + package_data={ + 'flash_mla': ['*.so'], # Include any .so files in the flash_mla package + }, + zip_safe=False, # Important for extensions ) From 831cec650304996e9c6d7352f9450a2f55394e0b Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 7 Aug 2025 15:35:17 +0000 Subject: [PATCH 07/17] Fix FLASH_MLA_DISABLE_FP8 flag Signed-off-by: Matthew Bonanni --- csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu | 4 +++- setup.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu b/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu index a7eb3db..dda6aa5 100644 --- a/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu +++ b/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu @@ -1,3 +1,5 @@ #include "flash_fwd_mla_kernel.h" -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params_fp8 ¶ms, cudaStream_t stream); \ No newline at end of file +#ifndef FLASH_MLA_DISABLE_FP8 +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params_fp8 ¶ms, cudaStream_t stream); +#endif \ No newline at end of file diff --git a/setup.py b/setup.py index e1f82e3..5022cf5 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,9 @@ def get_features_args(): DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"] if DISABLE_FP16: features_args.append("-DFLASH_MLA_DISABLE_FP16") + DISABLE_FP8 = os.getenv("FLASH_MLA_DISABLE_FP8", "FALSE") in ["TRUE", "1"] + if DISABLE_FP8: + features_args.append("-DFLASH_MLA_DISABLE_FP8") return features_args From 8e0d857b3eba7f8711605e4fc410b597b68b7566 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 7 Aug 2025 17:55:26 +0000 Subject: [PATCH 08/17] Fix param matchup Signed-off-by: Matthew Bonanni --- csrc/kernels_fp8/flash_fwd_mla_kernel.h | 26 ++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/csrc/kernels_fp8/flash_fwd_mla_kernel.h b/csrc/kernels_fp8/flash_fwd_mla_kernel.h index c33e6db..ebb0306 100644 --- a/csrc/kernels_fp8/flash_fwd_mla_kernel.h +++ b/csrc/kernels_fp8/flash_fwd_mla_kernel.h @@ -221,9 +221,9 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params_fp8 ¶ms, co cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h_q + bidh) * params.s_q + m_block * kBlockM) * params.d_v; - const index_t row_offset_lse = (bidb * params.h_q + bidh) * params.s_q + m_block * kBlockM; - const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h_q + bidh) * params.s_q + m_block * kBlockM; + const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM) * params.d_v; + const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; + const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, @@ -252,7 +252,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params_fp8 ¶ms, co #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); - if (row < params.s_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } @@ -263,7 +263,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params_fp8 ¶ms, co Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.s_q - m_block * kBlockM + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.q_seq_per_hk - m_block * kBlockM ); } @@ -361,7 +361,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / q_head_per_hk // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / q_head_per_hk int row = int(get<0>(tScS(i))); - int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.s_q - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; + int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; } } @@ -423,7 +423,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, - params.s_q - m_block * kBlockM); + params.q_seq_per_hk - m_block * kBlockM); const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), @@ -590,7 +590,7 @@ flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_param const int tidx = threadIdx.x; const int bidx = blockIdx.x; - const int hs = params.h_q * params.s_q; + const int hs = params.h_k * params.q_seq_per_hk; const int batch_idx = bidx / hs; const int hs_idx = bidx % hs; @@ -663,8 +663,8 @@ flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_param } Tensor rO = flash::convert_type(tOrO); - const int head_idx = (bidx - batch_idx * hs) / params.s_q; - const int row = bidx - batch_idx * hs - head_idx * params.s_q; + const int head_idx = (bidx - batch_idx * hs) / params.q_seq_per_hk; + const int row = bidx - batch_idx * hs - head_idx * params.q_seq_per_hk; auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); cute::copy(rO, gO); @@ -677,16 +677,16 @@ flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_param template void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params_fp8 ¶ms, cudaStream_t stream) { FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); - const int num_m_block = cute::ceil_div(params.s_q, Kernel_traits::kBlockM); + const int num_m_block = cute::ceil_div(params.q_seq_per_hk, Kernel_traits::kBlockM); BOOL_SWITCH(params.is_causal, Is_causal, [&] { auto kernel = &flash::flash_fwd_splitkv_mla_kernel; constexpr size_t smem_size = sizeof(SharedStorage); CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - kernel<<>>(params); + kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); - dim3 grid_combine(params.b * params.h_q * params.s_q); + dim3 grid_combine(params.b * params.h_k * params.q_seq_per_hk); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; From 2b0735fe3dc78aae1b3cc42cd5920e508703e118 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 7 Aug 2025 20:34:02 +0000 Subject: [PATCH 09/17] typo Signed-off-by: Matthew Bonanni --- flash_mla/flash_mla_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 2bb00c0..e304650 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -47,8 +47,8 @@ def flash_mla_with_kvcache( num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. - descale_q_: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. - descale_k_: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. + descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. + descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). From 1e3ad7e42856de275a26ef2cdc95e5143d0f1646 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 8 Aug 2025 14:14:08 +0000 Subject: [PATCH 10/17] Fix out dtype Signed-off-by: Matthew Bonanni --- csrc/flash_api.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index dbfa8f4..8f95c1e 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -151,7 +151,13 @@ mha_fwd_kvcache_mla( at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); + caffe2::TypeMeta out_type; + if (q_dtype == torch::kFloat8_e4m3fn) { + out_type = torch::kBFloat16; + } else { + out_type = q_dtype; + } + at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype(out_type)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); CHECK_CONTIGUOUS(softmax_lse); From 18554fc5a0925cb917207cffee157fff7d2466b7 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 8 Aug 2025 15:51:07 +0000 Subject: [PATCH 11/17] Fix IMA Signed-off-by: Matthew Bonanni --- csrc/kernels_fp8/flash_fwd_mla_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/kernels_fp8/flash_fwd_mla_kernel.h b/csrc/kernels_fp8/flash_fwd_mla_kernel.h index ebb0306..a403435 100644 --- a/csrc/kernels_fp8/flash_fwd_mla_kernel.h +++ b/csrc/kernels_fp8/flash_fwd_mla_kernel.h @@ -554,7 +554,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params_fp8 pa int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; + if (begin_idx >= params.b || begin_idx < 0) return; int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); float descale_k = 1.f; From 902064e250fb5e8c72b734ddc68bf860e5def572 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 8 Aug 2025 20:48:07 +0000 Subject: [PATCH 12/17] Extension name should be _flashmla_C Signed-off-by: Matthew Bonanni --- csrc/flash_api.cpp | 4 ++-- flash_mla/flash_mla_interface.py | 3 +-- setup.py | 2 +- tests/test_flash_mla.py | 5 +++++ 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 8f95c1e..680c000 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -260,8 +260,8 @@ TORCH_LIBRARY(_flashmla_C, m) { m.impl("fwd_kvcache_mla", torch::kCUDA, make_pytorch_shim(&mha_fwd_kvcache_mla)); } -PyMODINIT_FUNC PyInit_flash_mla_cuda() { +PyMODINIT_FUNC PyInit__flashmla_C() { static struct PyModuleDef module = { - PyModuleDef_HEAD_INIT, "flash_mla_cuda", nullptr, 0, nullptr}; + PyModuleDef_HEAD_INIT, "_flashmla_C", nullptr, 0, nullptr}; return PyModule_Create(&module); } \ No newline at end of file diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index e304650..9b63ac4 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -2,8 +2,7 @@ import torch -from . import flash_mla_cuda - +import flash_mla._flashmla_C # noqa: F401 def get_mla_metadata( cache_seqlens: torch.Tensor, diff --git a/setup.py b/setup.py index 5022cf5..108eacc 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_features_args(): ext_modules = [] ext_modules.append( CUDAExtension( - name="flash_mla.flash_mla_cuda", + name="flash_mla._flashmla_C", sources=[ "csrc/flash_api.cpp", "csrc/kernels/get_mla_metadata.cu", diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index bc152d0..c49e7e3 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -1,6 +1,11 @@ import argparse import math +import os import random +import sys + +# Add the parent directory to the path so we can import flash_mla +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) import torch import triton From b77c35cc66638a50fd00e20bc2b224b4758569b9 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 Aug 2025 15:08:21 +0000 Subject: [PATCH 13/17] Clean up Signed-off-by: Matthew Bonanni --- csrc/flash_api.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 680c000..bfdcfa3 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -225,16 +225,11 @@ mha_fwd_kvcache_mla( #ifdef FLASH_MLA_DISABLE_FP8 TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP8. Please remove this flag from your environment and re-compile FlashMLA."); #else - // Create FP8-specific params by copying base params and setting FP8 fields Flash_fwd_mla_params_fp8 fp8_params; - // Copy all base fields static_cast(fp8_params) = params; - - // Set FP8-specific fields fp8_params.h_h_k_ratio = 1; fp8_params.descale_q_ptr = reinterpret_cast(descale_q.value().data_ptr()); fp8_params.descale_k_ptr = reinterpret_cast(descale_k.value().data_ptr()); - run_mha_fwd_splitkv_mla(fp8_params, stream); #endif } else { From f1420c35f5c602e216b052acd0089734cfda5924 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 Aug 2025 15:37:41 +0000 Subject: [PATCH 14/17] Tighten FP8 error tolerance Signed-off-by: Matthew Bonanni --- tests/test_flash_mla.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index c49e7e3..931b798 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -38,9 +38,8 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) - RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() - # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") if use_fp8: - assert cos_diff < 1e-2 + assert cos_diff < 1e-3 else: assert cos_diff < 1e-5 From ec3ce57b36e8fe94ec86105bd50c44c09466b35b Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 12 Aug 2025 20:41:51 +0000 Subject: [PATCH 15/17] Add attribution to copied files Signed-off-by: Matthew Bonanni --- csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu | 5 +++++ csrc/kernels_fp8/flash_fwd_mla_kernel.h | 5 +++++ csrc/kernels_fp8/flash_mla.h | 5 +++++ csrc/kernels_fp8/fp8_transpose_v.h | 6 ++++++ csrc/kernels_fp8/named_barrier.h | 5 +++++ csrc/kernels_fp8/softmax.h | 5 +++++ csrc/kernels_fp8/static_switch.h | 5 +++++ csrc/kernels_fp8/utils.h | 5 +++++ tests/test_flash_mla.py | 1 + 9 files changed, 42 insertions(+) diff --git a/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu b/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu index dda6aa5..e5f2b56 100644 --- a/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu +++ b/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu @@ -1,3 +1,8 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + #include "flash_fwd_mla_kernel.h" #ifndef FLASH_MLA_DISABLE_FP8 diff --git a/csrc/kernels_fp8/flash_fwd_mla_kernel.h b/csrc/kernels_fp8/flash_fwd_mla_kernel.h index a403435..4275c8c 100644 --- a/csrc/kernels_fp8/flash_fwd_mla_kernel.h +++ b/csrc/kernels_fp8/flash_fwd_mla_kernel.h @@ -1,3 +1,8 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + #pragma once #include diff --git a/csrc/kernels_fp8/flash_mla.h b/csrc/kernels_fp8/flash_mla.h index 0ae85a9..857f128 100644 --- a/csrc/kernels_fp8/flash_mla.h +++ b/csrc/kernels_fp8/flash_mla.h @@ -1,3 +1,8 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + #pragma once #include "../kernels/params.h" diff --git a/csrc/kernels_fp8/fp8_transpose_v.h b/csrc/kernels_fp8/fp8_transpose_v.h index 40bb4d5..082ae0a 100644 --- a/csrc/kernels_fp8/fp8_transpose_v.h +++ b/csrc/kernels_fp8/fp8_transpose_v.h @@ -1,3 +1,9 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + + /** * ref to Fa3's SmemTranspose64x64: * https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26 diff --git a/csrc/kernels_fp8/named_barrier.h b/csrc/kernels_fp8/named_barrier.h index 940c934..572516a 100644 --- a/csrc/kernels_fp8/named_barrier.h +++ b/csrc/kernels_fp8/named_barrier.h @@ -1,3 +1,8 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + #pragma once #include "cutlass/barrier.h" diff --git a/csrc/kernels_fp8/softmax.h b/csrc/kernels_fp8/softmax.h index bcb8cac..1d6d553 100644 --- a/csrc/kernels_fp8/softmax.h +++ b/csrc/kernels_fp8/softmax.h @@ -1,3 +1,8 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h #pragma once diff --git a/csrc/kernels_fp8/static_switch.h b/csrc/kernels_fp8/static_switch.h index f156adc..58e0fe2 100644 --- a/csrc/kernels_fp8/static_switch.h +++ b/csrc/kernels_fp8/static_switch.h @@ -1,3 +1,8 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + #pragma once #define CHECK_CUDA(call) \ diff --git a/csrc/kernels_fp8/utils.h b/csrc/kernels_fp8/utils.h index 716c50c..cd6f95b 100644 --- a/csrc/kernels_fp8/utils.h +++ b/csrc/kernels_fp8/utils.h @@ -1,3 +1,8 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h #pragma once diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 931b798..e613712 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -102,6 +102,7 @@ def prepare_fp8_input(): blocked_v = blocked_v_fp8 def flash_mla(): + breakpoint() return flash_mla_with_kvcache( q, blocked_k, From cf90884c75196db7be78c726d3868a4eda0aee48 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 12 Aug 2025 20:43:17 +0000 Subject: [PATCH 16/17] Remove breakpoint Signed-off-by: Matthew Bonanni --- tests/test_flash_mla.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index e613712..931b798 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -102,7 +102,6 @@ def prepare_fp8_input(): blocked_v = blocked_v_fp8 def flash_mla(): - breakpoint() return flash_mla_with_kvcache( q, blocked_k, From ce68f28129f6a71f596a9703732b1ebf7a7066dd Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 12 Aug 2025 21:18:28 +0000 Subject: [PATCH 17/17] Port cudagraph fix from vllm-project/FlashMLA#3 Signed-off-by: Matthew Bonanni --- csrc/kernels_fp8/flash_fwd_mla_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/kernels_fp8/flash_fwd_mla_kernel.h b/csrc/kernels_fp8/flash_fwd_mla_kernel.h index 4275c8c..9ba4035 100644 --- a/csrc/kernels_fp8/flash_fwd_mla_kernel.h +++ b/csrc/kernels_fp8/flash_fwd_mla_kernel.h @@ -602,7 +602,7 @@ flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_param const int split_offset = __ldg(params.num_splits_ptr + batch_idx); const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); - if (actual_num_splits == 1) return; + if (actual_num_splits <= 1) return; __shared__ ElementAccum sLseScale[kMaxSplits];