In [8]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load the tokenizer as usual.
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

# Load the model in 8-bit mode for quantization.
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct",
    torch_dtype=torch.float16,    # Enable 8-bit quantization.
    device_map="auto"     # Automatically place model on available GPU(s).
)

# Example usage:
prompt = "Once upon a time"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Once upon a time, in a small village nestled in the rolling hills of Tuscany, there lived a young girl named Sophia. Sophia was a curious and adventurous child, with a heart full of wonder and a mind full of questions. She spent her days exploring the


In [1]:
import torch

In [9]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((3072,), eps=1e-05)
    (rotary_emb

In [4]:
import torch
import torch.nn.functional as F

def dynamic_token_replacement(embeddings, threshold=0.95):
    """
    Clusters token embeddings that have cosine similarity above the threshold
    and replaces tokens in a cluster with the representative embedding.

    Args:
        embeddings (Tensor): Tensor of shape (seq_len, head_dim)
        threshold (float): Similarity threshold for merging tokens.

    Returns:
        replaced (Tensor): Tensor of shape (seq_len, head_dim) with merged embeddings.
    """
    seq_len, head_dim = embeddings.shape
    replaced = embeddings.clone()
    used = torch.zeros(seq_len, dtype=torch.bool, device=embeddings.device)

    for i in range(seq_len):
        if used[i]:
            continue
        # For token i, find all subsequent tokens that are very similar
        for j in range(i + 1, seq_len):
            if not used[j]:
                sim = F.cosine_similarity(embeddings[i].unsqueeze(0), embeddings[j].unsqueeze(0))
                if sim.item() > threshold:
                    replaced[j] = embeddings[i]  # Replace token j with token i's embedding.
                    used[j] = True
    return replaced

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DynamicLlamaAttention(nn.Module):
    def __init__(self, orig_attn, threshold=0.95):
        """
        Wrap the original LlamaAttention module with dynamic token replacement
        for the keys and values.

        Args:
            orig_attn: The original LlamaAttention module (with q_proj, k_proj, v_proj, o_proj).
            threshold (float): Similarity threshold for dynamic token replacement.
        """
        super().__init__()
        self.threshold = threshold

        # Reuse the projection layers from the original attention.
        self.q_proj = orig_attn.q_proj  # maps 3072 -> 3072
        self.k_proj = orig_attn.k_proj  # maps 3072 -> 1024
        self.v_proj = orig_attn.v_proj  # maps 3072 -> 1024
        self.o_proj = orig_attn.o_proj  # maps back to 3072

        # Number of heads is assumed to be stored in the original module.
        self.num_heads = getattr(orig_attn, "num_heads", 32)
        # For q: hidden_size_q = 3072, so per-head dimension is:
        self.q_head_dim = 3072 // self.num_heads  # = 96
        # For k/v: hidden_size_k = 1024.
        self.k_dim = 1024  # keys and values are not split in multi-query attention.

        # To compute dot products, we need to project q's per-head vectors (96-d) to 1024-d.
        self.q_to_k = nn.Linear(self.q_head_dim, self.k_dim)

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        """
        Args:
            hidden_states (Tensor): shape (batch, seq_len, 3072)
            attention_mask (Tensor, optional): mask to add to attention scores.

        Returns:
            output (Tensor): shape (batch, seq_len, 3072)
        """
        batch, seq_len, _ = hidden_states.size()

        # Compute q, k, v projections.
        # q: (batch, seq_len, 3072)
        # k, v: (batch, seq_len, 1024)
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # Reshape q into heads.
        # q: (batch, seq_len, 3072) -> (batch, seq_len, num_heads, q_head_dim) -> (batch, num_heads, seq_len, q_head_dim)
        q = q.view(batch, seq_len, self.num_heads, self.q_head_dim).transpose(1, 2)
        # Project q per head to dimension matching k.
        # q_new: (batch, num_heads, seq_len, k_dim) where k_dim = 1024.
        q_new = self.q_to_k(q)

        # Process k and v for multi-query:
        # They are computed once per sequence: shape (batch, seq_len, 1024)
        # Unsqueeze to add a heads dimension and then repeat across heads.
        k = k.unsqueeze(1).expand(batch, self.num_heads, seq_len, self.k_dim)
        v = v.unsqueeze(1).expand(batch, self.num_heads, seq_len, self.k_dim)

        # Apply dynamic token replacement on k and v for each batch and head.
        # (This loop is not fully vectorized; for research purposes only.)
        new_k = []
        new_v = []
        for b in range(batch):
            k_heads = []
            v_heads = []
            for h in range(self.num_heads):
                # k[b, h]: shape (seq_len, k_dim)
                replaced_k = dynamic_token_replacement(k[b, h], threshold=self.threshold)
                replaced_v = dynamic_token_replacement(v[b, h], threshold=self.threshold)
                k_heads.append(replaced_k)
                v_heads.append(replaced_v)
            # Stack heads for this batch element: (num_heads, seq_len, k_dim)
            new_k.append(torch.stack(k_heads, dim=0))
            new_v.append(torch.stack(v_heads, dim=0))
        # Stack batch: new_k, new_v: (batch, num_heads, seq_len, k_dim)
        new_k = torch.stack(new_k, dim=0)
        new_v = torch.stack(new_v, dim=0)

        # Compute scaled dot-product attention.
        # q_new: (batch, num_heads, seq_len, k_dim)
        # new_k: (batch, num_heads, seq_len, k_dim) -> transpose last two dims for dot product.
        d_k = self.k_dim
        scores = torch.matmul(q_new, new_k.transpose(-2, -1)) / (d_k ** 0.5)
        if attention_mask is not None:
            scores = scores + attention_mask

        attn_weights = F.softmax(scores, dim=-1)
        # attn_output: (batch, num_heads, seq_len, k_dim)
        attn_output = torch.matmul(attn_weights, new_v)

        # Reassemble: transpose and reshape to (batch, seq_len, num_heads * k_dim)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.num_heads * self.k_dim)
        # Project back to hidden dimension (3072).
        output = self.o_proj(attn_output)
        return output

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load tokenizer and model.
model_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=None)



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# A simple dynamic token replacement function.
# If two tokens have cosine similarity above the threshold, replace the later one with the earlier.
def dynamic_token_replacement(embeddings, threshold=1.0):
    """
    Args:
        embeddings (Tensor): shape (seq_len, dim)
        threshold (float): similarity threshold; if >= 1.0, no merging is done.
    Returns:
        replaced (Tensor): tensor of shape (seq_len, dim) after merging.
    """
    # When threshold is 1.0 or more, return the original embeddings (i.e. disable merging)
    if threshold >= 1.0:
        return embeddings

    seq_len, dim = embeddings.shape
    replaced = embeddings.clone()
    used = torch.zeros(seq_len, dtype=torch.bool, device=embeddings.device)
    for i in range(seq_len):
        if used[i]:
            continue
        for j in range(i + 1, seq_len):
            if not used[j]:
                sim = F.cosine_similarity(embeddings[i].unsqueeze(0), embeddings[j].unsqueeze(0))
                if sim.item() > threshold:
                    # Replace token j's embedding with token i's embedding.
                    replaced[j] = embeddings[i]
                    used[j] = True
    return replaced

