diff --git a/extension/llm/modules/kv_cache.py b/extension/llm/modules/kv_cache.py index 827078a40a8..eb95cab0838 100644 --- a/extension/llm/modules/kv_cache.py +++ b/extension/llm/modules/kv_cache.py @@ -105,9 +105,8 @@ def update( f", but found new key tensors with batch size {k_val.shape[0]}!" ) - assert ( - self.cache_pos[0] + seq_len - ) <= self.max_seq_len, f"self.cache_pos[0]: {self.cache_pos[0]} + seq_len: {seq_len} > self.max_seq_len: {self.max_seq_len}" + assert (self.cache_pos[0] + seq_len) <= self.max_seq_len + k_out = self.k_cache v_out = self.v_cache diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 565e8c67d75..70267eb7c41 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os +import tempfile import unittest import torch @@ -13,6 +15,7 @@ MultiHeadAttention as ETMultiHeadAttention, ) from executorch.runtime import Runtime +from torch._inductor.package import load_package, package_aoti from torch.testing import assert_close from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention @@ -130,34 +133,62 @@ def test_attention_eager(self): def test_attention_export(self): # Self attention. - et_mha_ep = torch.export.export( - self.et_mha, - (self.x, self.x), - kwargs={"input_pos": self.input_pos}, - dynamic_shapes=self.dynamic_shapes, - ) + + # test with kv cache + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + with torch.no_grad(): + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs={"input_pos": self.input_pos}, + dynamic_shapes=self.dynamic_shapes, + ) et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) assert_close(et_res, tt_res) - # TODO: KV cache. - def test_attention_aoti(self): - # TODO. - pass + # Self attention. + + # test with kv cache + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + with torch.no_grad(): + so = torch._export.aot_compile( + self.et_mha, + args=(self.x, self.x), + kwargs={"input_pos": self.input_pos}, + options={"aot_inductor.package": True}, + dynamic_shapes=self.dynamic_shapes, + ) + with tempfile.TemporaryDirectory() as tempdir: + path = package_aoti(os.path.join(tempdir, "mha.pt2"), so) + mha_aoti = load_package(path) + + aoti_res = mha_aoti(self.x, self.x, input_pos=self.input_pos) + tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) + assert_close(aoti_res, tt_res) def test_attention_executorch(self): # Self attention. - et_mha_ep = torch.export.export( - self.et_mha, - (self.x, self.x), - kwargs={"input_pos": self.input_pos}, - dynamic_shapes=self.dynamic_shapes, - ) + # TODO: Fix kv cache + # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + + with torch.no_grad(): + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs={"input_pos": self.input_pos}, + dynamic_shapes=self.dynamic_shapes, + ) et_program = to_edge( et_mha_ep, - compile_config=EdgeCompileConfig(), + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg] + ), ).to_executorch() runtime = Runtime.get() program = runtime.load_program(et_program.buffer) @@ -166,5 +197,3 @@ def test_attention_executorch(self): tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) assert_close(et_res[0], tt_res) - - # TODO: KV cache.