# LLM Interpretability Toolkit - Quick Start

This notebook demonstrates the basic usage of the LLM Interpretability Toolkit.

In [None]:
# Import necessary libraries
import sys
sys.path.append('..')

from src.core import InterpretabilityAnalyzer
import torch
import matplotlib.pyplot as plt
import seaborn as sns

## 1. Initialize the Analyzer

In [None]:
# Initialize with a small model for quick testing
analyzer = InterpretabilityAnalyzer(model_name="distilgpt2")

print(f"Model loaded: {analyzer.model_name}")
print(f"Number of layers: {analyzer.model_wrapper.get_num_layers()}")
print(f"Number of attention heads: {analyzer.model_wrapper.get_num_attention_heads()}")

## 2. Basic Text Analysis

In [None]:
# Analyze a simple sentence
text = "The cat sat on the mat"
results = analyzer.analyze(text, methods=["attention", "importance"])

print("Analysis completed!")
print(f"Available results: {list(results.keys())}")

## 3. Visualize Attention Patterns

In [None]:
# Extract attention data
attention_data = results["attention"]
tokens = attention_data["tokens"][0]  # First sequence

# Get attention weights for first layer, first head
attention_weights = torch.tensor(attention_data["patterns"][0, 0, 0])

# Create heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(
    attention_weights.numpy(),
    xticklabels=tokens,
    yticklabels=tokens,
    cmap="Blues",
    cbar_kws={"label": "Attention Weight"}
)
plt.title("Attention Pattern - Layer 0, Head 0")
plt.xlabel("To Token")
plt.ylabel("From Token")
plt.tight_layout()
plt.show()

## 4. Token Importance Analysis

In [None]:
# Get token importance scores
importance_data = results["importance"]["token_importance"]
tokens = importance_data["tokens"]
importance_scores = torch.tensor(importance_data["importance_mean"])

# Plot token importance
plt.figure(figsize=(10, 6))
plt.bar(range(len(tokens)), importance_scores.numpy())
plt.xticks(range(len(tokens)), tokens, rotation=45)
plt.xlabel("Tokens")
plt.ylabel("Importance Score")
plt.title("Token Importance for Final Prediction")
plt.tight_layout()
plt.show()

## 5. Head Importance Analysis

In [None]:
# Get head importance scores
head_importance = torch.tensor(results["importance"]["head_importance"])

# Create heatmap for head importance
plt.figure(figsize=(10, 8))
sns.heatmap(
    head_importance.numpy(),
    cmap="YlOrRd",
    cbar_kws={"label": "Importance Score"},
    xticklabels=[f"Head {i}" for i in range(head_importance.shape[1])],
    yticklabels=[f"Layer {i}" for i in range(head_importance.shape[0])]
)
plt.title("Attention Head Importance Scores")
plt.xlabel("Attention Head")
plt.ylabel("Layer")
plt.tight_layout()
plt.show()

## 6. Failure Prediction

In [None]:
# Test failure prediction on different texts
test_texts = [
    "The cat sat on the mat",
    "aaaaaaaaaaaaaaaaaaaaaa",  # Repetitive text
    "The the the the the the",  # Repeated words
    "A normal sentence with proper structure and meaning."
]

for text in test_texts:
    prediction = analyzer.predict_failure_probability(text)
    print(f"\nText: '{text[:50]}...'" if len(text) > 50 else f"\nText: '{text}'")
    print(f"Failure probability: {prediction['failure_probability']:.2%}")
    print(f"Risk level: {prediction['prediction']}")
    print(f"Indicators: {', '.join(prediction['indicators']) if prediction['indicators'] else 'None'}")

## 7. Attention Pattern Detection

In [None]:
# Analyze attention head patterns
pattern_results = analyzer.analyze(
    "The cat sat on the mat. The dog sat on the mat.",
    methods=["head_patterns"]
)

patterns = pattern_results["head_patterns"]
print("Identified attention patterns:")
for pattern_type, heads in patterns["identified_patterns"].items():
    print(f"\n{pattern_type.capitalize()} pattern:")
    for layer, head in heads[:5]:  # Show first 5
        print(f"  - Layer {layer}, Head {head}")
    if len(heads) > 5:
        print(f"  ... and {len(heads) - 5} more")

## 8. Batch Analysis

In [None]:
# Analyze multiple texts at once
batch_texts = [
    "The weather is nice today.",
    "Machine learning is fascinating.",
    "Python is a great programming language."
]

batch_results = analyzer.analyze(batch_texts, methods=["attention"])
print(f"Batch size: {batch_results['attention']['shape']['batch_size']}")
print(f"Tokens per sequence: {[len(tokens) for tokens in batch_results['attention']['tokens']]}")

## 9. Using the API

The toolkit also provides a REST API for integration with other applications.

In [None]:
# Example of how to use the API (when running)
import requests
import json

# Note: Start the API server first with: uvicorn src.api.main:app --reload

# Example API request
api_example = {
    "url": "http://localhost:8000/analyze",
    "method": "POST",
    "headers": {"Content-Type": "application/json"},
    "body": {
        "text": "The cat sat on the mat",
        "methods": ["attention", "importance"]
    }
}

print("API Request Example:")
print(json.dumps(api_example, indent=2))