# 🔍 Transformers Attention Viz Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sisird864/transformers-attention-viz/blob/main/examples/demo.ipynb)
[![PyPI](https://img.shields.io/pypi/v/transformers-attention-viz.svg)](https://pypi.org/project/transformers-attention-viz/)

This notebook demonstrates how to visualize attention patterns in multi-modal transformers like CLIP and BLIP.

**What you'll learn:**
- How to install and use transformers-attention-viz
- Visualizing cross-modal attention between text and images
- Analyzing attention statistics
- Using the interactive dashboard

## 📦 Installation

First, let's install the package and its dependencies:

In [None]:
!pip install transformers-attention-viz torch transformers pillow requests -q
print("✅ Installation complete!")

## 🚀 Quick Start

Let's visualize how CLIP understands the relationship between images and text:

In [None]:
from transformers import CLIPModel, CLIPProcessor
from attention_viz import AttentionVisualizer
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt

# Load CLIP model
print("Loading CLIP model...")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Create visualizer
visualizer = AttentionVisualizer(model, processor)
print("✅ Model loaded!")

## 🖼️ Example 1: Understanding a Cat Image

Let's see how CLIP processes an image of a cat with descriptive text:

In [None]:
# Load a cat image
url = "https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400"
response = requests.get(url)
image = Image.open(BytesIO(response.content))

# Display the image
plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.axis('off')
plt.title("Input Image")
plt.show()

# Create attention visualization
text = "a fluffy orange cat sitting on a surface"
print(f"Text: '{text}'")

viz = visualizer.visualize(
    image=image,
    text=text,
    visualization_type="heatmap"
)

# Display the attention heatmap
plt.figure(figsize=(10, 8))
plt.imshow(viz.to_image())
plt.axis('off')
plt.tight_layout()
plt.show()

## 📊 Attention Statistics

Let's analyze the attention patterns to see which tokens the model focuses on:

In [None]:
# Get attention statistics
stats = visualizer.get_attention_stats(image, text)

print("🔍 Attention Analysis:")
print("=" * 40)
print(f"Average Entropy: {stats['entropy'].mean():.3f}")
print(f"Attention Concentration: {stats['concentration']:.3f}")
print("\n🏆 Top 5 Most Attended Tokens:")
for i, (token, score) in enumerate(stats['top_tokens'][:5]):
    print(f"{i+1}. '{token}' - {score*100:.1f}% attention")

# Visualize attention distribution
tokens = [t[0] for t in stats['top_tokens'][:10]]
scores = [t[1] for t in stats['top_tokens'][:10]]

plt.figure(figsize=(10, 6))
plt.bar(range(len(tokens)), scores)
plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
plt.ylabel('Attention Weight')
plt.title('Top 10 Tokens by Attention Weight')
plt.tight_layout()
plt.show()

## 🎯 Example 2: Comparing Different Descriptions

Let's see how attention changes with different text descriptions:

In [None]:
descriptions = [
    "a cat",
    "an orange cat",
    "a fluffy cat sitting"
]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, desc in enumerate(descriptions):
    viz = visualizer.visualize(image, desc, visualization_type="heatmap")
    axes[idx].imshow(viz.to_image())
    axes[idx].set_title(f"Text: '{desc}'")
    axes[idx].axis('off')

plt.suptitle("Attention Patterns for Different Descriptions", fontsize=16)
plt.tight_layout()
plt.show()

## 🌟 Example 3: Different Image Types

Let's test with different types of images:

In [None]:
# Test with different images
test_cases = [
    ("https://images.unsplash.com/photo-1547407139-3c921a66005c?w=400", "a golden retriever dog"),
    ("https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=400", "snowy mountain peaks"),
    ("https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=400", "city street with buildings")
]

for img_url, description in test_cases:
    # Load image
    response = requests.get(img_url)
    test_image = Image.open(BytesIO(response.content))
    
    # Create visualization
    viz = visualizer.visualize(test_image, description, visualization_type="heatmap")
    
    # Get stats
    stats = visualizer.get_attention_stats(test_image, description)
    
    # Display
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    ax1.imshow(test_image)
    ax1.set_title(f"Input: '{description}'")
    ax1.axis('off')
    
    ax2.imshow(viz.to_image())
    top_token = stats['top_tokens'][0]
    ax2.set_title(f"Top token: '{top_token[0]}' ({top_token[1]*100:.1f}%)")
    ax2.axis('off')
    
    plt.tight_layout()
    plt.show()
    print("-" * 50)

## 🎛️ Interactive Dashboard

For a more interactive experience, you can launch the Gradio dashboard locally:

In [None]:
# Note: This will not work in Colab, but you can run it locally!
print("To launch the interactive dashboard locally, run:")
print("\nfrom attention_viz import launch_dashboard")
print("launch_dashboard(model, processor)")
print("\nOr from command line:")
print("attention-viz-dashboard")

## 🔧 Advanced Usage

### Analyzing Specific Layers

In [None]:
# Visualize attention from different layers
layers_to_check = [0, 5, -1]  # First, middle, and last layer

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, layer in enumerate(layers_to_check):
    viz = visualizer.visualize(
        image=image,
        text="a cat",
        layer_indices=[layer],
        visualization_type="heatmap"
    )
    axes[idx].imshow(viz.to_image())
    layer_name = "Last" if layer == -1 else f"Layer {layer}"
    axes[idx].set_title(f"{layer_name}")
    axes[idx].axis('off')

plt.suptitle("Attention at Different Layers", fontsize=16)
plt.tight_layout()
plt.show()

## 🤝 Contributing

Found a bug or have a feature request? Please open an issue on GitHub!

**GitHub**: https://github.com/sisird864/transformers-attention-viz

### Current Limitations (v0.1.0)
- ✅ Heatmap visualization (fully working)
- ❌ Flow visualization (debugging dimension issues)
- ❌ Evolution visualization (debugging array indexing)

### Roadmap
- Fix flow and evolution visualizations
- Add support for more models (Flamingo, ALIGN)
- Real-time video attention tracking
- 3D attention visualizations

## 📚 Learn More

- **Documentation**: [GitHub README](https://github.com/sisird864/transformers-attention-viz)
- **PyPI**: [transformers-attention-viz](https://pypi.org/project/transformers-attention-viz/)
- **HuggingFace Models**: [CLIP](https://huggingface.co/openai/clip-vit-base-patch32), [BLIP](https://huggingface.co/Salesforce/blip-image-captioning-base)

Happy visualizing! 🎉