In [1]:
# For boba

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"

import torch
print(torch.cuda.device_count())

2


In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "meta-llama/Llama-2-7b-chat-hf"

model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16, attn_implementation='sdpa', device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [3]:
import transformers
print(transformers.__version__)

4.37.2


In [4]:
sample_text = open("snapkv_full.txt", "r", encoding="utf-8").read()
encoded_tokens = tokenizer(sample_text, return_tensors="pt")
for key in encoded_tokens:
    encoded_tokens[key] = encoded_tokens[key].cuda()
encoded_tokens = encoded_tokens.input_ids[0]
print(encoded_tokens.shape)

torch.Size([17585])


In [5]:
padding_window = 64
query_window = 12
max_tokens = 512

In [6]:
from typing import Optional, Tuple
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
import math
from torch import nn
import torch.nn.functional as F

def sdpa_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    # print("Running custom forward function")
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )

    # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
    # Reference: https://github.com/pytorch/pytorch/issues/112577.
    if query_states.device.type == "cuda" and attention_mask is not None:
        query_states = query_states.contiguous()
        key_states = key_states.contiguous()
        value_states = value_states.contiguous()

    # multiply q by k and and return it back to the user.
    attn_weights = torch.matmul(query_states[..., -query_window:, :], key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    mask = torch.full((query_window, query_window), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
    mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(attn_weights.device)
    weights_attention_mask = mask[None, None, :, :]
    attn_weights[:, :, -query_window:, -query_window:] += weights_attention_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights_sum = attn_weights[:, :, -query_window:, : -query_window].sum(dim = -2)
    attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = 7, padding=7//2, stride=1)

    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query_states,
        key_states,
        value_states,
        attn_mask=attention_mask,
        dropout_p=self.attention_dropout if self.training else 0.0,
        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
        is_causal=self.is_causal and attention_mask is None and q_len > 1,
    )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    return attn_output, attn_cache, past_key_value

for i in range(len(model.model.layers)):
    model.model.layers[i].self_attn.forward = sdpa_forward.__get__(model.model.layers[i].self_attn, type(model.model.layers[i].self_attn))

In [7]:
import matplotlib.pyplot as plt

content_tokens = encoded_tokens[:-query_window]
query_tokens = encoded_tokens[-query_window:]

print("Query tokens: ", tokenizer.decode(query_tokens))

unpadded_chunks = []
total_length = encoded_tokens.shape[-1] - query_window

x = padding_window
y = max_tokens - padding_window * 2 - query_window

for content_start in range(0, total_length - y, y):
    split_start = max(0, content_start - x)
    split_end = min(content_start + y, total_length)
    # Adjust post-chunk and pre-chunk if necessary
    post_chunk_end = min(split_end + x, total_length)
    pre_chunk_start = max(0, split_start - x)

    # Handle wrapping for pre-chunk
    if split_start - x < 0:
        pre_chunk_start = total_length + (split_start - x)
        pre_chunk = torch.cat((content_tokens[pre_chunk_start:], content_tokens[:split_start]), dim=-1)
    else:
        pre_chunk = content_tokens[pre_chunk_start:split_start]

    # Handle wrapping for post-chunk
    if split_end + x > total_length:
        post_chunk_end = (split_end + x) % total_length
        post_chunk = torch.cat((content_tokens[split_end:], content_tokens[:post_chunk_end]), dim=-1)
    else:
        post_chunk = content_tokens[split_end:post_chunk_end]

    # Combine pre-chunk, relevant content, and post-chunk
    chunk_tensor = torch.cat((pre_chunk, content_tokens[split_start:split_end], post_chunk, query_tokens), dim=-1)
    decoded_tokens = tokenizer.decode(chunk_tensor, skip_special_tokens=True)
    chunk_text = tokenizer.apply_chat_template([{
        "role": "user",
        "content": decoded_tokens
    }], tokenize=False)

    unpadded_chunks.append(chunk_text)

print("Unpadded chunks length: ", len(unpadded_chunks))

tokenizer.pad_token = tokenizer.eos_token
chunks = tokenizer(unpadded_chunks, return_tensors="pt", padding=True)
for key in chunks:
    chunks[key] = chunks[key].to(model.device)

outputs = model.generate(**chunks, output_attentions=True, return_dict_in_generate=True, max_new_tokens=1)
attention_scores = outputs.attentions[0]

print("First length: ", len(attention_scores))
print("Second length: ", len(attention_scores[0]))
print("Third length: ", attention_scores[0][0].shape)
# print(attention_scores)

# since we have each of these in distinct layers, we should first stack them
attention_scores = torch.stack(attention_scores, dim=0)
print("Shape: ", attention_scores.shape)
attention_scores = attention_scores.sum(dim=0)
print("Shape: ", attention_scores.shape)
attention_scores = attention_scores.squeeze(0)
print("Shape: ", attention_scores.shape)

chunk_scores = []
attention_weights = []

for chunk_text, scores in zip(unpadded_chunks, attention_scores):
    # scores = F.avg_pool1d(scores, kernel_size=14, stride=5)
    # interest_score = torch.sum(scores[:, query_window:-query_window])
    intermediate_sum = torch.sum(scores[:, padding_window:-padding_window], dim=0)
    # intermediate_sum = torch.where(intermediate_sum < 100, intermediate_sum, torch.tensor(0))

    interest_score = torch.sum(intermediate_sum)
    chunk_scores.append((interest_score.cpu().numpy(), chunk_text, intermediate_sum.cpu().numpy()))
    # print(chunk_text)
    # print("Interest score: ", interest_score)
    # print("===========================")
    # pick the last 'max_tokens - query_window' tokens, sum up attention and see how it stacks up

chunk_scores = sorted(chunk_scores, reverse=True)
for score, text, attention_scores in chunk_scores:
    print("Score: ", score)
    print(text)
    plt.plot(range(attention_scores.shape[-1]), attention_scores, marker='o', linestyle='-', color='b')
    plt.show()
    # print("Attention scores: ", attention_scores)
    print("===================")

Query tokens:  
What is the GitHub repository for SnapKV?
Unpadded chunks length:  47


RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
