# Transformers Attention Visualization - Basic Usage

This notebook demonstrates how to use the transformers-attention-viz library to visualize attention patterns in multi-modal transformer models.

In [None]:
# Install the library (if not already installed)
# !pip install transformers-attention-viz

In [None]:
# Import required libraries
from transformers import CLIPModel, CLIPProcessor
from attention_viz import AttentionVisualizer
from PIL import Image
import requests
from io import BytesIO

In [None]:
## 1. Load Model and Create Visualizer

In [None]:
# Load CLIP model and processor
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

# Create visualizer
visualizer = AttentionVisualizer(model, processor)
print(f"Loaded {model_name}")

In [None]:
## 2. Load Example Image

In [None]:
# Load an example image from URL
url = "https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400"
response = requests.get(url)
image = Image.open(BytesIO(response.content))

# Display the image
image.resize((224, 224))

In [None]:
## 3. Basic Attention Heatmap

In [None]:
# Create attention heatmap
text = "a photo of a cat"

viz = visualizer.visualize(
    image=image,
    text=text,
    visualization_type="heatmap",
    layer_indices=-1  # Last layer
)

viz.show()

In [None]:
## 4. Attention Flow Diagram

In [None]:
# Create attention flow visualization
viz_flow = visualizer.visualize(
    image=image,
    text=text,
    visualization_type="flow",
    threshold=0.1,  # Only show connections above this threshold
    top_k=5  # Show top 5 connections per token
)

viz_flow.show()

In [None]:
## 5. Layer Evolution Analysis

In [None]:
# Analyze how attention evolves across layers
viz_evolution = visualizer.visualize(
    image=image,
    text=text,
    visualization_type="evolution",
    metric="entropy"  # Track entropy across layers
)

viz_evolution.show()

In [None]:
## 6. Attention Statistics

In [None]:
# Get detailed attention statistics
stats = visualizer.get_attention_stats(image, text)

print("Attention Statistics:")
print(f"- Average Entropy: {stats['entropy'].mean():.3f}")
print(f"- Attention Concentration (Gini): {stats['concentration']:.3f}")
print("\nTop Attended Tokens:")
for token, score in stats['top_tokens']:
    print(f"  - {token}: {score:.3f}")

In [None]:
## 7. Compare Multiple Inputs

In [None]:
# Compare attention patterns for different text descriptions
texts = ["a cat", "a fluffy orange cat", "a cat sitting on a couch"]

comparison = visualizer.compare_attention(
    images=[image] * 3,  # Same image, different texts
    texts=texts,
    layer_index=-1
)

comparison.show()

In [None]:
## 8. Save Visualizations

In [None]:
# Save visualization to file
viz.save("attention_heatmap.png", dpi=300)
print("Saved attention_heatmap.png")

# Convert to PIL Image for further processing
pil_image = viz.to_image()
print(f"Image size: {pil_image.size}")

In [None]:
## 9. Launch Interactive Dashboard

In [None]:
# Launch the interactive dashboard
from attention_viz import launch_dashboard

# Uncomment to launch (will open in browser)
# launch_dashboard(model, processor)