diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 0bb168bdadb..e8a53a41312 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -700,10 +700,23 @@ void update_cache( const Tensor& cache, int64_t start_pos, int64_t seq_length) { // NOLINT: unused parameter 'seq_length' + // 1) Cache shape should be [bs, max_seq_len, num heads, head dim] + // 2) projected_value shape should be [bs, seq_len, num heads, head dim] + // 3) We're updating the cache with projected_value, at position start_pos + + ET_CHECK_MSG( + projected_value.size(0) == cache.size(0), + "projected_value batch size should be equal to the cache batch size."); + ET_CHECK_MSG( + projected_value.size(2) == cache.size(2), + "projected_value number of heads should be equal to the cache number of heads."); ET_CHECK_MSG( - projected_value.size(0) == 1, - "projected_value must have batch size of 1"); - ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1"); + projected_value.size(3) == cache.size(3), + "projected_value embedding dimension should be equal to the cache embedding dimension."); + ET_CHECK_MSG( + projected_value.element_size() == cache.element_size(), + "projected_value data type size should be equal to the cache data type size."); + ET_CHECK_MSG( is_contiguous_dim_order( projected_value.dim_order().data(), projected_value.dim()), @@ -714,16 +727,31 @@ void update_cache( ET_CHECK_MSG(projected_value_data != nullptr, "projected_value data is null"); ET_CHECK_MSG(cache_data, "cache data is null"); - auto strides = cache.strides(); - exec_aten::StridesType seq_dim_stride = strides[1]; - exec_aten::SizesType pos_offset = start_pos * seq_dim_stride; - exec_aten::SizesType pos_offset_bytes = - pos_offset * projected_value.element_size(); - exec_aten::SizesType num_bytes = - projected_value.numel() * projected_value.element_size(); - // NOLINTNEXTLINE - std::memcpy( - (uint8_t*)cache_data + pos_offset_bytes, projected_value_data, num_bytes); + auto cache_strides = cache.strides(); + exec_aten::StridesType cache_batch_dim_stride = cache_strides[0]; + exec_aten::StridesType cache_seq_dim_stride = cache_strides[1]; + + auto value_strides = projected_value.strides(); + exec_aten::StridesType value_batch_dim_stride = value_strides[0]; + + exec_aten::SizesType num_bytes_to_copy = + (projected_value.numel() / projected_value.size(0)) * + projected_value.element_size(); + + for (int64_t batch_line = 0; batch_line < projected_value.size(0); + ++batch_line) { + exec_aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + start_pos * cache_seq_dim_stride) * + cache.element_size(); + exec_aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride) * cache.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)projected_value_data + value_pos_offset, + num_bytes_to_copy); + } } } // anonymous namespace @@ -859,6 +887,8 @@ Tensor& sdpa_with_kv_cache_out( sliced_key_dim_order.data(), util::kKVDim, sliced_key_strides.data()); + // since the cache is sliced, the batch stride needs to stay the same. + sliced_key_strides[0] = key_cache.strides()[0]; void* key_cache_data = key_cache.mutable_data_ptr(); TensorImpl k_impl = TensorImpl( key_cache.scalar_type(), @@ -883,6 +913,8 @@ Tensor& sdpa_with_kv_cache_out( sliced_value_dim_order.data(), util::kKVDim, sliced_value_strides.data()); + // since the cache is sliced, the batch stride needs to stay the same. + sliced_value_strides[0] = value_cache.strides()[0]; void* value_cache_data = value_cache.mutable_data_ptr(); TensorImpl value_impl = TensorImpl( value_cache.scalar_type(), diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index dd63c68f138..bfd64cb8975 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -373,10 +373,10 @@ class SDPATestCommon(unittest.TestCase): def setup_caches(self): self.k_cache = torch.zeros( - (1, self.max_seq_len, self.n_heads_kv, self.head_dim) + (self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim) ) self.v_cache = torch.zeros( - (1, self.max_seq_len, self.n_heads_kv, self.head_dim) + (self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim) ) self.mask = torch.full( (self.max_seq_len, self.max_seq_len), @@ -386,6 +386,7 @@ def setup_caches(self): def setUp(self): torch.manual_seed(42) + self.n_batch = 5 self.n_heads_kv = 32 self.n_heads_q = 32 self.head_dim = 128 @@ -410,27 +411,27 @@ def _test_sdpa_common( scale_tensors=False, ): # Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests - tensor_scale_max = 20 - tensor_scale_min = -20 + tensor_scale_max = 15 + tensor_scale_min = -15 self.n_heads_kv = n_heads_kv self.n_heads_q = n_heads_q self.head_dim = head_dim self.max_seq_len = max_seq_len self.setup_caches() q = self._scale_tensor( - torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)), + torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)), tensor_scale_max, tensor_scale_min, scale_tensors, ) k = self._scale_tensor( - torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)), + torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)), tensor_scale_max, tensor_scale_min, scale_tensors, ) v = self._scale_tensor( - torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)), + torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)), tensor_scale_max, tensor_scale_min, scale_tensors, @@ -448,19 +449,25 @@ def _test_sdpa_common( self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6)) q = self._scale_tensor( - torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)), + torch.rand( + (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim) + ), tensor_scale_max, tensor_scale_min, scale_tensors, ) k = self._scale_tensor( - torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)), + torch.rand( + (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim) + ), tensor_scale_max, tensor_scale_min, scale_tensors, ) v = self._scale_tensor( - torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)), + torch.rand( + (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim) + ), tensor_scale_max, tensor_scale_min, scale_tensors,