diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index 74e14076b37..eee4aacf44d 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -246,7 +246,6 @@ def forward( # x has shape [b, s_x, d] # y has shape [b, s_y, d] b, s_x, _ = x.shape - s_y = y.shape[1] if y is not None else 0 # q has shape [b, s_x, num_heads * head_dim] q = self.q_proj(x) @@ -263,16 +262,9 @@ def forward( if self.q_norm is not None: q = self.q_norm(q) - if y is None: - if self.kv_cache is None: - raise ValueError( - "Must provide y input or use kv_cache to enable streaming decoding" - ) - k = self.kv_cache.k_cache - v = self.kv_cache.v_cache - else: + def calculate_kv(y): # Update k and v shape, positional embeddings, and normalization - + s_y = y.shape[1] # k has shape [b, s_y, num_kv_heads * head_dim] # v has shape [b, s_y, num_kv_heads * head_dim] k = self.k_proj(y) @@ -288,12 +280,37 @@ def forward( # Normalize k if self.k_norm is not None: k = self.k_norm(k) + return k, v + + def true_fn(y): + kv_cache = self.kv_cache.clone() + return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos + + def false_fn(y): + k, v = calculate_kv(y) + kv_cache = self.kv_cache.clone() + kv_cache.update(k, v) + return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos + # If kv cache is None, we expect y to be provided + if self.kv_cache is None: + assert ( + y is not None + ), "Must provide y input or use kv_cache to enable streaming decoding" + k, v = calculate_kv(y) + else: + # Expecting the k, v returning here to be the same size of self.kv_cache + # In eager, we expect this predicate to specialize. In export, this will + # become a SymBool so it's not specialized. + k, v, cache_pos = torch.cond( + torch.isnan(y).all().item(), true_fn, false_fn, (y,) + ) # Update key-value cache - if self.kv_cache is not None and self.cache_enabled: - k, v = self.kv_cache.update(k, v) + self.kv_cache.k_cache.copy_(k) + self.kv_cache.v_cache.copy_(v) + self.kv_cache.cache_pos.copy_(cache_pos) - output = self._sdpa(q, k, v, b, s_x) + output = self._sdpa(q, k, v, b, s_x, mask=mask) return self.output_proj(output) diff --git a/extension/llm/modules/kv_cache.py b/extension/llm/modules/kv_cache.py index eb95cab0838..db940bca3f8 100644 --- a/extension/llm/modules/kv_cache.py +++ b/extension/llm/modules/kv_cache.py @@ -127,3 +127,22 @@ def update( self.cache_pos.add_(seq_len) return k_out, v_out + + def clone(self) -> "KVCache": + """Create a clone of the KVCache.""" + if self.transpose_cache: + num_kv_heads = self.k_cache.shape[1] + else: + num_kv_heads = self.k_cache.shape[2] + clone = KVCache( + batch_size=self.batch_size, + max_seq_len=self.max_seq_len, + num_kv_heads=num_kv_heads, + head_dim=self.k_cache.shape[3], + dtype=self.k_cache.dtype, + transpose_cache=self.transpose_cache, + ) + clone.k_cache.copy_(self.k_cache) + clone.v_cache.copy_(self.v_cache) + clone.cache_pos.copy_(self.cache_pos) + return clone diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index bd0c44d8b5f..f4e4b8c670c 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -27,7 +27,7 @@ def setUp(self): torch.manual_seed(0) # Constants self.embed_dim = 2048 - self.num_heads = 32 + self.num_heads = 8 self.num_kv_heads = 8 self.head_dim = 64 self.max_seq_len = 128 @@ -41,10 +41,14 @@ def setUp(self): self.k_proj = torch.nn.Linear( self.embed_dim, self.num_kv_heads * self.head_dim, bias=False ) + self.k_proj.weight.requires_grad = False self.v_proj = torch.nn.Linear( self.embed_dim, self.num_kv_heads * self.head_dim, bias=False ) - self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj.weight.requires_grad = False + self.output_proj = torch.nn.Linear( + self.num_heads * self.head_dim, self.embed_dim, bias=False + ) self.pos_embeddings = Llama3ScaledRoPE( dim=self.head_dim, max_seq_len=self.max_seq_len, @@ -90,6 +94,12 @@ def setUp(self): {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC}, {0: torch.export.Dim.STATIC, 1: seq_len_dim}, ) + self.causal_mask = torch.tril( + torch.ones( + size=(self.max_seq_len, self.max_seq_len), + dtype=torch.bool, + ) + ) def test_attention_eager(self): et_res = self.et_mha(self.x, self.x) # Self attention. @@ -195,3 +205,35 @@ def test_attention_executorch(self): tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) assert_close(et_res[0], tt_res) + + def test_attention_torch_cond_eager(self): + # Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition. + # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan. + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) + + # mask + mask = self.causal_mask[self.input_pos, :] + # First run + et_res = self.et_mha( + self.x, self.x, mask=mask, input_pos=self.input_pos + ) # Self attention with input pos. + tt_res = self.tt_mha( + self.x, self.x, mask=mask, input_pos=self.input_pos + ) # Self attention with input pos. + + self.assertTrue(torch.allclose(et_res, tt_res)) + + # Second run test kv cache read. Input pos is [10, 11, ..., 19] + next_input_pos = torch.arange(10, 20).unsqueeze(0) + + empty_y = torch.full_like(self.x, torch.nan) + mask = self.causal_mask[next_input_pos, :] + et_res = self.et_mha( + self.x, empty_y, mask=mask, input_pos=next_input_pos + ) # Self attention with input pos. + tt_res = self.tt_mha( + self.x, None, mask=mask, input_pos=next_input_pos + ) # Self attention with input pos. + + assert_close(et_res, tt_res)