-
Notifications
You must be signed in to change notification settings - Fork 40
/
flashinfer_all.cu
186 lines (168 loc) · 8.17 KB
/
flashinfer_all.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#include <algorithm>
#include <cstdint>
#include <cstdio>
#include "flashinfer/page.cuh"
#include "flashinfer_config.h"
#include "flashinfer_decl.h"
#include "generated/dispatch.inc"
using flashinfer::paged_kv_t;
using flashinfer::PageStorage;
using flashinfer::RotaryMode;
#define _DISPATCH_SWITCH(cond, ...) \
[&]() -> bool { \
switch (cond) { \
__VA_ARGS__ \
default: \
return false; \
} \
}()
#define _DISPATCH_CASE(case_expr, var, ...) \
case case_expr: { \
constexpr auto var = case_expr; \
return __VA_ARGS__(); \
}
#define DISPATCH_group_size(expr, ...) \
_DISPATCH_SWITCH(expr, _DISPATCH_CASES_group_size(__VA_ARGS__))
#define DISPATCH_page_size(expr, ...) \
_DISPATCH_SWITCH(expr, _DISPATCH_CASES_page_size(__VA_ARGS__))
#define DISPATCH_head_dim(expr, ...) \
_DISPATCH_SWITCH(expr, _DISPATCH_CASES_head_dim(__VA_ARGS__))
namespace {
template <typename T>
inline T* alloc_from_buf(void** buf, int n) {
auto* p = (T*)*buf;
*buf = (void*)(p + n);
return p;
}
} // namespace
template <typename T>
bool FlashInferBatchPrefillKernel(T* o, T* q, int32_t* qo_indptr, T** kv_ptrs,
int32_t* kv_indptr, int32_t* last_page_offset,
void* tmpbuf, int head_dim, int num_layers,
int layer_idx, int group_size,
int num_kv_heads, int page_size,
int batch_size) {
return DISPATCH_page_size(page_size, [&] {
return DISPATCH_group_size(group_size, [&] {
return DISPATCH_head_dim(head_dim, [&] {
auto kv_aux = alloc_from_buf<int32_t>(&tmpbuf, 4 * (batch_size + 1));
paged_kv_t<PageStorage::kPointer, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim,
batch_size, kv_ptrs, kv_indptr, last_page_offset, kv_aux);
int num_qo_heads = num_kv_heads * group_size;
constexpr bool allow_fp16_qk_reduction = false;
constexpr bool causal = true;
constexpr auto rotary = RotaryMode::kLlama;
float rope_scale = 1.f;
float rope_theta = 1e4;
cudaStream_t stream = nullptr;
auto status = flashinfer::BatchPrefillWithPagedKVCacheDispatched<
PAGE_SIZE, GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, rotary,
allow_fp16_qk_reduction, causal>(q, paged_kv, qo_indptr, o,
(float*)tmpbuf, num_qo_heads,
rope_scale, rope_theta, stream);
if (status != cudaSuccess) {
fprintf(stderr, "batch_prefill failed: %s\n",
cudaGetErrorString(status));
}
return true;
});
});
});
}
template <typename T>
bool FlashInferBatchDecodeKernel(T* o, T* q, T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, void* tmpbuf,
int head_dim, int num_layers, int layer_idx,
int group_size, int num_kv_heads,
int page_size, int batch_size) {
return DISPATCH_page_size(page_size, [&] {
return DISPATCH_group_size(group_size, [&] {
return DISPATCH_head_dim(head_dim, [&] {
auto kv_aux = alloc_from_buf<int32_t>(&tmpbuf, 4 * (batch_size + 1));
paged_kv_t<PageStorage::kPointer, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim,
batch_size, kv_ptrs, kv_indptr, last_page_offset, kv_aux);
constexpr auto rotary = RotaryMode::kLlama;
float rope_scale = 1.f;
float rope_theta = 1e4;
cudaStream_t stream = nullptr;
auto status = flashinfer::BatchDecodeWithPagedKVCacheDispatched<
PAGE_SIZE, GROUP_SIZE, HEAD_DIM, PageStorage::kPointer, rotary>(
q, paged_kv, o, nullptr, rope_scale, rope_theta, stream);
if (status != cudaSuccess) {
fprintf(stderr, "batch_decode failed: %s\n",
cudaGetErrorString(status));
}
return true;
});
});
});
}
template <int head_dim, typename T>
void FlashInferInitKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int32_t* seqlen_indptr, int num_layers,
int layer_idx, int num_kv_heads, int page_size,
int batch_size) {
paged_kv_t<PageStorage::kPointer, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size,
kv_ptrs, kv_indptr, last_page_offset);
constexpr size_t vec_size =
std::max(16 / sizeof(T), static_cast<size_t>(head_dim / 32));
constexpr size_t bdx = head_dim / vec_size;
constexpr size_t bdy = 1;
dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy);
dim3 nthrs(bdx, bdy);
flashinfer::AppendPagedKVCachePrefillKernel<head_dim, vec_size, bdx, bdy,
PageStorage::kPointer, T, int32_t>
<<<nblks, nthrs>>>(paged_kv, key, value, seqlen_indptr);
}
template <int head_dim, typename T>
void FlashInferAppendKvKernel(T** kv_ptrs, int32_t* kv_indptr,
int32_t* last_page_offset, T* key, T* value,
int num_layers, int layer_idx, int num_kv_heads,
int page_size, int batch_size) {
paged_kv_t<PageStorage::kPointer, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size,
kv_ptrs, kv_indptr, last_page_offset);
constexpr size_t vec_size =
std::max(16 / sizeof(T), static_cast<size_t>(head_dim / 32));
constexpr size_t bdx = head_dim / vec_size;
constexpr size_t bdy = 1;
dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy);
dim3 nthrs(bdx, bdy);
flashinfer::AppendPagedKVCacheDecodeKernel<head_dim, vec_size, bdx, bdy,
PageStorage::kPointer, T, int32_t>
<<<nblks, nthrs>>>(paged_kv, key, value);
}
#define INST_FlashInferBatchPrefillKernel(T) \
template bool FlashInferBatchPrefillKernel<T>( \
T * o, T * q, int32_t * qo_indptr, T * *kv_ptrs, int32_t * kv_indptr, \
int32_t * last_page_offset, void* tmpbuf, int head_dim, int num_layers, \
int layer_idx, int group_size, int num_kv_heads, int page_size, \
int batch_size);
INST_FlashInferBatchPrefillKernel(nv_half);
INST_FlashInferBatchPrefillKernel(nv_bfloat16);
#define INST_FlashInferBatchDecodeKernel(T) \
template bool FlashInferBatchDecodeKernel<T>( \
T * o, T * q, T * *kv_ptrs, int32_t * kv_indptr, \
int32_t * last_page_offset, void* tmpbuf, int head_dim, int num_layers, \
int layer_idx, int group_size, int num_kv_heads, int page_size, \
int batch_size);
INST_FlashInferBatchDecodeKernel(nv_half);
INST_FlashInferBatchDecodeKernel(nv_bfloat16);
#define INST_FlashInferInitKvKernel(head_dim, T) \
template void FlashInferInitKvKernel<head_dim, T>( \
T * *kv_ptrs, int32_t * kv_indptr, int32_t * last_page_offset, T * key, \
T * value, int32_t * seqlen_indptr, int num_layers, int layer_idx, \
int num_kv_heads, int page_size, int batch_size);
FOR_FlashInferBatchDecode_D(INST_FlashInferInitKvKernel, nv_half);
FOR_FlashInferBatchDecode_D(INST_FlashInferInitKvKernel, nv_bfloat16);
#define INST_FlashInferAppendKvKernel(head_dim, T) \
template void FlashInferAppendKvKernel<head_dim, T>( \
T * *kv_ptrs, int32_t * kv_indptr, int32_t * last_page_offset, T * key, \
T * value, int num_layers, int layer_idx, int num_kv_heads, \
int page_size, int batch_size);
FOR_FlashInferBatchDecode_D(INST_FlashInferAppendKvKernel, nv_half);
FOR_FlashInferBatchDecode_D(INST_FlashInferAppendKvKernel, nv_bfloat16);