Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tests/runner/test_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
import torch
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionType
Expand Down Expand Up @@ -200,14 +201,15 @@ def test_insert_request_with_kv_cache(self):
np.testing.assert_array_equal(updated_block_content,
expected_padded_slice)

def test_get_kv_cache_spec_with_compilation_cfg(self):
@pytest.mark.parametrize("num_kv_heads", [16, 32])
@pytest.mark.parametrize("head_size", [64, 100, 200])
def test_get_kv_cache_spec_with_compilation_cfg(self, num_kv_heads,
head_size):
# tests we create kv cache spec from compilation config
# create a static forward context with
# 10 full attention layers +
# 10 sliding window attention layers
# 1 layer with shared kv cache.
num_kv_heads = 16
head_size = 128
attn_type = AttentionType.DECODER
sliding_window = 10
static_forward_context = {}
Expand Down
4 changes: 4 additions & 0 deletions tpu_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:

def get_padded_head_dim(head_dim: int) -> int:
"""Pads head_dim up to the nearest multiple of 128 for kernel performance."""
# When head_dim == 64, we use kernel specificly optimized for it which does
# not require any padding.
if head_dim == 64:
return 64
return (head_dim + 127) // 128 * 128


Expand Down