Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ struct Sm120CollectiveFMhaWs {
static constexpr int kRowsPerMMA = 2;
static constexpr int kMmaThreads = size(TiledMma{});

// TODO: tune the number of stages based on smem size
static constexpr int StageCountQ = 1;
static constexpr int StageCountKV = 3;

// Atom layout: (8, BLK_K):(BLK_K, 1) k-major
using SmemLayoutAtom_ =
decltype(composition(Swizzle<3, 3, 3>{},
Expand All @@ -68,14 +72,16 @@ struct Sm120CollectiveFMhaWs {
using SmemLayoutQ =
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_M, HEAD_DIM>{}));

// KV smem: (BLK_N, HEAD_DIM)
// KV smem: (BLK_N, HEAD_DIM, KVStages)
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_N, HEAD_DIM>{}));
decltype(tile_to_shape(SmemLayoutAtom_{},
Shape<BLK_N, HEAD_DIM, Int<StageCountKV>>{}));
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtom_{}, Shape<BLK_N, HEAD_DIM>{}));
decltype(tile_to_shape(SmemLayoutAtom_{},
Shape<BLK_N, HEAD_DIM, Int<StageCountKV>>{}));

// V^T smem: (HEAD_DIM, BLK_N)
using SmemLayoutVt = decltype(select<1, 0>(SmemLayoutV{}));
// V^T smem: (HEAD_DIM, BLK_N, KVStages)
using SmemLayoutVt = decltype(select<1, 0, 2>(SmemLayoutV{}));

// s2r tiled copy for gemm-I
using SmemTiledCopyQ =
Expand All @@ -92,17 +98,13 @@ struct Sm120CollectiveFMhaWs {

struct TensorStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>> smem_vt;
};
};

// TODO: tune the number of stages based on smem size
static constexpr int StageCountQ = 1;
static constexpr int StageCountKV = 2;

using PipelineQ = cutlass::PipelineAsync<StageCountQ>;
using PipelineKV = cutlass::PipelineAsync<StageCountKV>;

Expand Down Expand Up @@ -201,80 +203,91 @@ struct Sm120CollectiveFMhaWs {
// Construct smem tensors
// (BLK_M, HEAD_DIM), k-major
Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{});
// (BLK_N, HEAD_DIM), k-major
// (BLK_N, HEAD_DIM, KVStages), k-major
Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{});
// Tensor for V^t; used in GEMM-II.
// (HEAD_DIM, BLK_N), k-major
// (HEAD_DIM, BLK_N, KVStages), k-major
Tensor sVt = make_tensor(make_smem_ptr(ss.smem_vt.data()), SmemLayoutVt{});

TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(tidx);
// GEMM-I: S = Q@K.T
auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
// (MMA,MMA_M,MMA_K)
auto tSrQ = thr_mma.partition_fragment_A(sQ);
// (MMA,MMA_N,MMA_K)
auto tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{}));

// s2r tiled copy for qkv
// copy query to rmem
SmemTiledCopyQ smem_tiled_copy_Q;
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx);
// (CPY,CPY_M,CPY_K)
auto tSsQ = smem_thr_copy_Q.partition_S(sQ);
// (CPY,CPY_M,CPY_K)
auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);

SmemTiledCopyK smem_tiled_copy_K;
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx);
// (CPY,CPY_N,CPY_K, KVStages)
auto tSsK = smem_thr_copy_K.partition_S(sK);
// (CPY,CPY_N,CPY_K)
auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);

// S = Q@K.T
// tSrAccS: (MMA,MMA_M,MMA_N)
auto compute_qk = [&](auto& tSrAccS) {
auto compute_qk = [&](auto& tSrAccS, int stage) {
auto tSsK_s = tSsK(_, _, _, stage);
// prefetch key
cute::copy(
smem_tiled_copy_K, tSsK(_, _, _0{}), tSrK_copy_view(_, _, _0{}));
smem_tiled_copy_K, tSsK_s(_, _, _0{}), tSrK_copy_view(_, _, _0{}));

CUTE_UNROLL
for (int ki = 0; ki < size<2>(tSrQ); ++ki) {
// prefetch next key
if (ki != size<2>(tSrQ) - 1) {
const auto next_ki = ki + 1;
cute::copy(smem_tiled_copy_K,
tSsK(_, _, next_ki),
tSsK_s(_, _, next_ki),
tSrK_copy_view(_, _, next_ki));
}
cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS);
}
};