# A dynamic attention module that uses only the existing model projections.
class DynamicLlamaAttentionNoNew(nn.Module):
    def __init__(self, orig_attn, threshold=1.0):
        """
        This module wraps the original LlamaAttention module to add dynamic token replacement
        for the keys and values—merging redundant tokens along the sequence dimension—without
        adding any new learned parameters.

        Args:
            orig_attn: The original LlamaAttention module (with q_proj, k_proj, v_proj, o_proj).
            threshold (float): Cosine similarity threshold for merging tokens.
                Use threshold >= 1.0 to disable merging.
        """
        super().__init__()
        self.threshold = threshold

        # Use the original projection layers.
        self.q_proj = orig_attn.q_proj  # maps 3072 -> 3072
        self.k_proj = orig_attn.k_proj  # maps 3072 -> 1024
        self.v_proj = orig_attn.v_proj  # maps 3072 -> 1024
        self.o_proj = orig_attn.o_proj  # maps 3072 -> 3072

        # For Llama-3.2-3B, we assume the hidden size is 3072 and the key/value dimension is 1024.
        # In standard multi-head attention, one splits the hidden state into "num_heads" parts.
        # To match dimensions, we require: hidden_size / num_heads == key_dim.
        # Here, we assume that 3072 / num_heads == 1024, so num_heads = 3.
        self.num_heads = getattr(orig_attn, "num_heads", 3)
        self.hidden_size = 3072
        self.head_dim = self.hidden_size // self.num_heads  # should be 1024 in this configuration

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        """
        Args:
            hidden_states (Tensor): shape (batch, seq_len, 3072)
            attention_mask (Tensor, optional): attention mask.
        Returns:
            output (Tensor): shape (batch, seq_len, 3072)
            attn_weights (Tensor): attention weights.
        """
        batch, seq_len, _ = hidden_states.size()

        # Compute queries, keys, and values.
        q = self.q_proj(hidden_states)  # (batch, seq_len, 3072)
        k = self.k_proj(hidden_states)  # (batch, seq_len, 1024)
        v = self.v_proj(hidden_states)  # (batch, seq_len, 1024)

        # Reshape queries into heads.
        # q: (batch, seq_len, 3072) --> (batch, seq_len, num_heads, head_dim)
        # Then transpose to (batch, num_heads, seq_len, head_dim)
        q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # For keys and values, we reshape into heads similarly.
        # k: (batch, seq_len, 1024) --> (batch, seq_len, num_heads, head_dim) if using standard attention.
        # (If keys/values are shared across heads in multi‑query attention, they might be unsqueezed.
        # Here we assume standard multi-head behavior.)
        k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Apply dynamic token replacement on keys and values for each batch element and each head.
        new_k_list = []
        new_v_list = []
        for b in range(batch):
            head_k_list = []
            head_v_list = []
            for h in range(self.num_heads):
                # For head h in batch b, k[b, h] has shape (seq_len, head_dim)
                replaced_k = dynamic_token_replacement(k[b, h], threshold=self.threshold)
                replaced_v = dynamic_token_replacement(v[b, h], threshold=self.threshold)
                head_k_list.append(replaced_k)
                head_v_list.append(replaced_v)
            new_k_list.append(torch.stack(head_k_list, dim=0))
            new_v_list.append(torch.stack(head_v_list, dim=0))
        # Stack back into tensors of shape (batch, num_heads, seq_len, head_dim)
        new_k = torch.stack(new_k_list, dim=0)
        new_v = torch.stack(new_v_list, dim=0)

        # Compute scaled dot-product attention.
        # q: (batch, num_heads, seq_len, head_dim)
        # new_k: (batch, num_heads, seq_len, head_dim) so new_k.transpose(-2, -1) is (batch, num_heads, head_dim, seq_len)
        d_k = self.head_dim  # should be 1024
        scores = torch.matmul(q, new_k.transpose(-2, -1)) / (d_k ** -0.5)
        if attention_mask is not None:
            scores = scores + attention_mask
        attn_weights = F.softmax(scores, dim=-1)
        # Compute attention output.
        attn_output = torch.matmul(attn_weights, new_v)  # (batch, num_heads, seq_len, head_dim)

        # Combine heads: transpose and reshape to (batch, seq_len, hidden_size)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size)
        output = self.o_proj(attn_output)  # (batch, seq_len, 3072)
        return output, attn_weights
