Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "kernels/params.h"
#include "kernels/splitkv_mla.h"

#include "kernels_fp8/flash_mla.h"

#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
Expand Down Expand Up @@ -68,7 +70,9 @@ mha_fwd_kvcache_mla(
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
const at::Tensor &num_splits, // batch_size + 1
c10::optional<const at::Tensor> &descale_q, // batch_size
c10::optional<const at::Tensor> &descale_k // batch_size
) {
// Check the architecture
auto dprops = at::cuda::getCurrentDeviceProperties();
Expand All @@ -77,7 +81,7 @@ mha_fwd_kvcache_mla(

// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf || q_dtype == torch::kFloat8_e4m3fn);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
Expand Down Expand Up @@ -115,6 +119,20 @@ mha_fwd_kvcache_mla(
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

if (q_dtype == torch::kFloat8_e4m3fn) {
TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8");
auto descale_q_ = descale_q.value();
auto descale_k_ = descale_k.value();
CHECK_DEVICE(descale_q_);
CHECK_DEVICE(descale_k_);
TORCH_CHECK(descale_q_.stride(-1) == 1);
TORCH_CHECK(descale_k_.stride(-1) == 1);
TORCH_CHECK(descale_q_.dtype() == torch::kFloat);
TORCH_CHECK(descale_k_.dtype() == torch::kFloat);
CHECK_SHAPE(descale_q_, 1);
CHECK_SHAPE(descale_k_, 1);
}

if (seqlen_q_ori == 1) { is_causal = false; }

const int num_q_heads_per_hk = num_heads_q / num_heads_k;
Expand All @@ -133,7 +151,13 @@ mha_fwd_kvcache_mla(
at::cuda::CUDAGuard device_guard{(char)q.get_device()};

auto opts = q.options();
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
caffe2::TypeMeta out_type;
if (q_dtype == torch::kFloat8_e4m3fn) {
out_type = torch::kBFloat16;
} else {
out_type = q_dtype;
}
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype(out_type));
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse);

Expand Down Expand Up @@ -196,6 +220,17 @@ mha_fwd_kvcache_mla(
#else
run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
#endif
} else if (q_dtype == torch::kFloat8_e4m3fn) {
#ifdef FLASH_MLA_DISABLE_FP8
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP8. Please remove this flag from your environment and re-compile FlashMLA.");
#else
Flash_fwd_mla_params_fp8 fp8_params;
static_cast<Flash_fwd_mla_params&>(fp8_params) = params;
fp8_params.h_h_k_ratio = 1;
fp8_params.descale_q_ptr = reinterpret_cast<float *>(descale_q.value().data_ptr());
fp8_params.descale_k_ptr = reinterpret_cast<float *>(descale_k.value().data_ptr());
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(fp8_params, stream);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
Expand Down Expand Up @@ -223,5 +258,5 @@ TORCH_LIBRARY(_flashmla_C, m) {
PyMODINIT_FUNC PyInit__flashmla_C() {
static struct PyModuleDef module = {
PyModuleDef_HEAD_INIT, "_flashmla_C", nullptr, 0, nullptr};
return PyModule_Create(&module);
return PyModule_Create(&module);
}
10 changes: 10 additions & 0 deletions csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54
* originally authored by @endurehero
*/

#include "flash_fwd_mla_kernel.h"

#ifndef FLASH_MLA_DISABLE_FP8
template void run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(Flash_fwd_mla_params_fp8 &params, cudaStream_t stream);
#endif
Loading