Skip to content
7 changes: 7 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
)

parser.add_argument(
"--use_attention_sink",
default=None,
type=str,
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>,<batch_eviction_size>', e.g., '4,2044,1024'.",
)

parser.add_argument(
"--output_prune_map",
default=None,
Expand Down
19 changes: 19 additions & 0 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,25 @@ def __init__(self, **kwargs):

sanitize_checkpoint_from_pre_quantization(checkpoint)

if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink:
from .source_transformation.attention_sink import enable_attention_sink

attention_sink_params = self.args.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert self.args.max_seq_length == sink_size + window_size

self.model_ = enable_attention_sink(
module=self.model_,
params=model_args,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)

# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
Expand Down
118 changes: 117 additions & 1 deletion examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@
# Components for supporting Attention Sink. See
# https://arxiv.org/abs/2309.17453 for more details about Attention Sink.

import types
from typing import Optional

import torch

from executorch.examples.models.llama.llama_transformer import KVCache, ModelArgs, Rope
from executorch.examples.models.llama.llama_transformer import (
Attention,
KVCache,
ModelArgs,
Rope,
)
from executorch.examples.models.llama.rope import (
apply_rotary_emb_to_k,
hf_apply_rotary_emb_to_k,
)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter


class RopeWithAttentionSink(Rope):
Expand Down Expand Up @@ -206,3 +213,112 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
)
self.position_shift -= num_to_evict # pyre-ignore [8]
return self.position_shift


def attention_sink_forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
):
assert self.use_kv_cache
assert input_pos is not None

bsz, seqlen, _ = x.shape

# QKV
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# We need view_copy elimination
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# Prepare for space in KV cache and get position shift
position_shift = self.kv_cache.evict_tokens(input_pos, seqlen)

# RoPE relative positional embeddings with shifted position in KV cache
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

output = self.SDPA(input_pos + position_shift, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)


def _replace_rope(
module: torch.nn.Module, rope_with_attention_sink: RopeWithAttentionSink
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
return isinstance(child, Rope)

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
return rope_with_attention_sink

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def _replace_attention(
module: torch.nn.Module,
rope_with_attention_sink: RopeWithAttentionSink,
sink_size: int,
window_size: int,
eviction_batch_size: int,
):
for _, child_module in module._modules.items():
if len(list(child_module.children())) > 0: # pyre-ignore [16]
_replace_attention(
module=child_module, # pyre-ignore [6]
rope_with_attention_sink=rope_with_attention_sink,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)

if isinstance(child_module, Attention):
kv_cache = child_module.kv_cache
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
n_heads=kv_cache.n_heads,
head_dim=kv_cache.head_dim,
transpose_cache=kv_cache.transpose_cache,
enable_dynamic_shape=kv_cache.enable_dynamic_shape,
rope=rope_with_attention_sink,
max_batch_size=kv_cache.max_batch_size,
window_size=window_size,
sink_size=sink_size,
eviction_batch_size=eviction_batch_size,
dtype=kv_cache.k_cache.dtype,
)
child_module.kv_cache = kv_cache_with_attention_sink
child_module.SDPA.kv_cache = kv_cache_with_attention_sink
child_module.forward = types.MethodType( # pyre-ignore
attention_sink_forward, child_module
)


def enable_attention_sink(
module: torch.nn.Module,
params: ModelArgs,
sink_size: int,
window_size: int,
eviction_batch_size: int,
) -> torch.nn.Module:
"""
Transform the model to be able to run inference with Attention Sink.
There mainly three steps:
- Replace Rope with RopeWithAttentionSink
- Replace Attention's KVCache with KVCacheWithAttentionSink, forward with attention_sink_forward
"""
rope_with_attention_sink = RopeWithAttentionSink(
params=params,
window_size=window_size,
sink_size=sink_size,
eviction_batch_size=eviction_batch_size,
)
_replace_rope(module, rope_with_attention_sink)
_replace_attention(
module=module,
rope_with_attention_sink=rope_with_attention_sink,
sink_size=sink_size,
window_size=window_size,
eviction_batch_size=eviction_batch_size,
)
return module