In [None]:
from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer)

import torch

import sys
import os

# Add the parent directory (project_folder) to the system path
# This allows Python to find main_script
current_dir = os.path.dirname(__file__)
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)

from coarse_grain_model import GPT2WithSlidingWindow

In [None]:
"""
Tests the custom model by visualizing the attention pattern for a specific token.
"""
model_path = "./models/gpt2" 
WINDOW_SIZE = 5  # Use a small window for easy verification

# 1. Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2WithSlidingWindow.from_pretrained(model_path)
model.config.window_size = WINDOW_SIZE
model.eval() # Set model to evaluation mode

# 2. Create sample input
text = "The quick brown fox jumps over the lazy dog and runs away."
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
tokens = [tokenizer.decode(token_id) for token_id in input_ids[0]]
seq_len = len(tokens)

print(f"Input Sentence: '{text}'")
print(f"Window Size: {WINDOW_SIZE}\n")
print("-" * 50)

# 3. Perform a forward pass, requesting attention scores
with torch.no_grad():
    outputs = model(**inputs, output_attentions=True)

# `outputs.attentions` is a tuple of attention tensors, one for each layer.
# Let's inspect the first layer's attention.
# Shape: [batch_size, num_heads, query_len, key_len]
attention_layer_0 = outputs.attentions[0]

# Let's inspect the attention from the first head.
attention_head_0 = attention_layer_0[0, 0, :, :].cpu().numpy()

# 4. Pick a token to analyze and verify its attention window
token_to_check_index = 10

# Get the attention scores *from* this token *to* all other tokens
attention_scores = attention_head_0[token_to_check_index]

# The tokens that received a non-negligible attention score
attended_indices = np.where(attention_scores > 0.001)[0]

# The actual tokens it attended to
attended_tokens = [tokens[i] for i in attended_indices]

# Calculate the expected window
expected_start_index = max(0, token_to_check_index - WINDOW_SIZE + 1)
expected_end_index = token_to_check_index
expected_window_indices = list(range(expected_start_index, expected_end_index + 1))
expected_tokens = [tokens[i] for i in expected_window_indices]

print(f"🔍 ANALYSIS FOR TOKEN '{tokens[token_to_check_index]}' (at index {token_to_check_index}):\n")

print(f"EXPECTED to attend to tokens from index {expected_start_index} to {expected_end_index}:")
print(f"==> {expected_tokens}\n")

print(f"ACTUALLY attended to tokens at indices {attended_indices.tolist()}:")
print(f"==> {attended_tokens}\n")

# 5. Assert to confirm correctness
assert sorted(attended_indices.tolist()) == sorted(expected_window_indices), \
    "Test Failed: The model did not attend to the correct sliding window!"

print("✅ TEST PASSED: The attention pattern matches the expected sliding window.")