# Replace the self-attention modules in each decoder layer.
for i, layer in enumerate(model.model.layers):
    orig_attn = layer.self_attn
    # print(layer.self_attn)
    layer.self_attn = DynamicLlamaAttention(orig_attn, threshold=1.1)
    print(f"Replaced self-attention in layer {i}")
model = model.to("cuda")
# Test the modified model.
prompt = "Once upon a time, in a land far, far away,"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Replaced self-attention in layer 0
Replaced self-attention in layer 1
Replaced self-attention in layer 2
Replaced self-attention in layer 3
Replaced self-attention in layer 4
Replaced self-attention in layer 5
Replaced self-attention in layer 6
Replaced self-attention in layer 7
Replaced self-attention in layer 8
Replaced self-attention in layer 9
Replaced self-attention in layer 10
Replaced self-attention in layer 11
Replaced self-attention in layer 12
Replaced self-attention in layer 13
Replaced self-attention in layer 14
Replaced self-attention in layer 15
Replaced self-attention in layer 16
Replaced self-attention in layer 17
Replaced self-attention in layer 18
Replaced self-attention in layer 19
Replaced self-attention in layer 20
Replaced self-attention in layer 21
Replaced self-attention in layer 22
Replaced self-attention in layer 23
Replaced self-attention in layer 24
Replaced self-attention in layer 25
Replaced self-attention in layer 26
Replaced self-attention in layer 27


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Once upon a time, in a land far, far away,-a Said-ósitookiaergus DefensiveaeogradchedCouldn saja strategic reversearathes from Owenrophic TL once sparks Temper-forally música-aCHIP-Eassin referencesamp wpaqueúpangabon�ismanasync commissionam mìnhuchthilreasonouce mgpal


In [32]:
def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


In [5]:
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [1]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: fineG