In [1]:
import os
os.chdir('/data/apdesai/code/sparse-attention-hub')

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time

from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig
from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import (
    LocalMaskerConfig, SinkMaskerConfig
)
from sparse_attention_hub.sparse_attention.integrations.hugging_face import SparseAttentionHF

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")


Device: cuda


In [2]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"

try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
        attn_implementation="eager"
    )
    print(f"✅ Loaded {model_name}")
    
except Exception as e:
    print(f"Fallback to smaller model due to: {e}")
    model_name = "microsoft/DialoGPT-medium"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        attn_implementation="eager"
    )
    print(f"✅ Loaded {model_name}")


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✅ Loaded meta-llama/Llama-3.1-8B-Instruct


In [5]:
# Create StreamingLLM configuration: Sink (4 tokens) + Local (64 tokens)
sink_config = SinkMaskerConfig(sink_size=4)
local_config = LocalMaskerConfig(window_size=64)
research_config = ResearchAttentionConfig(masker_configs=[sink_config, local_config])

print(f"✅ StreamingLLM config: Sink(4) + Local(64)")


✅ StreamingLLM config: Sink(4) + Local(64)


In [6]:
# Create SparseAttentionHF integration object
sparse_attention_hf = SparseAttentionHF.create_from_config(research_config)
print("✅ SparseAttentionHF created")


✅ SparseAttentionHF created


In [34]:
# Test text about attention mechanisms and StreamingLLM
test_text = """ 
        The concept of attention mechanisms has revolutionized natural language processing and machine learning.
        StreamingLLM addresses efficiency challenges by implementing sparse attention patterns that combine:
        1. Sink tokens: The first few tokens contain crucial global information
        2. Local attention: Recent tokens are most relevant for next token prediction
        This approach maintains performance while reducing computational costs for long sequences.
        Summarize the above in a single title with less than 10 words.
        """
# Tokenize input
inputs = tokenizer(test_text, return_tensors="pt", truncation=True, max_length=32000)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

print(f"✅ Input prepared: {input_ids.shape[1]} tokens")


✅ Input prepared: 97 tokens


In [35]:
"a" * 100

'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'

In [36]:
# Run with full attention
model.eval()
max_new_tokens = 50

start_time = time.time()
with torch.no_grad():
    full_outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=0,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
full_time = time.time() - start_time

# Get generated text
full_generated_ids = full_outputs[0]
full_generated_text = tokenizer.decode(full_generated_ids, skip_special_tokens=True)

print(f"⏱️ Full attention: {full_time:.2f}s")
print(f"📝 Generated: {len(full_generated_ids) - len(input_ids[0])} tokens")
print("\nOutput:")
print("-" * 50)
print(full_generated_text)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


⏱️ Full attention: 7.96s
📝 Generated: 50 tokens

Output:
--------------------------------------------------
 
        The concept of attention mechanisms has revolutionized natural language processing and machine learning.
        StreamingLLM addresses efficiency challenges by implementing sparse attention patterns that combine:
        1. Sink tokens: The first few tokens contain crucial global information
        2. Local attention: Recent tokens are most relevant for next token prediction
        This approach maintains performance while reducing computational costs for long sequences.
        Summarize the above in a single title with less than 10 words.
         "Efficient Sequence Modeling with Local Attention and Global Context" 
        The title should be a concise summary of the main idea presented in the text. 

        Here is the correct title: "Efficient Sequence Modeling with Local Attention and Global Context"


In [37]:
# Run with sparse attention (StreamingLLM)
start_time = time.time()
with torch.no_grad():
    with sparse_attention_hf(model) as sparse_model:
        sparse_outputs = sparse_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            temperature=0,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
sparse_time = time.time() - start_time

# Get generated text
sparse_generated_ids = sparse_outputs[0]
sparse_generated_text = tokenizer.decode(sparse_generated_ids, skip_special_tokens=True)

print(f"⚡ Sparse attention: {sparse_time:.2f}s")
print(f"📝 Generated: {len(sparse_generated_ids) - len(input_ids[0])} tokens")
print("\nOutput:")
print("-" * 50)
print(sparse_generated_text)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


⚡ Sparse attention: 7.92s
📝 Generated: 50 tokens

Output:
--------------------------------------------------
 
        The concept of attention mechanisms has revolutionized natural language processing and machine learning.
        StreamingLLM addresses efficiency challenges by implementing sparse attention patterns that combine:
        1. Sink tokens: The first few tokens contain crucial global information
        2. Local attention: Recent tokens are most relevant for next token prediction
        This approach maintains performance while reducing computational costs for long sequences.
        Summarize the above in a single title with less than 10 words.
         "Efficient Sequence Modeling with Local Attention and Global Context" 
        The title should be a concise summary of the main idea presented in the text. 

        Here is the correct title: "Efficient Sequence Modeling with Local Attention and Global Context"
