<a href="https://colab.research.google.com/github/vishal-337/CS-6730-Group-8/blob/main/Copy_of_Lab_task2_template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Task 2: Visualizing Attention and Analyzing the “Attention Sink” Phenomenon**

## **Overview**
Explores **Transformer self-attention** by capturing, visualizing, and quantifying attention maps, then simulates how **StreamLLM** leverages the *“attention sink”* to enable long-sequence generation with fixed memory.  
The notebook adds `output_attentions=True` for introspection, builds a hook-based **Attention Catcher**, measures sink attention across layers, and contrasts standard **KV caching** with a **StreamLLM-style cache** that trims middle tokens while preserving sinks and recency.

---

## **Step 1: Load the Model in Investigation Mode (Cell 4)**

- **Enable attention outputs:**  
  Load the model with  
  `AutoModelForCausalLM.from_pretrained(..., output_attentions=True)`  
  so each forward pass returns per-layer attention tensors alongside logits.

- **Tokenization setup:**  
  Initialize the matching `AutoTokenizer` and standard generation configs used throughout experiments.

---

## **Step 2: Build the “Attention Catcher” Toolkit (Cell 5)**

- **Global storage:**  
  `attention_maps_storage` keeps captured attention tensors keyed by layer/module identifiers.

- **Hook factory:**  
  `get_attention_hook` returns a forward hook that extracts `attn_weights` during the pass and stores them in `attention_maps_storage`.

- **Hook registrar:**  
  `register_attention_hooks(model, layers=...)` attaches the hook to each chosen self-attention module (single layer, a subset, or all layers).

- **Visualization:**  
  `plot_attention_maps` retrieves saved maps, aggregates across heads (e.g., mean over heads), and renders **attention heatmaps** for inspection.

---

## **Step 3: Experiment 1 — Attention Patterns Across Inputs (Cell 6)**

- **Inputs:**  
  Define `INPUT_TEXTS` with both meaningful sentences and repetitive “dummy” strings of varying lengths.

- **Hook installation:**  
  Call `register_attention_hooks` for selected layers; clear `attention_maps_storage` before each run.

- **Single forward pass:**  
  Tokenize each input and run a forward pass to populate attention storage via hooks.

- **Heatmaps:**  
  Use `plot_attention_maps` to visualize per-layer attention, aggregated across heads, for side-by-side comparison across inputs.

- **Cleanup:**  
  Remove all hooks after the experiment to avoid extra overhead later.

---

## **Step 4: Experiment 2 — Quantifying the Attention Sink (Cell 7)**

- **Full-depth monitoring:**  
  Register hooks on every layer to capture a complete attention profile across the model.

- **Sink metric:**  
  For each input and layer, compute the fraction of attention mass directed to the first `SINK_TOKEN_WINDOW` tokens, averaging across heads and query positions.

- **Trends by layer:**  
  Store results and plot layer-wise curves  
  *(x-axis: layer ID; y-axis: sink attention %)*  
  with separate lines per input type to reveal consistent sink patterns.

---

## **Step 5 (Advanced): StreamLLM Simulation and Memory Advantage (Cells 8–9)**

- **Positional shift attention:**  
  `llama_pos_shift_attention_forward` modifies the attention forward path to dynamically adjust positional encodings (e.g., RoPE phases) when intermediate tokens are evicted, preserving correct relative positions among the remaining tokens.

- **KV cache manager:**  
  `streamingllm_kv` tracks and trims the KV cache by discarding middle tokens once capacity is exceeded, keeping only early “sink” tokens and the most recent tokens.

- **Baseline vs. StreamLLM:**  
  - `run_baseline_experiment`: Standard generation where KV cache grows linearly with sequence length; log memory usage over steps.  
  - `run_streamllm_experiment`: Generation with `streamingllm_kv` trimming after each step; log memory usage for comparison.

- **Analysis:**  
  Plot both memory curves against generated tokens to show linear growth (**baseline**) vs. plateau (**StreamLLM-style trimming**), illustrating fixed-memory long-context generation.

---

## **Results and Takeaways**

