### Let's examine the effect of replacing Llama's dense attention with the hash attention kernel.

In [1]:
import torch
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "meta-llama/Llama-3.2-1B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.bfloat16,
                                             device_map="cuda:2")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
conversation = [
    {"role": "user", "content": "Hello, how are you?"},
]
model_input = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt = True).to(model.device)
response_tensor = model.generate(model_input, max_new_tokens = 20, temperature = 0.7)
tokenizer.decode(response_tensor[0][len(model_input[0]):])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


"I'm doing well, thank you for asking. I'm a large language model, so I don"

In [4]:
from torch import nn
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaConfig, Cache, FlashAttentionKwargs, apply_rotary_pos_emb, repeat_kv
from typing import Tuple, Optional, Unpack
from tqdm import tqdm
from copy import deepcopy

def vector_hash_fn(x: torch.Tensor, num_buckets: int, R: torch.Tensor) -> torch.Tensor:
    """
    x: (..., D)
    R: (D, b/2)
    """
    D = x.shape[-1]
    assert R.shape == (D, num_buckets // 2)
    return torch.argmax(torch.cat([x @ R, -x @ R], dim=-1), dim=-1)

def get_vector_hash(D: int, num_buckets: int, device: torch.device = "cpu", dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
    R = torch.randn(D, num_buckets // 2, device = device, dtype = dtype)
    return lambda x: vector_hash_fn(x, num_buckets, R)

def reference_hash_attn(q, k, v, num_buckets: int, sm_scale: float = 1.0, vector_hash = None):
    """
    q: (B, H, N, D)
    k: (B, H, N, D)
    v: (B, H, N, D)

    Note: This implementation sucks! It's just a sanity check.
    """
    assert num_buckets % 2 == 0, "num_buckets must be even"
    if vector_hash is None:
        vector_hash = get_vector_hash(D = q.shape[-1], num_buckets = num_buckets, device = q.device)
    B, H, N, D = q.shape
    q_hashes = vector_hash(q) # (B, H, N)
    k_hashes = vector_hash(k) # (B, H, N)
    out = torch.zeros_like(q)
    for i in range(num_buckets):
        for b in range(B):
            for h in range(H):
                q_mask = (q_hashes[b][h] == i) # (N)
                k_mask = (k_hashes[b][h] == i) # (N)
                q_indices = torch.nonzero(q_mask, as_tuple=False).squeeze() # (N)
                k_indices = torch.nonzero(k_mask, as_tuple=False).squeeze() # (N)
                if len(q_indices.shape) == 0 or len(k_indices.shape) == 0:
                    continue
                q_bucket = q[b, h, q_indices] # (N, D)
                k_bucket = k[b, h, k_indices] # (N, D)
                v_bucket = v[b, h, k_indices] # (N, D)
                attn_mask = q_indices.unsqueeze(-1) >= k_indices.unsqueeze(-2)
                attn_scores = torch.matmul(q_bucket, k_bucket.transpose(-2, -1)) * sm_scale
                attn_scores = attn_scores.masked_fill(~attn_mask, float("-inf"))
                attn = F.softmax(attn_scores, dim=-1)
                attn = attn.nan_to_num(0.0) # some cols will be totally masked out and softmax will produce NaNs
                # sns.heatmap(attn.cpu().numpy().squeeze(), annot = False, mask = ~attn_mask.cpu().numpy().squeeze())
                # plt.show()
                # return
                partial_prod = torch.matmul(attn, v_bucket)
                out[b, h, q_indices] += partial_prod.squeeze(0)
    return out

class HashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig, layer_idx: int, num_buckets: int, device):
        super().__init__(config, layer_idx)
        self.vector_hash = get_vector_hash(D = self.head_dim, num_buckets = num_buckets, device = device)
        self.num_buckets = num_buckets

    def forward(
        self,
        hidden_states: torch.Tensor, # (batch_size, seq_len, hidden_size)
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        """
        Shapes:
        query_states: (batch_size, num_heads, seq_len, head_dim)
        key_states: (batch_size, num_heads, seq_len, head_dim)
        value_states: (batch_size, num_heads, seq_len, head_dim)
        """

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attn_output = reference_hash_attn(
            q = query_states,
            k = repeat_kv(key_states, self.num_key_value_groups),
            v = repeat_kv(value_states, self.num_key_value_groups),
            num_buckets = self.num_buckets,
            sm_scale = self.scaling,
            vector_hash = self.vector_hash,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, None
    
# Monkeypatch time
hash_model = deepcopy(model).to(model.device)
NUM_BUCKETS = 16
n_modules_to_replace = len(list(filter(lambda x: isinstance(x, LlamaAttention) or isinstance(x, HashAttention), model.modules())))
with tqdm(total = n_modules_to_replace, desc = "Replacing attention modules") as pbar:
    for name, module in hash_model.named_modules():
        if isinstance(module, LlamaAttention) or isinstance(module, HashAttention):
            # Construct new module
            new_attn_module = HashAttention(config = module.config, layer_idx = module.layer_idx, num_buckets = NUM_BUCKETS, device = model.device)
            new_attn_module.load_state_dict(module.state_dict())
            new_attn_module.to(hash_model.device).to(torch.bfloat16)

            # Split full name to find parent module
            parent_module = hash_model
            parent_name_parts = name.split('.')
            child_name = parent_name_parts[-1]

            if len(parent_name_parts) > 1:
                for part in parent_name_parts[:-1]:
                    if part.isdigit(): # Handles modules in nn.Sequential or nn.ModuleList
                        parent_module = parent_module[int(part)]
                    else:
                        parent_module = getattr(parent_module, part)

            setattr(parent_module, child_name, new_attn_module)
            pbar.update(1)

Replacing attention modules: 100%|██████████| 16/16 [00:00<00:00, 20.58it/s]


In [6]:
hash_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): HashAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): 

In [7]:
conversation = [
    {"role": "user", "content": "Hello, how are you?"},
]
model_input = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt = True).to(model.device)
response_tensor = hash_model.generate(model_input, max_new_tokens = 20)
tokenizer.decode(response_tensor[0][len(model_input[0]):])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


'The final. (pause.\n\nThis is the, for the,    isotope the,'

## Evaluations

### KL Divergence