Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def __init__(
config: ModelArgs,
input_len: int,
cache_lens: Union[int, List[int]],
dtype=torch.float32,
batch_size: int = 1,
dtype: torch.dtype = torch.float32,
style: str = "shift_pointer",
mask_val: float = float("-inf"),
):
Expand All @@ -266,15 +267,21 @@ def __init__(
if split_mha:
self.k_caches = {
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
1, cache_lens[layer_id], none_throws(config.head_dim), dtype=dtype
batch_size,
cache_lens[layer_id],
none_throws(config.head_dim),
dtype=dtype,
)
for layer_id in range(config.n_layers)
for head_id in range(none_throws(config.n_kv_heads))
if cache_lens[layer_id] > 0
}
self.v_caches = {
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
1, cache_lens[layer_id], none_throws(config.head_dim), dtype=dtype
batch_size,
cache_lens[layer_id],
none_throws(config.head_dim),
dtype=dtype,
)
for layer_id in range(config.n_layers)
for head_id in range(none_throws(config.n_kv_heads))
Expand All @@ -283,7 +290,7 @@ def __init__(
else:
self.k_caches = {
StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros(
1,
batch_size,
none_throws(config.n_kv_heads),
cache_lens[layer_id],
none_throws(config.head_dim),
Expand All @@ -293,7 +300,7 @@ def __init__(
}
self.v_caches = {
StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros(
1,
batch_size,
none_throws(config.n_kv_heads),
cache_lens[layer_id],
none_throws(config.head_dim),
Expand Down Expand Up @@ -323,7 +330,7 @@ def reset(self):
def prefill(
self,
model: Callable[..., Any],
tokens: List[int],
tokens: Union[List[int], torch.Tensor],
) -> torch.Tensor:
if self.cache_full:
raise RuntimeError("KV cache is full.")
Expand All @@ -336,18 +343,21 @@ def prefill(
)
)

if isinstance(tokens, list):
tokens = torch.tensor([tokens], dtype=torch.int32)

logits = None
all_logits = None
for i in range(0, len(tokens), self.input_len):
logits = self._run_once(model, tokens[i : i + self.input_len])[0]
for i in range(0, tokens.size(1), self.input_len):
logits = self._run_once(model, tokens[:, i : i + self.input_len])[0]
if self.config.generate_full_logits:
if all_logits is None:
all_logits = logits
else:
all_logits = torch.cat([all_logits, logits], dim=1)

if self.config.generate_full_logits:
return all_logits[:, : len(tokens), :]
return all_logits[:, : tokens.size(1), :]

return logits

Expand Down Expand Up @@ -510,15 +520,16 @@ def lookahead_decode( # noqa: C901
def _run_once(
self,
model: Callable[..., Any],
tokens: List[int],
tokens: Union[List[int], torch.Tensor],
non_padded_len: Optional[int] = None,
freqs_cos_override: Optional[torch.Tensor] = None,
freqs_sin_override: Optional[torch.Tensor] = None,
):
n_tokens = len(tokens)
if isinstance(tokens, list):
tokens = torch.tensor([tokens], dtype=torch.int32)
n_tokens = tokens.size(1)
if n_tokens < self.input_len:
tokens += [0] * (self.input_len - n_tokens)
tokens = torch.tensor([tokens], dtype=torch.int32) # pyre-ignore[9]
tokens = F.pad(tokens, (0, self.input_len - n_tokens))
if freqs_cos_override is None:
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
if freqs_sin_override is None:
Expand Down
67 changes: 63 additions & 4 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_with_style(style):
test_with_style("shift_pointer")
test_with_style("smart_mask")

def _get_test_transformers(self, config, attention_type="static"):
def _get_test_transformers(self, config, attention_type="static", use_conv2d=False):
mha_transformer = construct_transformer(config).eval()

config = copy.copy(config)
Expand All @@ -207,6 +207,8 @@ def _get_test_transformers(self, config, attention_type="static"):
):
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
static_layer.attention.adopt_hf_rope()
if use_conv2d:
static_layer.linear_to_conv2d()
config.use_hf_rope = True

return mha_transformer, static_transformer, config
Expand All @@ -220,7 +222,8 @@ def test_within_transformer(self):
n_layers=4,
vocab_size=128,
)
x = torch.randint(config.vocab_size, (1, config.max_seq_len))
batch_size = 3
x = torch.randint(config.vocab_size, (batch_size, config.max_seq_len))
n_chunks = 3
chunk_len = config.max_seq_len // n_chunks
cache_len = config.max_seq_len - chunk_len
Expand All @@ -235,13 +238,13 @@ def test(style, attention_type):
expected = mha_transformer(x)

mgr = StaticAttentionIOManager(
static_config, chunk_len, cache_len, style=style
static_config, chunk_len, cache_len, style=style, batch_size=batch_size
)
ys = []
for i in range(n_chunks):
y_i = mgr.prefill(
static_transformer,
x[0][i * chunk_len : (i + 1) * chunk_len].tolist(),
x[:, i * chunk_len : (i + 1) * chunk_len],
)
ys.append(y_i)

Expand Down Expand Up @@ -300,3 +303,59 @@ def test_lookahead_decode(self):
ngram_caches=ngram_caches,
)
self.assertEqual(lookahead_output[: len(ref_output)], ref_output)

def test_batched_export_with_backprop(self):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=128,
n_layers=4,
vocab_size=128,
generate_full_logits=True,
)
_, static_transformer, static_config = self._get_test_transformers(config)
batch_size = 4
input_len = 32
cache_len = static_config.max_seq_len - input_len
mgr = StaticAttentionIOManager(
static_config, input_len, cache_len, batch_size=batch_size
)
example_inputs = (
torch.zeros(batch_size, input_len),
{
"masks": mgr.masks,
"freqs_cos_override": mgr.freqs_cos[:input_len],
"freqs_sin_override": mgr.freqs_sin[:input_len],
"in_cache_state": (mgr.k_caches, mgr.v_caches),
},
)
batched_gm = torch.export.export(static_transformer, example_inputs).module()

# Test backprop
for _ in range(10):
x = torch.randint(config.vocab_size, (batch_size, input_len))
y = mgr.prefill(batched_gm, x)
loss = torch.nn.functional.cross_entropy(
y, torch.rand(batch_size, input_len, config.vocab_size)
)
loss.backward()
mgr.reset()

# Test loading state dict into a non batched graph for inference
mgr = StaticAttentionIOManager(
static_config, input_len, cache_len, batch_size=1
)
example_inputs = (
torch.zeros(1, input_len),
{
"masks": mgr.masks,
"freqs_cos_override": mgr.freqs_cos[:input_len],
"freqs_sin_override": mgr.freqs_sin[:input_len],
"in_cache_state": (mgr.k_caches, mgr.v_caches),
},
)
non_batched_gm = torch.export.export(
static_transformer, example_inputs
).module()
non_batched_gm.load_state_dict(batched_gm.state_dict())
Loading