- **Memory efficiency:**  
  StreamLLM-style KV trimming flattens memory growth, enabling sustained generation without exhausting memory.

- **Output quality:**  
  Standard full-cache generation degrades (e.g., incoherent characters) far beyond training lengths, while StreamLLM maintains more coherent outputs under extended contexts.


In [None]:
### Cell 2: Environment Setup and Dependency Imports
# TODO: import all required libraries (os, random, numpy, pandas, torch, transformers, etc.)

RESULTS_DIR = "./results"
FIGURES_DIR = "./figures"

# TODO: create output directories if they do not exist
# os.makedirs(RESULTS_DIR, exist_ok=True)
# os.makedirs(FIGURES_DIR, exist_ok=True)

DEVICE = None  # TODO: select torch.device based on CUDA availability

# TODO: print environment diagnostics (CUDA availability, PyTorch version, etc.)

def set_seed(seed: int = 42) -> None:
    """Seed Python, NumPy, and PyTorch RNGs for reproducible attention analysis."""
    ...

def require_gpu(task: str) -> None:
    """Raise a descriptive error when a GPU-specific task cannot run."""
    ...

# TODO: configure plotting defaults and initialise the environment
# set_seed(42)
# sns.set_theme(...)
# plt.rcParams.update(...)
# print("Environment initialised.")


In [None]:
# ### Cell 3: Hugging Face Login
# from huggingface_hub import login, HfFolder
# from getpass import getpass

# # Check if a Hugging Face token is already set in the environment.
# if not os.getenv("HUGGING_FACE_HUB_TOKEN"):
#     try:
#         # Prompt user for Hugging Face access token if not found.
#         hf_token = getpass("Please enter your Hugging Face access token: ")
#         login(token=hf_token, add_to_git_credential=True)
#         print("   Hugging Face login successful!")
#     except Exception as e:
#         print(f"Login failed: {e}. Model loading may fail later.")
# else:
#     print("   Hugging Face token detected.")

In [None]:
### Cell 4: Load Model and Tokenizer
MODEL_ID = "..."  # TODO: primary model identifier
FALLBACK_MODEL_ID = "..."  # TODO: fallback model identifier

model: Optional[torch.nn.Module] = None
tokenizer: Optional["AutoTokenizer"] = None

# TODO: import AutoModelForCausalLM and AutoTokenizer from transformers
# from transformers import AutoModelForCausalLM, AutoTokenizer

candidate_models = [MODEL_ID, FALLBACK_MODEL_ID]

for candidate in candidate_models:
    # TODO: attempt to load tokenizer/model with appropriate dtype and device placement
    pass

# TODO: raise an error if loading fails for all candidate models

# TODO: ensure tokenizer/model pad tokens are configured
# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
# if getattr(model.config, "pad_token_id", None) is None:
#     model.config.pad_token_id = tokenizer.pad_token_id

# TODO: place model on the desired device, switch to eval mode, and print summary stats
# model.eval()
# print("Model summary:")
# print(...)


In [None]:
### Cell 5: Core Functions for Attention Extraction and Visualization

# Global storage for attention maps, keyed by layer name
attention_maps_storage = {}

def get_attention_hook(layer_name):
    """Return a forward hook function that stores attention weights for the given layer."""
    # TODO: capture attention tensors and store them in attention_maps_storage
    ...

def register_attention_hooks(model, layers_to_hook):
    """Register forward hooks on attention modules for the requested layers."""
    hooks = []
    # TODO: locate attention modules (e.g., LlamaAttention) and register hooks
    # hook_handle = attn_module.register_forward_hook(get_attention_hook(...))
    # hooks.append(hook_handle)
    return hooks

def plot_attention_maps(attention_maps, tokens, layers_to_plot, file_prefix):
    """Visualise attention maps for selected layers and save the figure."""
    # TODO: aggregate attention across heads, configure subplots, and render heatmaps
    ...


In [None]:
### Cell 6: Experiment - Visualize Attention Maps for Different Inputs

