In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_path = "./llama3.2"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

print(model)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.24s/it]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm




In [2]:
from transformers.models.llama.modeling_llama import (
    LlamaSdpaAttention,
    rotate_half,
    repeat_kv,
    LlamaDecoderLayer,
)
from typing import Optional, Tuple
from transformers.cache_utils import Cache
import torch.nn as nn
import math
from functools import partial


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    q_len = q.size(2)
    k_len = k.size(2)

    cos_len = cos.size(2)
    sin_len = sin.size(2)

    q_embed = (q * cos[:, :, 0 : min(q_len, cos_len), :]) + (
        rotate_half(q) * sin[:, :, 0 : min(q_len, sin_len), :]
    )
    k_embed = (k * cos[:, :, 0 : min(k_len, cos_len), :]) + (
        rotate_half(k) * sin[:, :, 0 : min(k_len, sin_len), :]
    )
    return q_embed, k_embed


class LlamaCrossAttention(LlamaSdpaAttention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[
            Tuple[torch.Tensor, torch.Tensor]
        ] = None,  # will become mandatory in v4.46
        memory_embeddings: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()
        _, kv_len, _ = memory_embeddings.size()

        query_states = self.q_proj(hidden_states)

        key_states = self.k_proj(memory_embeddings)
        value_states = self.v_proj(memory_embeddings)

        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, kv_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, kv_len, -1, self.head_dim).transpose(1, 2)

        position_ids = torch.arange(
            max(q_len, kv_len), dtype=torch.long, device=hidden_states.device
        )
        position_ids = position_ids.unsqueeze(0).expand(bsz, -1)
        cos, sin = self.rotary_emb(value_states, position_ids)

        query_states, key_states = apply_rotary_pos_emb(
            query_states, key_states, cos, sin
        )

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.layer_idx, cache_kwargs
            )

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        attn_weights = torch.matmul(
            query_states, key_states.transpose(2, 3)
        ) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(query_states.dtype)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.attention_dropout, training=self.training
        )
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

In [3]:
def replace_attention_with_cross_attention(model, memory_embeddings):
    layer_inclusion = [10, 11, 12, 13, 14, 15, 16, 17]
    for _, module in model.named_modules():
        if isinstance(module, LlamaDecoderLayer):
            if module.self_attn.layer_idx in layer_inclusion:
                cross_attention = LlamaCrossAttention(
                    module.self_attn.config, module.self_attn.layer_idx
                )
                cross_attention.load_state_dict(module.self_attn.state_dict())

                setattr(module, "self_attn", cross_attention)

                cross_attention.forward = partial(
                    cross_attention.forward,
                    memory_embeddings=module.input_layernorm(
                        memory_embeddings[module.self_attn.layer_idx].unsqueeze(0)
                    ),
                )

In [4]:
import numpy as np

facts = [
    "The Eiffel Tower is in the city of Seoul. The Eiffel Tower is the tallest building in the world. Glass is brittle."
]

with torch.no_grad():
    memory_embeddings = []
    attn_embeddings = []
    for fact in facts:
        inputs = tokenizer(fact, return_tensors="pt")
        outputs = model(**inputs, output_hidden_states=True, output_attentions=True)

        memory_embeddings.append(
            torch.stack([output.squeeze() for output in outputs.hidden_states[0:-1]])
        )
        attn_embeddings.append(outputs.attentions)

attn_variance = []

for i, attn_embedding in enumerate(attn_embeddings[0]):
    attn_variance.append(attn_embedding.var().item())

attn_variance = np.array(attn_variance)
sorted_indices = np.argsort(attn_variance)
top_k_indices = sorted_indices[-10:][::-1]

print(top_k_indices)



[ 2 22 21 20 25 18  3 24 17  1]


In [5]:
replace_attention_with_cross_attention(model, memory_embeddings[0])
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-9): 10 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
      (10-17): 8

In [6]:
inputs = tokenizer("The tallest building in the world is", return_tensors="pt")
output = model.generate(**inputs, max_length=50)

print(tokenizer.decode(output[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


The tallest building in the world is in the city of Seoul. The Seoul Tower is a 300 meter tall building in South Korea.
The Eiffel Tower is in Paris, France. The Eiffel Tower is in the city of Seoul
