Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 78 additions & 13 deletions extension/llm/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,17 +392,50 @@ def setUp(self):
self.max_seq_len = 2048
self.setup_caches()

def _scale_tensor(self, tensor, min_value, max_value, scale=True):
normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())

scaled_tensor = normalized_tensor * (max_value - min_value) + min_value

return scaled_tensor if scale else tensor

def _test_sdpa_common(
self, n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len=1
self,
n_heads_kv,
n_heads_q,
head_dim,
max_seq_len,
seq_len,
next_iter_seq_len=1,
scale_tensors=False,
):
# Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests
tensor_scale_max = 20
tensor_scale_min = -20
self.n_heads_kv = n_heads_kv
self.n_heads_q = n_heads_q
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.setup_caches()
q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
q = self._scale_tensor(
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
k = self._scale_tensor(
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
v = self._scale_tensor(
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)

start_pos = 0
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]
Expand All @@ -412,11 +445,27 @@ def _test_sdpa_common(
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))
self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))

q = self._scale_tensor(
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
k = self._scale_tensor(
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
v = self._scale_tensor(
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)

q = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim))
k = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim))
v = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim))
start_pos = seq_len
seq_len = q.size(1)
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
Expand All @@ -427,7 +476,7 @@ def _test_sdpa_common(
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))
self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))


class SDPATestForLargeSeqLength(SDPATestCommon):
Expand All @@ -438,7 +487,9 @@ def test_sdpa_with_cache_seq_len_130(self):
head_dim = 128
max_seq_len = 2048
seq_len = 130
self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len)
self._test_sdpa_common(
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True
)

def test_sdpa_with_cache_seq_len_small(self):
n_heads_kv = 4
Expand All @@ -462,7 +513,9 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
head_dim = 128
max_seq_len = 2048
seq_len = 130
self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len)
self._test_sdpa_common(
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True
)

def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
n_heads_kv = 16
Expand All @@ -483,7 +536,13 @@ def test_sdpa_with_cache_seq_len_130(self):
seq_len = 130
next_iter_seq_len = 17
self._test_sdpa_common(
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
n_heads_kv,
n_heads_q,
head_dim,
max_seq_len,
seq_len,
next_iter_seq_len,
True,
)

def test_sdpa_with_cache_seq_len_llava_example(self):
Expand All @@ -505,7 +564,13 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
seq_len = 130
next_iter_seq_len = 33
self._test_sdpa_common(
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
n_heads_kv,
n_heads_q,
head_dim,
max_seq_len,
seq_len,
next_iter_seq_len,
True,
)

def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
Expand Down