# --- Configurable Section ---
INPUT_TEXTS = {
    "short_dummy": "...",  # TODO: provide dummy prompt
    "short_meaningful": "...",  # TODO: provide meaningful prompt
    "medium_dummy": "...",
    "medium_meaningful": "...",
    "long_dummy": "...",
    "long_meaningful": "...",
}

LAYERS_TO_VISUALIZE = [...]  # TODO: select representative layer indices
if "1b" in MODEL_ID.lower():
    LAYERS_TO_VISUALIZE = [...]  # TODO: adjust layers for smaller models
# ---

hooks = []  # TODO: register attention hooks for the selected layers
# hooks = register_attention_hooks(model, LAYERS_TO_VISUALIZE)

for name, text in INPUT_TEXTS.items():
    print(f"\n--- Processing input: {name} ---")
    attention_maps_storage.clear()
    # TODO: tokenize text, run the model with output_attentions=True, and collect attention maps
    # inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(DEVICE)
    # with torch.no_grad():
    #     outputs = model(**inputs, output_attentions=True)
    # tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    # plot_attention_maps(attention_maps_storage, tokens, LAYERS_TO_VISUALIZE, name)
    pass

# TODO: remove hooks after visualization to avoid memory leaks
# for handle in hooks:
#     handle.remove()
print("\n   All attention maps for the provided inputs have been generated.")


In [None]:
### Cell 7: Experiment - Attention Sink Phenomenon Analysis

print("\n--- Starting Experiment: Attention Sink Phenomenon Analysis ---")
# --- Configurable Section ---
SINK_TOKEN_WINDOW = ...  # TODO: choose number of initial tokens treated as sink tokens
# ---

sink_analysis_results = []
hooks = []  # TODO: register attention hooks across all layers
# hooks = register_attention_hooks(model, range(model.config.num_hidden_layers))

for name, text in INPUT_TEXTS.items():
    attention_maps_storage.clear()
    # TODO: tokenize text and run the model to capture attention maps
    # inputs = tokenizer(text, return_tensors="pt", max_length=256).to(DEVICE)
    # with torch.no_grad():
    #     outputs = model(**inputs, output_attentions=True)

    for layer_idx in range(model.config.num_hidden_layers):
        layer_name = f"layer_{layer_idx}"
        if layer_name in attention_maps_storage:
            # TODO: compute sink attention statistics and append to sink_analysis_results
            # attn_map = attention_maps_storage[layer_name][0].mean(dim=0)
            # sink_attention_strength = attn_map[SINK_TOKEN_WINDOW:, :SINK_TOKEN_WINDOW].sum().item()
            # total_attention = attn_map[SINK_TOKEN_WINDOW:, :].sum().item()
            # sink_percentage = (sink_attention_strength / total_attention) * 100 if total_attention > 0 else 0
            # sink_analysis_results.append({"Input Type": name, "Layer ID": layer_idx, "Sink Attention (%)": sink_percentage})
            pass

# TODO: remove hooks after analysis
# for handle in hooks:
#     handle.remove()

# TODO: convert sink_analysis_results into a DataFrame and save to disk
# df_sink = pd.DataFrame(sink_analysis_results)
# df_sink.to_csv("./results/task2_attention_sink_analysis.csv", index=False)

# TODO: plot sink attention percentage vs. layer depth for each input type
# plt.figure(...)
# sns.lineplot(...)
# plt.savefig("./figures/task2_attention_sink_analysis.png", dpi=300)

print("\n--- Attention Sink Analysis Results ---")
# TODO: summarise sink attention statistics and display the plot
# print(df_sink.groupby("Input Type")["Sink Attention (%)"].mean().reset_index())
# plt.show()

"""
#### Attention Sink Phenomenon Analysis

**Observed Phenomena:**

**Analysis:**

"""


In [None]:
### Cell 8: Bonus Experiment: Modify Standard Attention to StreamingLLM Attention (Task 2 Step 4)
# TODO: import required modules for custom attention (transformers attention utilities, logging, torch.nn, types)
# from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# from transformers.models.llama.modeling_llama import LlamaAttention, rotate_half, repeat_kv
# from transformers.utils import logging
# import torch.nn as nn
# import types


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    """
    Implements the standard (eager) attention forward pass.
    """
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(
        attn_weights, dim=-1, dtype=torch.float32
    ).to(query.dtype)
    attn_weights = nn.functional.dropout(
        attn_weights, p=dropout, training=module.training
    )
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights

