From 3897f9ee2944008c7449595345e1c46f0c0f731b Mon Sep 17 00:00:00 2001 From: Emilian Stoimenov Date: Fri, 30 Aug 2024 15:42:21 -0700 Subject: [PATCH] Changing sdpa_with_kv_cache tests to use a wider dynamic range. (#4892) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4892 Changing some of the sdpa_with_kv_cache longer context tests to have a wider dynamici range and help verify correctness in those conditions too. Reviewed By: tarun292 Differential Revision: D61403179 --- .../llm/custom_ops/test_sdpa_with_kv_cache.py | 91 ++++++++++++++++--- 1 file changed, 78 insertions(+), 13 deletions(-) diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index a1b36e688f9..dd63c68f138 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -392,17 +392,50 @@ def setUp(self): self.max_seq_len = 2048 self.setup_caches() + def _scale_tensor(self, tensor, min_value, max_value, scale=True): + normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) + + scaled_tensor = normalized_tensor * (max_value - min_value) + min_value + + return scaled_tensor if scale else tensor + def _test_sdpa_common( - self, n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len=1 + self, + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + seq_len, + next_iter_seq_len=1, + scale_tensors=False, ): + # Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests + tensor_scale_max = 20 + tensor_scale_min = -20 self.n_heads_kv = n_heads_kv self.n_heads_q = n_heads_q self.head_dim = head_dim self.max_seq_len = max_seq_len self.setup_caches() - q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)) - k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)) - v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)) + q = self._scale_tensor( + torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)), + tensor_scale_max, + tensor_scale_min, + scale_tensors, + ) + k = self._scale_tensor( + torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)), + tensor_scale_max, + tensor_scale_min, + scale_tensors, + ) + v = self._scale_tensor( + torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)), + tensor_scale_max, + tensor_scale_min, + scale_tensors, + ) + start_pos = 0 attn_mask = self.mask[start_pos : start_pos + seq_len, :] attn_mask = attn_mask[:, : start_pos + seq_len] @@ -412,11 +445,27 @@ def _test_sdpa_common( op_output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True ) - self.assertTrue(torch.allclose(ref_output, op_output)) + self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6)) + + q = self._scale_tensor( + torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)), + tensor_scale_max, + tensor_scale_min, + scale_tensors, + ) + k = self._scale_tensor( + torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)), + tensor_scale_max, + tensor_scale_min, + scale_tensors, + ) + v = self._scale_tensor( + torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)), + tensor_scale_max, + tensor_scale_min, + scale_tensors, + ) - q = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)) - k = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)) - v = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)) start_pos = seq_len seq_len = q.size(1) attn_mask = self.mask[start_pos : start_pos + seq_len, :] @@ -427,7 +476,7 @@ def _test_sdpa_common( op_output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True ) - self.assertTrue(torch.allclose(ref_output, op_output)) + self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6)) class SDPATestForLargeSeqLength(SDPATestCommon): @@ -438,7 +487,9 @@ def test_sdpa_with_cache_seq_len_130(self): head_dim = 128 max_seq_len = 2048 seq_len = 130 - self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len) + self._test_sdpa_common( + n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True + ) def test_sdpa_with_cache_seq_len_small(self): n_heads_kv = 4 @@ -462,7 +513,9 @@ def test_sdpa_with_cache_seq_len_130_gqa(self): head_dim = 128 max_seq_len = 2048 seq_len = 130 - self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len) + self._test_sdpa_common( + n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True + ) def test_sdpa_with_cache_seq_len_llava_example_gqa(self): n_heads_kv = 16 @@ -483,7 +536,13 @@ def test_sdpa_with_cache_seq_len_130(self): seq_len = 130 next_iter_seq_len = 17 self._test_sdpa_common( - n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + seq_len, + next_iter_seq_len, + True, ) def test_sdpa_with_cache_seq_len_llava_example(self): @@ -505,7 +564,13 @@ def test_sdpa_with_cache_seq_len_130_gqa(self): seq_len = 130 next_iter_seq_len = 33 self._test_sdpa_common( - n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len + n_heads_kv, + n_heads_q, + head_dim, + max_seq_len, + seq_len, + next_iter_seq_len, + True, ) def test_sdpa_with_cache_seq_len_llava_example_gqa(self):