diff --git a/tests/runner/test_kv_cache_manager.py b/tests/runner/test_kv_cache_manager.py index 72df8501a..33beec81a 100644 --- a/tests/runner/test_kv_cache_manager.py +++ b/tests/runner/test_kv_cache_manager.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp import numpy as np +import pytest import torch from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionType @@ -200,14 +201,15 @@ def test_insert_request_with_kv_cache(self): np.testing.assert_array_equal(updated_block_content, expected_padded_slice) - def test_get_kv_cache_spec_with_compilation_cfg(self): + @pytest.mark.parametrize("num_kv_heads", [16, 32]) + @pytest.mark.parametrize("head_size", [64, 100, 200]) + def test_get_kv_cache_spec_with_compilation_cfg(self, num_kv_heads, + head_size): # tests we create kv cache spec from compilation config # create a static forward context with # 10 full attention layers + # 10 sliding window attention layers # 1 layer with shared kv cache. - num_kv_heads = 16 - head_size = 128 attn_type = AttentionType.DECODER sliding_window = 10 static_forward_context = {} diff --git a/tpu_inference/utils.py b/tpu_inference/utils.py index 18a0a8b24..08cb23ae5 100644 --- a/tpu_inference/utils.py +++ b/tpu_inference/utils.py @@ -150,6 +150,10 @@ def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]: def get_padded_head_dim(head_dim: int) -> int: """Pads head_dim up to the nearest multiple of 128 for kernel performance.""" + # When head_dim == 64, we use kernel specificly optimized for it which does + # not require any padding. + if head_dim == 64: + return 64 return (head_dim + 127) // 128 * 128