Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions benchmarks/kernels/benchmark_reshape_and_cache_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from tabulate import tabulate

from vllm import _custom_ops as ops
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
Expand All @@ -31,13 +34,23 @@ def run_benchmark(
kv_cache_dtype: str,
kv_cache_layout: str,
num_iters: int,
implementation: str,
benchmark_mode: str,
device: str = "cuda",
) -> float:
"""Return latency (seconds) for given num_tokens."""

if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")

if implementation not in ("cuda", "triton"):
raise ValueError(
f"Unsupported implementation: {implementation}. "
"Only 'cuda' and 'triton' are supported."
)
if implementation == "triton" and kv_cache_layout == "HND":
return float("nan") # Triton does not support HND layout yet.

current_platform.seed_everything(42)
torch.set_default_device(device)

Expand Down Expand Up @@ -65,27 +78,49 @@ def run_benchmark(
cache_layout=kv_cache_layout,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# to free unused memory
del key_caches, value_caches

# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)

if implementation == "cuda":
function_under_test = lambda: ops.reshape_and_cache_flash(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
else:
function_under_test = lambda: triton_reshape_and_cache_flash(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
if benchmark_mode == "cudagraph":
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
function_under_test = lambda: g.replay()

def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
torch.cuda.synchronize()
function_under_test()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / n_iters

Expand Down Expand Up @@ -116,10 +151,16 @@ def main(args):
kv_cache_dtype=args.kv_cache_dtype,
kv_cache_layout=layout,
num_iters=args.iters,
implementation=args.implementation,
benchmark_mode=args.mode,
device="cuda",
)
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])

print(
f"Benchmark results for implementation {args.implementation}"
f" (measuring with {args.mode}):"
)
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))


Expand Down Expand Up @@ -151,6 +192,21 @@ def main(args):
)

parser.add_argument("--iters", type=int, default=100)

parser.add_argument(
"--implementation",
type=str,
choices=["cuda", "triton"],
default="cuda",
)

parser.add_argument(
"--mode",
type=str,
choices=["cudagraph", "no_graph"],
default="cudagraph",
)

args = parser.parse_args()

main(args)
27 changes: 21 additions & 6 deletions tests/kernels/attention/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"]

RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]


@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
Expand Down Expand Up @@ -223,6 +225,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
Expand All @@ -236,9 +239,13 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
kv_cache_layout: str,
implementation: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
assert implementation in ["cuda", "triton"]
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")

# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
Expand Down Expand Up @@ -298,12 +305,20 @@ def permute_and_compact(x):
cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
if implementation == "cuda":
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)

Expand Down
176 changes: 176 additions & 0 deletions vllm/attention/ops/triton_reshape_and_cache_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
import triton
import triton.language as tl

from vllm.platforms import current_platform


@triton.jit
def reshape_and_cache_kernel_flash(
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size]
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
slot_mapping_ptr, # [num_tokens]
k_scale, # float32
v_scale, # float32
# strides
key_stride: tl.int64,
value_stride: tl.int64,
block_stride: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size: tl.constexpr,
block_size: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
# tune parameters
TILE_SIZE: tl.constexpr,
):

token_idx = tl.program_id(axis=0)
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
# Padding token that should be ignored.
return

tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs

block_idx = slot_idx // block_size
block_offset = slot_idx % block_size

src_key_idx = token_idx * key_stride
src_value_idx = token_idx * value_stride

tgt_idx = block_idx * block_stride + block_offset * page_stride

# [TILE_SIZE]
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
mask=tile_pos < (num_heads * head_size))
if FP8_KV_CACHE:
if key_load.dtype.is_fp8():
key_tile = key_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty
key_tile = key_load / tl.load(k_scale)
else:
key_tile = key_load

# [TILE_SIZE]
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
mask=tile_pos < (num_heads * head_size))
if FP8_KV_CACHE:
if value_load.dtype.is_fp8():
value_tile = value_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the value_cache_ptr.dtype.element_ty
value_tile = value_load / tl.load(v_scale)
else:
value_tile = value_load

tl.store(
key_cache_ptr + tgt_idx + tile_pos,
key_tile,
mask=tile_pos < (num_heads * head_size),
)
tl.store(
value_cache_ptr + tgt_idx + tile_pos,
value_tile,
mask=tile_pos < (num_heads * head_size),
)
return


def triton_reshape_and_cache_flash(
key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size]
# [num_blocks, block_size, num_heads, head_size]
key_cache: torch.Tensor,
# [num_blocks, block_size, num_heads, head_size]
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, # [num_tokens]
kv_cache_dtype: str, # "auto", "fp8"
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
num_tokens = key.shape[0]
num_heads = key.shape[1]
head_size = key.shape[2]
block_size = key_cache.shape[1]
n = num_heads * head_size

key_stride = key.stride()[0]
value_stride = value.stride()[0]
block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1]

head_stride = key_cache.stride()[2]
assert head_stride == head_size, "only continous heads are supported"

assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
kv_cache_dtype.startswith("fp8") else key_cache.dtype

if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
"fp8"):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
"uint8 is not supported by triton reshape_and_cache_flash"

FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
torch.float8_e4m3fnuz], \
"unsupported dtype of KV cache tensor, got "\
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."

# heuristics instead of autotuning
TILE_SIZE = min(2048, triton.next_power_of_2(n))
if torch.version.hip:
num_stages = 4
num_warps = 8
else: # cuda
num_stages = 10
num_warps = 16
if torch.cuda.get_device_capability(key.device)[0] < 9:
TILE_SIZE = min(512, TILE_SIZE)

# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))

reshape_and_cache_kernel_flash[grid](
key_ptr=key,
value_ptr=value,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
slot_mapping_ptr=slot_mapping,
k_scale=k_scale,
v_scale=v_scale,
# strides
key_stride=key_stride,
value_stride=value_stride,
block_stride=block_stride,
page_stride=page_stride,
num_heads=num_heads,
head_size=head_size,
block_size=block_size,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
TILE_SIZE=TILE_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)
Loading