In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "HuggingFaceTB/SmolLM-135M-Instruct"

model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16, attn_implementation='sdpa').cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)



generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

In [2]:
import transformers
print(transformers.__version__)

4.37.2


In [3]:
sample_text = open("snapkv.txt", "r").read()
encoded_tokens = tokenizer(sample_text, return_tensors="pt")
for key in encoded_tokens:
    encoded_tokens[key] = encoded_tokens[key].cuda()
print(encoded_tokens.input_ids.shape)

Token indices sequence length is longer than the specified maximum sequence length for this model (17256 > 2048). Running this sequence through the model will result in indexing errors


torch.Size([1, 17256])


In [4]:
query_window = 48
template_window = 48
max_tokens = 1024

In [5]:
import torch
import time
import torch.nn.functional as F
import torch.nn as nn
import math

# perform qk calculation and get indices
# this version will not update in inference mode
class SnapKVCluster():
    def __init__(self, query_window = 64, template_window = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):
        self.query_window = query_window
        self.template_window = template_window
        self.max_capacity_prompt = max_capacity_prompt
        assert self.max_capacity_prompt - self.query_window > 0
        self.kernel_size = kernel_size
        self.pooling = pooling

    def update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):
        # check if prefix phase
        assert key_states.shape[-2] == query_states.shape[-2]
        bsz, num_heads, q_len, head_dim = query_states.shape
        if q_len < self.max_capacity_prompt:
            return key_states, value_states
        else:
            attn_weights = torch.matmul(query_states[..., -self.query_window:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)
            mask = torch.full((self.query_window, self.query_window), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
            mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
            mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
            mask = mask.to(attn_weights.device)
            attention_mask = mask[None, None, :, :]

            attn_weights[:, :, -self.query_window:, -self.query_window:] += attention_mask
            attn_weights = torch.where(torch.isinf(attn_weights), torch.tensor(-1e9, dtype=attn_weights.dtype), attn_weights)
            # print("Weights before softmax: ", attn_weights)
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_weights_sum = attn_weights[:, :, -self.query_window:, :-self.query_window].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            return attn_cache

snapkv_cluster = SnapKVCluster(query_window=query_window, template_window=template_window, max_capacity_prompt=max_tokens)

In [6]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head):

In [7]:
from typing import Optional, Tuple
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

def sdpa_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    # print("Running custom forward function")
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )

    # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
    # Reference: https://github.com/pytorch/pytorch/issues/112577.
    if query_states.device.type == "cuda" and attention_mask is not None:
        query_states = query_states.contiguous()
        key_states = key_states.contiguous()
        value_states = value_states.contiguous()

    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query_states,
        key_states,
        value_states,
        attn_mask=attention_mask,
        dropout_p=self.attention_dropout if self.training else 0.0,
        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
        is_causal=self.is_causal and attention_mask is None and q_len > 1,
    )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    attn_cache = snapkv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)

    return attn_output, attn_cache, past_key_value

for i in range(len(model.model.layers)):
    model.model.layers[i].self_attn.forward = sdpa_forward.__get__(model.model.layers[i].self_attn, type(model.model.layers[i].self_attn))

In [8]:
from einops import repeat
import time

start_time = time.perf_counter()
top_tokens = []

encoded_tokens = tokenizer.encode(sample_text, return_tensors="pt")[0]  
print("Encoding tokens took: ", time.perf_counter() - start_time)
print("Encoded tokens shape: ", encoded_tokens.shape)

chunk_count = math.ceil((encoded_tokens.shape[-1] - template_window - query_window) / (max_tokens - query_window - template_window))
print(encoded_tokens[template_window:-query_window])
split_tensors = encoded_tokens[template_window:-query_window].chunk(chunks=chunk_count, dim=-1)
# split_tensors = torch.split(encoded_tokens[template_adjustment:-query_window], max_tokens - query_window - template_adjustment)
max_split_tokens = 0
for split_tensor in split_tensors:
    max_split_tokens = max(max_split_tokens, split_tensor.shape[-1])
padded_tensors = []
attention_masks = []

template_tokens = encoded_tokens[:template_window]
query_tokens = encoded_tokens[-query_window:]

print("Query tokens shape: ", query_tokens.shape)
print("Template tokens shape: ", template_tokens.shape)

for i, chunk_tensor in enumerate(split_tensors):
    joined_tensor = torch.cat((template_tokens, chunk_tensor, query_tokens), dim=-1) # -> take care of template tensor later
    pad_tensor = torch.tensor(tokenizer.pad_token_id).expand(max_tokens - joined_tensor.shape[-1])
    chunk_attention_mask = torch.IntTensor([0] * (max_tokens - chunk_tensor.shape[-1]) + [1] * (chunk_tensor.shape[-1]))
    padded_tensors.append(torch.cat((pad_tensor, joined_tensor), dim=-1))
    attention_masks.append(chunk_attention_mask)

attention_masks = torch.stack(attention_masks, dim=0)
padded_tensors = torch.stack(padded_tensors, dim=0)
print("Generated padded tensors: ", padded_tensors.shape, " and attention masks: ", attention_masks.shape, " in ", time.perf_counter() - start_time)

padded_tensors = padded_tensors.to(model.device)
attention_masks = attention_masks.to(model.device)

# batch run 
outputs = model(padded_tensors, attention_mask=attention_masks, output_attentions=True)#, return_dict_in_generate=True, output_attentions=True)
for key in outputs:
    print("Output key: ", key)
print("Forward pass completed in: ", time.perf_counter() - start_time)

summed_weights = torch.stack(outputs.attentions, dim=0).sum((0, 2))
first_template_window = summed_weights[0, :template_window]
summed_weights = summed_weights[:, template_window:]
summed_weights = torch.cat((first_template_window, summed_weights.flatten()), dim=0)

# summed_weights = torch.stack(summed_weights, dim=0)
# summed_weights = summed_weights.sum(dim=0)
print("Summed weights shape: ", summed_weights.shape)
print(summed_weights)

indices = summed_weights.topk(max_tokens, dim=-1).indices.cpu()
print(indices)
selected_tokens = encoded_tokens[indices]
decoded_text = tokenizer.decode(selected_tokens)
print("Decoded text: ", decoded_text)

# get the top tokens attended to by the query tokens
# create one single list of "most important tokens"
# return this string.
end_time = time.perf_counter()
print("Time taken: ", end_time - start_time)

Encoding tokens took:  0.038643027015496045
Encoded tokens shape:  torch.Size([17256])
tensor([   70, 13842,   281,  ...,  2526,  5605, 19926])
Query tokens shape:  torch.Size([48])
Template tokens shape:  torch.Size([48])
Generated padded tensors:  torch.Size([19, 1024])  and attention masks:  torch.Size([19, 1024])  in  0.04112430999521166
Output key:  logits
Output key:  past_key_values
Output key:  attentions
Forward pass completed in:  0.6328249570215121
Summed weights shape:  torch.Size([17680])
tensor([ 0.3438,  0.4636,  0.5820,  ..., 39.0938, 31.0156, 24.6875],
       device='cuda:0', dtype=torch.float16, grad_fn=<CatBackward0>)
tensor([10329, 10330,  9402,  ...,  1851,  3685,  5495])


IndexError: index 17610 is out of bounds for dimension 0 with size 17256