def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
    """
    Applies rotary positional embedding to a single tensor.
    """
    # Remove singleton dimensions for broadcasting
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    x_embed = (x * cos) + (rotate_half(x) * sin)
    return x_embed

def apply_rotary_pos_emb_q(q, cos, sin, unsqueeze_dim=1):
    """
    Applies rotary positional embedding to the query tensor.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    return q_embed


# TODO: refer to https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/pos_shift/modify_llama.py
# modify to fit llama3 architecture
def llama_pos_shift_attention_forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position=None, **kwargs):
    """Modified LLaMA attention forward pass with position shifting for StreamLLM."""
    # TODO: project QKV, update caches, apply rotary embeddings, and compute attention outputs
    ...

def enable_llama_pos_shift_attention(model):
    """Replace standard LlamaAttention.forward methods with the position-shifted variant."""
    # TODO: recursively locate LlamaAttention modules and bind llama_pos_shift_attention_forward
    ...


In [None]:
### Cell 9: Bonus Experiment: Investigate StreamLLM's Impact on Long-Sequence Memory Usage
# TODO: import tqdm for progress visualisation
# from tqdm.auto import tqdm

print("\n--- Starting Bonus Experiment: Simulating StreamLLM Memory Impact ---")

# --- Configurable Section ---
BONUS_PROMPT = "..."  # TODO: provide long-form prompt
BONUS_GENERATION_LENGTH = ...  # TODO: choose number of tokens to generate
BONUS_SAMPLING_INTERVAL = ...  # TODO: sampling interval for memory measurements
STREAMLLM_CACHE_SIZE = ...  # TODO: number of sink tokens to retain
STREAMLLM_RECENT_SIZE = ...  # TODO: number of most recent tokens to retain
# ---

def run_baseline_experiment(model, tokenizer, prompt, generation_length, sampling_interval, device):
    """Run baseline generation with the standard KV cache while logging memory usage."""
    # TODO: implement generation loop without cache eviction and record GPU memory
    ...

class streamingllm_kv:
    """Implement StreamLLM-style KV cache eviction (retain sink + recent tokens)."""
    def __init__(self, start_size, recent_size, past_key_values):
        # TODO: store configuration for cache trimming
        ...

    def __call__(self, kv_cache):
        """Trim the KV cache according to the StreamLLM policy."""
        # TODO: drop middle tokens while retaining sink and recent tokens
        ...

def run_streamllm_experiment(model, tokenizer, prompt, generation_length, sampling_interval, sink_size, recent_size, device):
    """Run generation with StreamLLM cache eviction and record memory usage."""
    # TODO: enable modified attention, apply streaming cache policy, and log memory
    ...

# =================================================================================
# Main Execution Flow
# =================================================================================

# TODO: prepare chat-formatted prompt and run both baseline and StreamLLM experiments
# messages = [...]
# BONUS_PROMPT = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# memory_baseline, last100words_baseline = run_baseline_experiment(...)
# memory_streamllm, last100words_streamllm = run_streamllm_experiment(...)

# TODO: collate memory usage results, save CSV artifacts, and plot comparisons
# df_mem_compare = pd.concat([...])
# df_mem_compare.to_csv("./results/task2_bonus_memory_comparison.csv", index=False)
# sns.lineplot(...)
# plt.savefig("./figures/task2_bonus_memory_comparison.png", dpi=300)

# TODO: decode and print the final segments from each generation for qualitative comparison
# print(tokenizer.decode(last100words_baseline))
# print(tokenizer.decode(last100words_streamllm))


In [None]:
### Cell 10: List all generated artifacts for Task 2
print("Task 2 complete. Generated artifacts:")

# TODO: iterate over output directories and list generated files
# if os.path.isdir(FIGURES_DIR):
#     ...
# if os.path.isdir(RESULTS_DIR):
#     ...
