# Attention Mechanisms in Transformers

This notebook provides an interactive environment to explore and understand attention mechanisms in transformer models. You can modify parameters, run the model, and visualize attention patterns.

In [1]:
# Import necessary libraries
import torch
from model import GPT, GPTConfig
from train import SimpleDataset, train
from attention_demo import visualize_attention_weights, get_attention_patterns
import json

# Load configuration
with open('../config.json', 'r') as f:
    config_data = json.load(f)

# Initialize model configuration
config = GPTConfig(
    vocab_size=config_data['vocab_size'],
    block_size=config_data['block_size'],
    n_layer=config_data['n_layer'],
    n_head=config_data['n_head'],
    n_embd=config_data['n_embd'],
    dropout=config_data['dropout'],
    debug=config_data['debug']
)

# Create dataset and model
dataset = SimpleDataset(config.block_size)
model = GPT(config)

# Train the model
train(model, dataset, epochs=1, batch_size=4)

# Analyze attention patterns
text = "the quick brown fox"
words, qk_vectors = get_attention_patterns(model, text, dataset)
for layer_idx, (layer_name, (q, k)) in enumerate(qk_vectors):
    att = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(q.size(-1), dtype=torch.float32))
    att = torch.nn.functional.softmax(att, dim=-1)
    visualize_attention_weights(att[0], words, f"Layer {layer_idx+1}")