-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[V1][Kernel] Add triton implementation for reshape_and_cache_flash
#24503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
mgoin
merged 21 commits into
vllm-project:main
from
bringlein:ngl_triton_reshape_and_cache_pr
Sep 23, 2025
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
af04384
adding reshape and cache kernel, adapting and checking unit tests
bringlein 6fc04db
cosmetics
bringlein 3fd8b21
integrate triton reshape and cache into triton_attn
bringlein aa227e6
Merge remote-tracking branch 'origin' into ngl_triton_reshape_and_cac…
bringlein 2ef6acb
ruff...
bringlein b17d8a8
fix unittest
bringlein b8ccbac
ruff...
bringlein 5ab8eef
Merge remote-tracking branch 'origin/main' into ngl_triton_reshape_an…
bringlein fb4f8fb
fixing stupidity
bringlein 90425db
ensuring co-authors
cyang49 c679da1
disable triton kernel if using fp8 on rocm
bringlein 55e829e
fmt...
bringlein ebd565c
simplifying fp8 logic, supporting now float8_e4m3fnuz
bringlein 4243b54
fmt...
bringlein 3f36c80
Merge branch 'main' into ngl_triton_reshape_and_cache_pr
bringlein 6962704
Merge branch 'main' into ngl_triton_reshape_and_cache_pr
yewentao256 12dbbec
Merge remote-tracking branch 'origin/main' into ngl_triton_reshape_an…
bringlein 6316e22
fmt...
bringlein 10c6af4
update reshape and cache benchmark
bringlein d7247a4
Merge branch 'ngl_triton_reshape_and_cache_pr' of github.com:bringlei…
bringlein 0b2c4d6
mute ruff false-positive
bringlein File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from vllm.platforms import current_platform | ||
|
||
|
||
@triton.jit | ||
def reshape_and_cache_kernel_flash( | ||
key_ptr, # [num_tokens, num_heads, head_size] | ||
value_ptr, # [num_tokens, num_heads, head_size] | ||
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size] | ||
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size] | ||
slot_mapping_ptr, # [num_tokens] | ||
k_scale, # float32 | ||
v_scale, # float32 | ||
# strides | ||
key_stride: tl.int64, | ||
value_stride: tl.int64, | ||
block_stride: tl.int64, | ||
page_stride: tl.int64, | ||
num_heads: tl.constexpr, | ||
head_size: tl.constexpr, | ||
block_size: tl.constexpr, | ||
# FP8 flags | ||
FP8_KV_CACHE: tl.constexpr, | ||
# tune parameters | ||
TILE_SIZE: tl.constexpr, | ||
): | ||
|
||
token_idx = tl.program_id(axis=0) | ||
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64) | ||
if slot_idx < 0: | ||
# Padding token that should be ignored. | ||
return | ||
|
||
tile_i = tl.program_id(axis=1) | ||
tile_offs = tl.arange(0, TILE_SIZE) | ||
tile_pos = tile_i * TILE_SIZE + tile_offs | ||
|
||
block_idx = slot_idx // block_size | ||
block_offset = slot_idx % block_size | ||
|
||
src_key_idx = token_idx * key_stride | ||
src_value_idx = token_idx * value_stride | ||
|
||
tgt_idx = block_idx * block_stride + block_offset * page_stride | ||
|
||
# [TILE_SIZE] | ||
key_load = tl.load(key_ptr + src_key_idx + tile_pos, | ||
mask=tile_pos < (num_heads * head_size)) | ||
if FP8_KV_CACHE: | ||
if key_load.dtype.is_fp8(): | ||
key_tile = key_load | ||
else: | ||
# tl.store will do the correct implicit cast to fp8, | ||
# based on the key_cache_ptr.dtype.element_ty | ||
key_tile = key_load / tl.load(k_scale) | ||
else: | ||
key_tile = key_load | ||
|
||
# [TILE_SIZE] | ||
value_load = tl.load(value_ptr + src_value_idx + tile_pos, | ||
mask=tile_pos < (num_heads * head_size)) | ||
if FP8_KV_CACHE: | ||
if value_load.dtype.is_fp8(): | ||
value_tile = value_load | ||
else: | ||
# tl.store will do the correct implicit cast to fp8, | ||
# based on the value_cache_ptr.dtype.element_ty | ||
value_tile = value_load / tl.load(v_scale) | ||
else: | ||
value_tile = value_load | ||
|
||
tl.store( | ||
key_cache_ptr + tgt_idx + tile_pos, | ||
key_tile, | ||
mask=tile_pos < (num_heads * head_size), | ||
) | ||
tl.store( | ||
value_cache_ptr + tgt_idx + tile_pos, | ||
value_tile, | ||
mask=tile_pos < (num_heads * head_size), | ||
) | ||
return | ||
|
||
|
||
def triton_reshape_and_cache_flash( | ||
key: torch.Tensor, # [num_tokens, num_heads, head_size] | ||
value: torch.Tensor, # [num_tokens, num_heads, head_size] | ||
# [num_blocks, block_size, num_heads, head_size] | ||
key_cache: torch.Tensor, | ||
# [num_blocks, block_size, num_heads, head_size] | ||
value_cache: torch.Tensor, | ||
slot_mapping: torch.Tensor, # [num_tokens] | ||
kv_cache_dtype: str, # "auto", "fp8" | ||
k_scale: torch.Tensor, # float32 | ||
v_scale: torch.Tensor, # float32 | ||
): | ||
num_tokens = key.shape[0] | ||
num_heads = key.shape[1] | ||
head_size = key.shape[2] | ||
block_size = key_cache.shape[1] | ||
n = num_heads * head_size | ||
|
||
key_stride = key.stride()[0] | ||
value_stride = value.stride()[0] | ||
block_stride = key_cache.stride()[0] | ||
page_stride = key_cache.stride()[1] | ||
|
||
head_stride = key_cache.stride()[2] | ||
assert head_stride == head_size, "only continous heads are supported" | ||
bringlein marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \ | ||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." | ||
kv_cache_torch_dtype = current_platform.fp8_dtype() if \ | ||
kv_cache_dtype.startswith("fp8") else key_cache.dtype | ||
|
||
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith( | ||
"fp8"): | ||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8) | ||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) | ||
key_cache = key_cache.view(kv_cache_torch_dtype) | ||
value_cache = value_cache.view(kv_cache_torch_dtype) | ||
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\ | ||
"uint8 is not supported by triton reshape_and_cache_flash" | ||
|
||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") | ||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ | ||
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8, | ||
torch.float8_e4m3fnuz], \ | ||
"unsupported dtype of KV cache tensor, got "\ | ||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \ | ||
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." | ||
|
||
# heuristics instead of autotuning | ||
TILE_SIZE = min(2048, triton.next_power_of_2(n)) | ||
if torch.version.hip: | ||
num_stages = 4 | ||
num_warps = 8 | ||
else: # cuda | ||
num_stages = 10 | ||
num_warps = 16 | ||
if torch.cuda.get_device_capability(key.device)[0] < 9: | ||
TILE_SIZE = min(512, TILE_SIZE) | ||
|
||
# TODO(ngl): maybe replace with static launch grid to avoid overhead if | ||
# using cudagraphs | ||
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"])) | ||
|
||
reshape_and_cache_kernel_flash[grid]( | ||
key_ptr=key, | ||
value_ptr=value, | ||
key_cache_ptr=key_cache, | ||
value_cache_ptr=value_cache, | ||
slot_mapping_ptr=slot_mapping, | ||
k_scale=k_scale, | ||
v_scale=v_scale, | ||
# strides | ||
key_stride=key_stride, | ||
value_stride=value_stride, | ||
block_stride=block_stride, | ||
page_stride=page_stride, | ||
num_heads=num_heads, | ||
head_size=head_size, | ||
block_size=block_size, | ||
# FP8 flags | ||
FP8_KV_CACHE=FP8_KV_CACHE, | ||
# autotune parameters | ||
TILE_SIZE=TILE_SIZE, | ||
num_warps=num_warps, | ||
num_stages=num_stages, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.