diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 395fce85613..b42371dc090 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -242,7 +242,8 @@ def __init__( config: ModelArgs, input_len: int, cache_lens: Union[int, List[int]], - dtype=torch.float32, + batch_size: int = 1, + dtype: torch.dtype = torch.float32, style: str = "shift_pointer", mask_val: float = float("-inf"), ): @@ -266,7 +267,10 @@ def __init__( if split_mha: self.k_caches = { StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( - 1, cache_lens[layer_id], none_throws(config.head_dim), dtype=dtype + batch_size, + cache_lens[layer_id], + none_throws(config.head_dim), + dtype=dtype, ) for layer_id in range(config.n_layers) for head_id in range(none_throws(config.n_kv_heads)) @@ -274,7 +278,10 @@ def __init__( } self.v_caches = { StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( - 1, cache_lens[layer_id], none_throws(config.head_dim), dtype=dtype + batch_size, + cache_lens[layer_id], + none_throws(config.head_dim), + dtype=dtype, ) for layer_id in range(config.n_layers) for head_id in range(none_throws(config.n_kv_heads)) @@ -283,7 +290,7 @@ def __init__( else: self.k_caches = { StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros( - 1, + batch_size, none_throws(config.n_kv_heads), cache_lens[layer_id], none_throws(config.head_dim), @@ -293,7 +300,7 @@ def __init__( } self.v_caches = { StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros( - 1, + batch_size, none_throws(config.n_kv_heads), cache_lens[layer_id], none_throws(config.head_dim), @@ -323,7 +330,7 @@ def reset(self): def prefill( self, model: Callable[..., Any], - tokens: List[int], + tokens: Union[List[int], torch.Tensor], ) -> torch.Tensor: if self.cache_full: raise RuntimeError("KV cache is full.") @@ -336,10 +343,13 @@ def prefill( ) ) + if isinstance(tokens, list): + tokens = torch.tensor([tokens], dtype=torch.int32) + logits = None all_logits = None - for i in range(0, len(tokens), self.input_len): - logits = self._run_once(model, tokens[i : i + self.input_len])[0] + for i in range(0, tokens.size(1), self.input_len): + logits = self._run_once(model, tokens[:, i : i + self.input_len])[0] if self.config.generate_full_logits: if all_logits is None: all_logits = logits @@ -347,7 +357,7 @@ def prefill( all_logits = torch.cat([all_logits, logits], dim=1) if self.config.generate_full_logits: - return all_logits[:, : len(tokens), :] + return all_logits[:, : tokens.size(1), :] return logits @@ -510,15 +520,16 @@ def lookahead_decode( # noqa: C901 def _run_once( self, model: Callable[..., Any], - tokens: List[int], + tokens: Union[List[int], torch.Tensor], non_padded_len: Optional[int] = None, freqs_cos_override: Optional[torch.Tensor] = None, freqs_sin_override: Optional[torch.Tensor] = None, ): - n_tokens = len(tokens) + if isinstance(tokens, list): + tokens = torch.tensor([tokens], dtype=torch.int32) + n_tokens = tokens.size(1) if n_tokens < self.input_len: - tokens += [0] * (self.input_len - n_tokens) - tokens = torch.tensor([tokens], dtype=torch.int32) # pyre-ignore[9] + tokens = F.pad(tokens, (0, self.input_len - n_tokens)) if freqs_cos_override is None: freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len] if freqs_sin_override is None: diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 2461732db5a..8786c70da11 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -195,7 +195,7 @@ def test_with_style(style): test_with_style("shift_pointer") test_with_style("smart_mask") - def _get_test_transformers(self, config, attention_type="static"): + def _get_test_transformers(self, config, attention_type="static", use_conv2d=False): mha_transformer = construct_transformer(config).eval() config = copy.copy(config) @@ -207,6 +207,8 @@ def _get_test_transformers(self, config, attention_type="static"): ): static_layer.attention.load_weights_from_attention_mha(mha_layer.attention) static_layer.attention.adopt_hf_rope() + if use_conv2d: + static_layer.linear_to_conv2d() config.use_hf_rope = True return mha_transformer, static_transformer, config @@ -220,7 +222,8 @@ def test_within_transformer(self): n_layers=4, vocab_size=128, ) - x = torch.randint(config.vocab_size, (1, config.max_seq_len)) + batch_size = 3 + x = torch.randint(config.vocab_size, (batch_size, config.max_seq_len)) n_chunks = 3 chunk_len = config.max_seq_len // n_chunks cache_len = config.max_seq_len - chunk_len @@ -235,13 +238,13 @@ def test(style, attention_type): expected = mha_transformer(x) mgr = StaticAttentionIOManager( - static_config, chunk_len, cache_len, style=style + static_config, chunk_len, cache_len, style=style, batch_size=batch_size ) ys = [] for i in range(n_chunks): y_i = mgr.prefill( static_transformer, - x[0][i * chunk_len : (i + 1) * chunk_len].tolist(), + x[:, i * chunk_len : (i + 1) * chunk_len], ) ys.append(y_i) @@ -300,3 +303,59 @@ def test_lookahead_decode(self): ngram_caches=ngram_caches, ) self.assertEqual(lookahead_output[: len(ref_output)], ref_output) + + def test_batched_export_with_backprop(self): + config = ModelArgs( + dim=64, + n_heads=4, + n_kv_heads=2, + max_seq_len=128, + n_layers=4, + vocab_size=128, + generate_full_logits=True, + ) + _, static_transformer, static_config = self._get_test_transformers(config) + batch_size = 4 + input_len = 32 + cache_len = static_config.max_seq_len - input_len + mgr = StaticAttentionIOManager( + static_config, input_len, cache_len, batch_size=batch_size + ) + example_inputs = ( + torch.zeros(batch_size, input_len), + { + "masks": mgr.masks, + "freqs_cos_override": mgr.freqs_cos[:input_len], + "freqs_sin_override": mgr.freqs_sin[:input_len], + "in_cache_state": (mgr.k_caches, mgr.v_caches), + }, + ) + batched_gm = torch.export.export(static_transformer, example_inputs).module() + + # Test backprop + for _ in range(10): + x = torch.randint(config.vocab_size, (batch_size, input_len)) + y = mgr.prefill(batched_gm, x) + loss = torch.nn.functional.cross_entropy( + y, torch.rand(batch_size, input_len, config.vocab_size) + ) + loss.backward() + mgr.reset() + + # Test loading state dict into a non batched graph for inference + mgr = StaticAttentionIOManager( + static_config, input_len, cache_len, batch_size=1 + ) + example_inputs = ( + torch.zeros(1, input_len), + { + "masks": mgr.masks, + "freqs_cos_override": mgr.freqs_cos[:input_len], + "freqs_sin_override": mgr.freqs_sin[:input_len], + "in_cache_state": (mgr.k_caches, mgr.v_caches), + }, + ) + non_batched_gm = torch.export.export( + static_transformer, example_inputs + ).module() + non_batched_gm.load_state_dict(batched_gm.state_dict())