// GEMM-II: O = softmax(S)@V
auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N)
// (MMA,MMA_K,MMA_N)
auto tOrVt = thr_mma.partition_fragment_B(sVt(_, _, _0{}));

SmemTiledCopyVt smem_tiled_copy_Vt;
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx);
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx);
// (CPY,CPY_K,CPY_N, KVStages)
auto tOsVt = smem_thr_copy_Vt.partition_S(sVt);
// (CPY,CPY_K,CPY_N)
auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt);

// O = softmax(S)*V
// tSrAccS: (MMA,MMA_M,MMA_N)
// tOrAccO: (MMA,MMA_M,MMA_K)
auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO) {
auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO, int stage) {
// cast scores from Accumulator to Element
auto tSrS = make_tensor_like<Element>(tSrAccS);
fast_cast(tSrAccS, tSrS);

// convert layout from gemm-I C to gemm-II A
auto tOrS =
make_tensor(tSrS.data(), LayoutConvertor::to_mma_a(tSrS.layout()));

// (CPY,CPY_M,CPY_K)
auto tOsVt_s = tOsVt(_, _, _, stage);
// prefetch V^t
cute::copy(
smem_tiled_copy_Vt, tOsVt(_, _, _0{}), tOrVt_copy_view(_, _, _0{}));
smem_tiled_copy_Vt, tOsVt_s(_, _, _0{}), tOrVt_copy_view(_, _, _0{}));
CUTE_UNROLL
for (int ki = 0; ki < size<2>(tOrS); ++ki) {
// prefetch next V^t
if (ki != size<2>(tOrS) - 1) {
const auto next_ki = ki + 1;
cute::copy(smem_tiled_copy_Vt,
tOsVt(_, _, next_ki),
tOsVt_s(_, _, next_ki),
tOrVt_copy_view(_, _, next_ki));
}
cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO);
Expand Down Expand Up @@ -321,7 +334,7 @@ struct Sm120CollectiveFMhaWs {
kv_pipeline.consumer_wait(kv_state);

// 1> S = Q@K.T
compute_qk(tSrS);
compute_qk(tSrS, kv_state.index());

// release key smem
kv_pipeline.consumer_release(kv_state);
Expand All @@ -345,7 +358,7 @@ struct Sm120CollectiveFMhaWs {
kv_pipeline.consumer_wait(kv_state);

// 2> O = softmax(S)*V
compute_sv(tSrS, tOrO);
compute_sv(tSrS, tOrO, kv_state.index());

// release value smem
kv_pipeline.consumer_release(kv_state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct Sm120CollectiveLoadCpAsyncWs {
// Construct smem tensors
// (BLK_M, HEAD_DIM), k-major
Tensor sQ = make_tensor(make_smem_ptr(ss.smem_q.data()), SmemLayoutQ{});
// (BLK_N, HEAD_DIM), k-major
// (BLK_N, HEAD_DIM, KVStages), k-major
Tensor sK = make_tensor(make_smem_ptr(ss.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(ss.smem_v.data()), SmemLayoutV{});

Expand All @@ -74,15 +74,15 @@ struct Sm120CollectiveLoadCpAsyncWs {
GmemCopyThrLayout_{}, // Thr layout: (_16,_8)/(_32, _4)
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
);
auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(tidx);
auto gmem_thr_copy = gmem_tiled_copy.get_slice(tidx);

// (CPY, CPY_N, CPY_K, n) => (N, K)
Tensor tGcKV = gmem_thr_copy.partition_S(cKV);
// (CPY, CPY_N, CPY_K, n)
Tensor tGgK = gmem_thr_copy.partition_S(gK);
Tensor tGgV = gmem_thr_copy.partition_S(gV);

// (CPY, CPY_N, CPY_K)
// (CPY, CPY_N, CPY_K, KVStages)
Tensor tGsK = gmem_thr_copy.partition_D(sK);
Tensor tGsV = gmem_thr_copy.partition_D(sV);

Expand All @@ -108,7 +108,7 @@ struct Sm120CollectiveLoadCpAsyncWs {
safe_copy</*EVEN_N=*/false, EVEN_K, /*ZFILL_N=*/false, /*ZFILL_K=*/true>(
gmem_tiled_copy,
tGgK(_, _, _, ni),
tGsK,
tGsK(_, _, _, state.index()),
tGcKV(_, _, _, ni),
residue_nk);
kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
Expand All @@ -121,7 +121,7 @@ struct Sm120CollectiveLoadCpAsyncWs {
safe_copy</*EVEN_N=*/true, EVEN_K, /*ZFILL_N=*/false, /*ZFILL_K=*/false>(
gmem_tiled_copy,
tGgK(_, _, _, ni),
tGsK,
tGsK(_, _, _, state.index()),
tGcKV(_, _, _, ni),
residue_nk);
kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
Expand All @@ -134,7 +134,7 @@ struct Sm120CollectiveLoadCpAsyncWs {
safe_copy</*EVEN_N=*/false, EVEN_K, /*ZFILL_N=*/true, /*ZFILL_K=*/true>(
gmem_tiled_copy,
tGgV(_, _, _, ni),
tGsV,
tGsV(_, _, _, state.index()),
tGcKV(_, _, _, ni),
residue_nk);
kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
Expand All @@ -147,15 +147,15 @@ struct Sm120CollectiveLoadCpAsyncWs {
safe_copy</*EVEN_N=*/true, EVEN_K, /*ZFILL_N=*/false, /*ZFILL_K=*/false>(
gmem_tiled_copy,
tGgV(_, _, _, ni),
tGsV,
tGsV(_, _, _, state.index()),
tGcKV(_, _, _, ni),
residue_nk);
kv_pipeline.producer_commit(state, cutlass::arch::cpasync_barrier_arrive);
++state;
};

// async copy gmem to smem in following order:
// Q1, Kn-1, Vn-1, ..., K2, V2, K1, V1
// Q0, Kn-1, Vn-1, ..., K1, V1, K0, V0

// produce Q1
produce_query(q_state);
Expand Down
44 changes: 24 additions & 20 deletions src/kernels/attention/tests/sm120_fmha_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cstdint>
#include <cute/layout.hpp>

#include "common/static_dispatch.h"
#include "device/sm120_fmha_launch.cuh"
#include "mha_params.h"
#include "tests/mha_ref.h"
Expand Down Expand Up @@ -81,15 +82,17 @@ torch::Tensor sm120_fmha(
// normalize params that for performance optimization
params.normalize();

DISPATCH_TORCH_DTYPE_(query.dtype(), DTYPE, [&] {
DISPATCH_TORCH_DTYPE_(query.dtype(), Dtype, [&] {
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
sm120_launch_mha_kernel<DTYPE,
HEAD_DIM,
/*EVEN_K*/ true,
/*ALIBI*/ false,
/*SOFT_CAP*/ false,
/*LOCAL*/ false,
MHAParams>(params, nullptr);
DISPATCH_BOOL(params.head_dim == HEAD_DIM, EVEN_K, [&] {
sm120_launch_mha_kernel<Dtype,
HEAD_DIM,
EVEN_K,
/*ALIBI*/ false,
/*SOFT_CAP*/ false,
/*LOCAL*/ false,
MHAParams>(params, nullptr);
});
});
});
return out;
Expand Down Expand Up @@ -140,7 +143,7 @@ TEST_P(MHAKernelTest, FMHA) {

torch::optional<torch::Tensor> alibi_slopes;
if (alibi) {
alibi_slopes = torch::rand(
alibi_slopes = torch::randn(
{n_heads}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
}

Expand All @@ -159,16 +162,17 @@ TEST_P(MHAKernelTest, FMHA) {
INSTANTIATE_TEST_SUITE_P(
SM120,
MHAKernelTest,
::testing::Combine(::testing::Values(torch::kHalf), // q_dtype
::testing::Values(1), // batch_size
::testing::Values(62), // q_len
::testing::Values(127), // kv_len
::testing::Values(6), // n_heads
::testing::Values(6), // n_kv_heads
::testing::Values(64), // head_dim
::testing::Values(0.0), // logits_soft_cap
::testing::Values(false), // alibi slope
::testing::Values(-1) // sliding window
));
::testing::Combine(
::testing::Values(torch::kHalf), // q_dtype
::testing::Values(1, 2, 4), // batch_size
::testing::Values(1, 62, 125), // q_len
::testing::Values(127, 287, 1000), // kv_len
::testing::Values(6), // n_heads
::testing::Values(6 /*mha*/, 3 /*gqa*/, 1 /*mqa*/), // n_kv_heads
::testing::Values(32, 64), // head_dim
::testing::Values(0.0), // logits_soft_cap
::testing::Values(false), // alibi slope
::testing::Values(-1) // sliding window
));

} // namespace llm