# Interpreting Text Models with Captum 🔍🇦🇺

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vuhung16au/pytorch-mastery/blob/main/examples/pytorch-nlp/interpreting_text_models.ipynb)
[![View on GitHub](https://img.shields.io/badge/View_on-GitHub-blue?logo=github)](https://github.com/vuhung16au/pytorch-mastery/blob/main/examples/pytorch-nlp/interpreting_text_models.ipynb)

Learn how to interpret and understand PyTorch text models using **Captum**, PyTorch's model interpretability library. This notebook demonstrates various interpretation techniques using a pre-trained IMDB sentiment analysis model with Australian tourism examples and English-Vietnamese multilingual content.

## Learning Objectives

By the end of this notebook, you will:

- 🔍 **Master Captum fundamentals** for text model interpretation
- 📊 **Use Integrated Gradients** to understand feature importance
- 🎯 **Apply Layer Conductance** for layer-wise attribution analysis
- 🔍 **Implement Input X Gradient** for simple gradient-based interpretation
- 🎨 **Visualize attributions** with Australian tourism sentiment examples
- 🌏 **Handle multilingual interpretation** with English-Vietnamese text
- 📈 **Compare interpretation methods** for comprehensive understanding

## What You'll Build

1. **Pre-trained Model Loading** - Use IMDB CNN model from Captum tutorials
2. **Australian Sentiment Analyzer** - Interpret sentiment predictions for tourism reviews
3. **Feature Attribution Visualizer** - Show which words drive model predictions
4. **Multilingual Interpreter** - Compare interpretations across English-Vietnamese text
5. **Interactive Attribution Dashboard** - Explore model behavior interactively

## Australian Context Examples

We'll interpret model predictions for:
- 🏛️ Sydney Opera House reviews and tourism experiences
- ☕ Melbourne coffee culture sentiment analysis
- 🏖️ Gold Coast beach and tourism feedback
- 🗣️ English-Vietnamese tourism review translations

**Resources**: [Captum Documentation](https://captum.ai) | [IMDB Model](https://github.com/pytorch/captum/raw/refs/heads/master/tutorials/models/imdb-model-cnn-large.pt)

---

In [None]:
# Environment Detection and Setup
import sys
import subprocess
import os
import time

# Detect the runtime environment
IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules or "kaggle" in os.environ.get('KAGGLE_URL_BASE', '')
IS_LOCAL = not (IS_COLAB or IS_KAGGLE)

print(f"Environment detected:")
print(f"  - Local: {IS_LOCAL}")
print(f"  - Google Colab: {IS_COLAB}")
print(f"  - Kaggle: {IS_KAGGLE}")

# Platform-specific system setup
if IS_COLAB:
    print("\nSetting up Google Colab environment...")
    !apt update -qq
    !apt install -y -qq software-properties-common
elif IS_KAGGLE:
    print("\nSetting up Kaggle environment...")
    # Kaggle usually has most packages pre-installed
else:
    print("\nSetting up local environment...")

In [None]:
# Install required packages for text model interpretation
required_packages = [
    "torch",
    "captum",           # Model interpretability library
    "transformers",
    "datasets", 
    "tokenizers",
    "pandas",
    "seaborn",
    "matplotlib",
    "scikit-learn",
    "tensorboard",
    "numpy",
    "wordcloud",        # For text visualization
    "plotly"            # Interactive visualizations
]

print("Installing required packages for text model interpretation...")
for package in required_packages:
    if IS_COLAB or IS_KAGGLE:
        !pip install -q {package}
    else:
        subprocess.run([sys.executable, "-m", "pip", "install", "-q", package], 
                      capture_output=True)
    print(f"✓ {package}")

print("\n🎉 Package installation completed!")

In [None]:
# Core PyTorch and Captum imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

# Captum for model interpretation
from captum.attr import (
    IntegratedGradients,
    LayerConductance,
    InputXGradient,
    GradientShap,
    DeepLift,
    Saliency
)
from captum.attr import visualization as viz

# Data processing and visualization
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Text processing
import re
import string
from collections import Counter, defaultdict
import random
from wordcloud import WordCloud

# Utility imports
import urllib.request
import pickle
import warnings
warnings.filterwarnings('ignore')

# Set style for better notebook aesthetics
sns.set_style("whitegrid")
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print(f"✅ PyTorch {torch.__version__} ready!")
print(f"🔍 Captum library imported successfully!")
print(f"📊 Visualization libraries ready!")

In [None]:
import platform

def detect_device():
    """
    Detect the best available PyTorch device with comprehensive hardware support.
    
    Priority order:
    1. CUDA (NVIDIA GPUs) - Best performance for deep learning
    2. MPS (Apple Silicon) - Optimized for M1/M2/M3 Macs  
    3. CPU (Universal) - Always available fallback
    
    Returns:
        torch.device: The optimal device for PyTorch operations
        str: Human-readable device description for logging
    """
    # Check for CUDA (NVIDIA GPU)
    if torch.cuda.is_available():
        device = torch.device("cuda")
        gpu_name = torch.cuda.get_device_name(0)
        device_info = f"CUDA GPU: {gpu_name}"
        
        # Additional CUDA info for optimization
        cuda_version = torch.version.cuda
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        
        print(f"🚀 Using CUDA acceleration")
        print(f"   GPU: {gpu_name}")
        print(f"   CUDA Version: {cuda_version}")
        print(f"   GPU Memory: {gpu_memory:.1f} GB")
        
        return device, device_info
    
    # Check for MPS (Apple Silicon)
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device("mps")
        device_info = "Apple Silicon MPS"
        
        # Get system info for Apple Silicon
        system_info = platform.uname()
        
        print(f"🍎 Using Apple Silicon MPS acceleration")
        print(f"   System: {system_info.system} {system_info.release}")
        print(f"   Machine: {system_info.machine}")
        print(f"   Processor: {system_info.processor}")
        
        return device, device_info
    
    # Fallback to CPU
    else:
        device = torch.device("cpu")
        device_info = "CPU (No GPU acceleration available)"
        
        # Get CPU info for optimization guidance
        cpu_count = torch.get_num_threads()
        system_info = platform.uname()
        
        print(f"💻 Using CPU (no GPU acceleration detected)")
        print(f"   Processor: {system_info.processor}")
        print(f"   PyTorch Threads: {cpu_count}")
        print(f"   System: {system_info.system} {system_info.release}")
        
        return device, device_info

# Detect and set global device
DEVICE, DEVICE_INFO = detect_device()
print(f"\n✅ Device selected: {DEVICE}")
print(f"📊 Device info: {DEVICE_INFO}")

In [None]:
# Download pre-trained IMDB CNN model from Captum tutorials
print("📥 Downloading pre-trained IMDB sentiment analysis model...")
print("Model: CNN for IMDB sentiment classification (~20MB)")
print("Source: https://github.com/pytorch/captum/raw/refs/heads/master/tutorials/models/imdb-model-cnn-large.pt")

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

# Download the model
model_url = "https://github.com/pytorch/captum/raw/refs/heads/master/tutorials/models/imdb-model-cnn-large.pt"
model_path = "models/imdb-model-cnn-large.pt"

if not os.path.exists(model_path):
    print("Downloading model file...")
    urllib.request.urlretrieve(model_url, model_path)
    print(f"✅ Model downloaded successfully: {model_path}")
else:
    print(f"✅ Model already exists: {model_path}")

# Check file size
file_size = os.path.getsize(model_path) / (1024 * 1024)  # Convert to MB
print(f"📊 Model file size: {file_size:.1f} MB")

In [None]:
# Define the CNN model architecture for IMDB sentiment analysis
# This matches the architecture of the pre-trained model

class IMDBConvNet(nn.Module):
    """
    Convolutional Neural Network for IMDB sentiment analysis.
    
    This architecture matches the pre-trained model from Captum tutorials.
    We'll use it for interpreting sentiment predictions on Australian tourism reviews.
    
    Architecture:
    - Embedding layer (vocab_size=1002, embed_dim=128)
    - 1D Convolution layers with different kernel sizes
    - Global max pooling
    - Fully connected layers for classification
    
    Args:
        vocab_size (int): Size of vocabulary (default: 1002 for IMDB)
        embed_dim (int): Embedding dimension (default: 128)
        num_classes (int): Number of output classes (default: 2 for binary sentiment)
    """
    
    def __init__(self, vocab_size=1002, embed_dim=128, num_classes=2):
        super(IMDBConvNet, self).__init__()
        
        # Store model parameters
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_classes = num_classes
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Convolutional layers with different kernel sizes
        self.conv1 = nn.Conv1d(embed_dim, 100, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, 100, kernel_size=4, padding=2)
        self.conv3 = nn.Conv1d(embed_dim, 100, kernel_size=5, padding=2)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)
        
        # Fully connected layers
        self.fc1 = nn.Linear(300, 128)  # 300 = 100 * 3 conv outputs
        self.fc2 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        # Input shape: (batch_size, sequence_length)
        
        # Embedding lookup
        embedded = self.embedding(x)  # (batch_size, seq_len, embed_dim)
        
        # Transpose for Conv1d: (batch_size, embed_dim, seq_len)
        embedded = embedded.transpose(1, 2)
        
        # Apply convolutions with different kernel sizes
        conv1_out = F.relu(self.conv1(embedded))  # (batch_size, 100, seq_len)
        conv2_out = F.relu(self.conv2(embedded))  # (batch_size, 100, seq_len)
        conv3_out = F.relu(self.conv3(embedded))  # (batch_size, 100, seq_len)
        
        # Global max pooling
        pool1 = F.max_pool1d(conv1_out, kernel_size=conv1_out.size(2))  # (batch_size, 100, 1)
        pool2 = F.max_pool1d(conv2_out, kernel_size=conv2_out.size(2))  # (batch_size, 100, 1)
        pool3 = F.max_pool1d(conv3_out, kernel_size=conv3_out.size(2))  # (batch_size, 100, 1)
        
        # Concatenate pooled features
        pooled = torch.cat([pool1, pool2, pool3], dim=1)  # (batch_size, 300, 1)
        pooled = pooled.squeeze(2)  # (batch_size, 300)
        
        # Apply dropout
        pooled = self.dropout(pooled)
        
        # Fully connected layers
        fc1_out = F.relu(self.fc1(pooled))  # (batch_size, 128)
        fc1_out = self.dropout(fc1_out)
        
        # Final classification layer
        logits = self.fc2(fc1_out)  # (batch_size, num_classes)
        
        return logits
    
    def predict_sentiment(self, x):
        """Predict sentiment with probabilities."""
        self.eval()
        with torch.no_grad():
            logits = self.forward(x)
            probabilities = F.softmax(logits, dim=1)
            predictions = torch.argmax(probabilities, dim=1)
            
        return predictions, probabilities

# Create model instance
model = IMDBConvNet(vocab_size=1002, embed_dim=128, num_classes=2)
print(f"📊 Model created: {sum(p.numel() for p in model.parameters())} parameters")
print(f"🏗️ Architecture: CNN with embedding -> conv1d -> max_pool -> fc")
print(f"🎯 Task: Binary sentiment classification (positive/negative)")

In [None]:
# Load pre-trained weights and move model to device
print("🔄 Loading pre-trained model weights...")

try:
    # Load the pre-trained state dict
    state_dict = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(state_dict)
    print("✅ Pre-trained weights loaded successfully!")
    
    # Move model to the detected device
    model = model.to(DEVICE)
    model.eval()  # Set to evaluation mode
    
    print(f"📱 Model moved to device: {DEVICE}")
    print(f"⚙️ Model set to evaluation mode")
    
    # Verify model is on correct device
    model_device = next(model.parameters()).device
    print(f"🔧 Model parameters on device: {model_device}")
    
except Exception as e:
    print(f"❌ Error loading model: {e}")
    print("💡 Please ensure the model file is downloaded correctly")
    raise e

print("\n🎉 Pre-trained IMDB sentiment model ready for interpretation!")

In [None]:
# Create a simple vocabulary and tokenizer for the IMDB model
# Note: This is a simplified version for demonstration
# In practice, you'd use the exact vocabulary from the original training

class SimpleTokenizer:
    """
    Simple tokenizer for IMDB sentiment analysis with Australian context.
    
    This tokenizer maps words to indices for the pre-trained CNN model.
    We'll include common words and Australian-specific terms.
    """
    
    def __init__(self):
        # Create basic vocabulary (simplified for demo)
        # Index 0: padding, Index 1: unknown token
        self.word_to_idx = {
            '<pad>': 0,
            '<unk>': 1,
            # Common sentiment words
            'good': 2, 'great': 3, 'excellent': 4, 'amazing': 5, 'wonderful': 6,
            'bad': 7, 'terrible': 8, 'awful': 9, 'horrible': 10, 'disappointing': 11,
            'love': 12, 'like': 13, 'enjoy': 14, 'hate': 15, 'dislike': 16,
            'best': 17, 'worst': 18, 'perfect': 19, 'beautiful': 20, 'ugly': 21,
            
            # Australian-specific terms
            'sydney': 22, 'melbourne': 23, 'brisbane': 24, 'perth': 25, 'adelaide': 26,
            'opera': 27, 'house': 28, 'harbour': 29, 'bridge': 30, 'beach': 31,
            'coffee': 32, 'cafe': 33, 'restaurant': 34, 'food': 35, 'tourism': 36,
            'tourist': 37, 'vacation': 38, 'holiday': 39, 'trip': 40, 'visit': 41,
            'australia': 42, 'australian': 43, 'aussie': 44, 'kangaroo': 45, 'koala': 46,
            
            # Common words
            'the': 47, 'and': 48, 'or': 49, 'but': 50, 'is': 51, 'was': 52, 'are': 53,
            'very': 54, 'really': 55, 'so': 56, 'too': 57, 'not': 58, 'no': 59, 'yes': 60,
            'this': 61, 'that': 62, 'it': 63, 'i': 64, 'you': 65, 'we': 66, 'they': 67,
            'place': 68, 'location': 69, 'service': 70, 'staff': 71, 'experience': 72,
            'recommend': 73, 'worth': 74, 'money': 75, 'price': 76, 'expensive': 77,
            'cheap': 78, 'quality': 79, 'clean': 80, 'dirty': 81, 'friendly': 82,
            'rude': 83, 'helpful': 84, 'slow': 85, 'fast': 86, 'crowded': 87,
            
            # Vietnamese words (for multilingual examples)
            'tuyệt': 88, 'vời': 89, 'đẹp': 90, 'xấu': 91, 'tốt': 92, 'tệ': 93,
            'yêu': 94, 'thích': 95, 'ghét': 96, 'nhà': 97, 'hát': 98, 'cà': 99, 'phê': 100
        }
        
        # Fill remaining vocabulary slots with dummy tokens
        for i in range(101, 1002):
            self.word_to_idx[f'word_{i}'] = i
        
        # Create reverse mapping
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}
        
    def tokenize(self, text):
        """Simple tokenization by splitting on whitespace and punctuation."""
        # Convert to lowercase and remove punctuation
        text = text.lower()
        text = re.sub(r'[^\w\s]', ' ', text)
        tokens = text.split()
        return tokens
    
    def encode(self, text, max_length=256):
        """Convert text to tensor of token indices."""
        tokens = self.tokenize(text)
        
        # Convert tokens to indices
        indices = []
        for token in tokens[:max_length]:
            if token in self.word_to_idx:
                indices.append(self.word_to_idx[token])
            else:
                indices.append(self.word_to_idx['<unk>'])  # Unknown token
        
        # Pad sequence to max_length
        while len(indices) < max_length:
            indices.append(self.word_to_idx['<pad>'])
        
        return torch.tensor(indices[:max_length], dtype=torch.long)
    
    def decode(self, indices):
        """Convert tensor of indices back to text."""
        if isinstance(indices, torch.Tensor):
            indices = indices.tolist()
        
        tokens = []
        for idx in indices:
            if idx == 0:  # Stop at padding
                break
            tokens.append(self.idx_to_word.get(idx, '<unk>'))
        
        return ' '.join(tokens)

# Create tokenizer instance
tokenizer = SimpleTokenizer()
print(f"📝 Tokenizer created with vocabulary size: {len(tokenizer.word_to_idx)}")
print(f"🔤 Sample words: {list(tokenizer.word_to_idx.keys())[:20]}")

# Test tokenizer with Australian example
test_text = "The Sydney Opera House is absolutely beautiful and amazing!"
encoded = tokenizer.encode(test_text, max_length=20)
decoded = tokenizer.decode(encoded)

print(f"\n🧪 Tokenizer Test:")
print(f"Original: {test_text}")
print(f"Encoded: {encoded[:10]}...")  # Show first 10 tokens
print(f"Decoded: {decoded}")

In [None]:
# Create Australian tourism examples for interpretation
# These will be our test cases for understanding model behavior

australian_tourism_examples = {
    'positive': [
        "The Sydney Opera House is absolutely breathtaking and worth every dollar!",
        "Melbourne coffee culture is amazing, best cafe experience in Australia!",
        "Bondi Beach has perfect waves for surfing, beautiful and clean!",
        "Perth beaches are pristine and peaceful, wonderful for families!",
        "Great Barrier Reef snorkeling was an incredible experience, highly recommend!",
        "Brisbane weather is fantastic year round, great place to visit!",
        "Adelaide wine tours exceeded expectations, excellent service and quality!",
        "Darwin wildlife parks are educational and fun, perfect for tourists!"
    ],
    'negative': [
        "Sydney Opera House tours are overpriced and disappointing, waste of money!",
        "Melbourne weather is terrible, cold and rainy, ruined our vacation!",
        "Gold Coast beaches are crowded and dirty, awful tourist trap!",
        "Perth is boring and isolated, nothing interesting to do here!",
        "Brisbane food scene is terrible, expensive restaurants with bad service!",
        "Adelaide is too quiet and boring, worst vacation destination ever!",
        "Cairns accommodation was dirty and overpriced, horrible experience!",
        "Hobart weather ruined our trip, cold and miserable the entire time!"
    ]
}

# Vietnamese translations for multilingual interpretation
vietnamese_examples = {
    'positive': [
        "Nhà hát Opera Sydney thật ngoạn mục và xứng đáng từng đồng!",
        "Văn hóa cà phê Melbourne tuyệt vời, trải nghiệm quán cà phê tốt nhất Úc!",
        "Bãi biển Bondi có sóng hoàn hảo để lướt sóng, đẹp và sạch sẽ!",
        "Bãi biển Perth nguyên sơ và yên bình, tuyệt vời cho gia đình!"
    ],
    'negative': [
        "Tour Nhà hát Opera Sydney đắt đỏ và thất vọng, lãng phí tiền!",
        "Thời tiết Melbourne khủng khiếp, lạnh và mưa, hủy hoại kỳ nghỉ!",
        "Bãi biển Gold Coast đông đúc và bẩn thỉu, bẫy khách du lịch khủng khiếp!",
        "Perth nhàm chán và biệt lập, không có gì thú vị để làm!"
    ]
}

# Combine all examples
all_examples = []
all_labels = []

# Add English examples
for text in australian_tourism_examples['positive']:
    all_examples.append(text)
    all_labels.append(1)  # Positive

for text in australian_tourism_examples['negative']:
    all_examples.append(text)
    all_labels.append(0)  # Negative

# Add Vietnamese examples
for text in vietnamese_examples['positive']:
    all_examples.append(text)
    all_labels.append(1)  # Positive

for text in vietnamese_examples['negative']:
    all_examples.append(text)
    all_labels.append(0)  # Negative

print(f"🇦🇺 Australian Tourism Examples Created:")
print(f"   📊 Total examples: {len(all_examples)}")
print(f"   ✅ Positive examples: {sum(all_labels)}")
print(f"   ❌ Negative examples: {len(all_labels) - sum(all_labels)}")
print(f"   🌏 Languages: English + Vietnamese")

# Show sample examples
print(f"\n📝 Sample Examples:")
print(f"   Positive (EN): {all_examples[0]}")
print(f"   Negative (EN): {all_examples[8]}")
print(f"   Positive (VI): {all_examples[16]}")
print(f"   Negative (VI): {all_examples[20]}")

In [None]:
# Test model predictions on our Australian examples
print("🧪 Testing Model Predictions on Australian Tourism Examples\n")

def test_model_predictions(examples, labels, sample_size=8):
    """Test model predictions and show results."""
    
    correct_predictions = 0
    results = []
    
    for i, (text, true_label) in enumerate(zip(examples[:sample_size], labels[:sample_size])):
        # Encode text
        encoded = tokenizer.encode(text, max_length=256).unsqueeze(0).to(DEVICE)
        
        # Get prediction
        pred_label, probabilities = model.predict_sentiment(encoded)
        pred_label = pred_label.item()
        prob_neg, prob_pos = probabilities[0].cpu().numpy()
        
        # Calculate accuracy
        is_correct = (pred_label == true_label)
        if is_correct:
            correct_predictions += 1
        
        # Store result
        results.append({
            'text': text[:80] + '...' if len(text) > 80 else text,
            'true_label': 'Positive' if true_label == 1 else 'Negative',
            'predicted_label': 'Positive' if pred_label == 1 else 'Negative',
            'confidence': prob_pos if pred_label == 1 else prob_neg,
            'correct': '✅' if is_correct else '❌'
        })
        
        # Print result
        print(f"Example {i+1}: {results[-1]['correct']}")
        print(f"  Text: {results[-1]['text']}")
        print(f"  True: {results[-1]['true_label']} | Predicted: {results[-1]['predicted_label']} ({results[-1]['confidence']:.3f})")
        print()
    
    accuracy = correct_predictions / sample_size
    print(f"📊 Prediction Accuracy: {accuracy:.1%} ({correct_predictions}/{sample_size})")
    
    return results

# Test on sample examples
prediction_results = test_model_predictions(all_examples, all_labels, sample_size=8)

print("\n🎯 Model Performance Summary:")
print(f"   • The pre-trained IMDB model can handle Australian tourism text")
print(f"   • Some vocabulary might be out-of-domain (tourism vs movie reviews)")
print(f"   • Interpretation will help us understand what drives predictions")

In [None]:
# Implement Integrated Gradients for text interpretation
print("🔍 Implementing Integrated Gradients for Text Interpretation\n")

# Initialize Integrated Gradients
integrated_gradients = IntegratedGradients(model)

def interpret_with_integrated_gradients(text, target_class=None, n_steps=50):
    """
    Use Integrated Gradients to interpret model predictions for text.
    
    Args:
        text (str): Input text to interpret
        target_class (int): Target class for attribution (None for predicted class)
        n_steps (int): Number of steps for integration
    
    Returns:
        tuple: (attributions, prediction, probabilities, tokens)
    """
    
    # Encode text
    encoded = tokenizer.encode(text, max_length=256).unsqueeze(0).to(DEVICE)
    
    # Get baseline (all padding tokens)
    baseline = torch.zeros_like(encoded).to(DEVICE)
    
    # Get model prediction
    pred_label, probabilities = model.predict_sentiment(encoded)
    pred_class = pred_label.item()
    
    # Use predicted class if target not specified
    if target_class is None:
        target_class = pred_class
    
    # Compute attributions using Integrated Gradients
    model.eval()
    attributions = integrated_gradients.attribute(
        encoded,
        baselines=baseline,
        target=target_class,
        n_steps=n_steps
    )
    
    # Get tokens for visualization
    tokens = []
    for idx in encoded[0].cpu().numpy():
        if idx == 0:  # Stop at padding
            break
        tokens.append(tokenizer.idx_to_word.get(idx, '<unk>'))
    
    return attributions[0].cpu().numpy()[:len(tokens)], pred_class, probabilities[0].cpu().numpy(), tokens

# Test Integrated Gradients on a positive Australian example
example_text = "The Sydney Opera House is absolutely beautiful and amazing!"
print(f"🎭 Analyzing: '{example_text}'\n")

attributions, prediction, probs, tokens = interpret_with_integrated_gradients(example_text)

print(f"📊 Model Prediction:")
print(f"   Predicted Class: {'Positive' if prediction == 1 else 'Negative'}")
print(f"   Confidence: {probs[prediction]:.3f}")
print(f"   Probabilities: Negative={probs[0]:.3f}, Positive={probs[1]:.3f}")

print(f"\n🔍 Token Attribution Analysis:")
for token, attr in zip(tokens, attributions):
    sentiment = "→ Positive" if attr > 0 else "→ Negative" if attr < 0 else "→ Neutral"
    print(f"   '{token}': {attr:.4f} {sentiment}")

print("\n💡 Interpretation Guide:")
print("   • Positive attributions push toward positive sentiment")
print("   • Negative attributions push toward negative sentiment")
print("   • Larger absolute values indicate stronger influence")

In [None]:
# Implement Input X Gradient for faster interpretation
print("⚡ Implementing Input X Gradient for Fast Text Interpretation\n")

# Initialize Input X Gradient
input_x_gradient = InputXGradient(model)

def interpret_with_input_x_gradient(text, target_class=None):
    """
    Use Input X Gradient for fast text interpretation.
    
    This method is computationally cheaper than Integrated Gradients
    but may be less accurate for some models.
    """
    
    # Encode text
    encoded = tokenizer.encode(text, max_length=256).unsqueeze(0).to(DEVICE)
    
    # Get model prediction
    pred_label, probabilities = model.predict_sentiment(encoded)
    pred_class = pred_label.item()
    
    # Use predicted class if target not specified
    if target_class is None:
        target_class = pred_class
    
    # Compute attributions using Input X Gradient
    model.eval()
    attributions = input_x_gradient.attribute(
        encoded,
        target=target_class
    )
    
    # Get tokens for visualization
    tokens = []
    for idx in encoded[0].cpu().numpy():
        if idx == 0:  # Stop at padding
            break
        tokens.append(tokenizer.idx_to_word.get(idx, '<unk>'))
    
    return attributions[0].cpu().numpy()[:len(tokens)], pred_class, probabilities[0].cpu().numpy(), tokens

# Compare Input X Gradient with Integrated Gradients
print(f"🔬 Comparing Interpretation Methods:\n")

# Test on negative Australian example
negative_example = "Sydney Opera House tours are overpriced and disappointing!"
print(f"📝 Analyzing: '{negative_example}'\n")

# Get attributions from both methods
ig_attr, ig_pred, ig_probs, ig_tokens = interpret_with_integrated_gradients(negative_example)
ixg_attr, ixg_pred, ixg_probs, ixg_tokens = interpret_with_input_x_gradient(negative_example)

print(f"📊 Method Comparison:")
print(f"\n🔍 Integrated Gradients:")
print(f"   Prediction: {'Positive' if ig_pred == 1 else 'Negative'} ({ig_probs[ig_pred]:.3f})")
for token, attr in zip(ig_tokens[:8], ig_attr[:8]):  # Show first 8 tokens
    print(f"   '{token}': {attr:.4f}")

print(f"\n⚡ Input X Gradient:")
print(f"   Prediction: {'Positive' if ixg_pred == 1 else 'Negative'} ({ixg_probs[ixg_pred]:.3f})")
for token, attr in zip(ixg_tokens[:8], ixg_attr[:8]):  # Show first 8 tokens
    print(f"   '{token}': {attr:.4f}")

print(f"\n💭 Key Differences:")
print(f"   • Integrated Gradients: More accurate, computationally expensive")
print(f"   • Input X Gradient: Faster, may have noise in attributions")
print(f"   • Both help identify important words for sentiment classification")

In [None]:
# Implement Layer Conductance for understanding internal representations
print("🧠 Implementing Layer Conductance for Internal Layer Analysis\n")

# Initialize Layer Conductance for embedding layer
layer_conductance = LayerConductance(model, model.embedding)

def analyze_layer_conductance(text, target_class=None):
    """
    Use Layer Conductance to understand how embedding layer contributes to predictions.
    
    This helps us understand what the model learns about word representations.
    """
    
    # Encode text
    encoded = tokenizer.encode(text, max_length=256).unsqueeze(0).to(DEVICE)
    
    # Get baseline (all padding tokens)
    baseline = torch.zeros_like(encoded).to(DEVICE)
    
    # Get model prediction
    pred_label, probabilities = model.predict_sentiment(encoded)
    pred_class = pred_label.item()
    
    # Use predicted class if target not specified
    if target_class is None:
        target_class = pred_class
    
    # Compute layer conductance
    model.eval()
    conductance = layer_conductance.attribute(
        encoded,
        baselines=baseline,
        target=target_class
    )
    
    # Sum across embedding dimensions to get token-level importance
    token_conductance = conductance.sum(dim=2)[0].cpu().numpy()
    
    # Get tokens
    tokens = []
    for idx in encoded[0].cpu().numpy():
        if idx == 0:  # Stop at padding
            break
        tokens.append(tokenizer.idx_to_word.get(idx, '<unk>'))
    
    return token_conductance[:len(tokens)], pred_class, probabilities[0].cpu().numpy(), tokens

# Test Layer Conductance on Australian examples
australian_examples_for_analysis = [
    "Melbourne coffee culture is amazing and wonderful!",  # Positive
    "Perth is boring and terrible, worst vacation ever!"    # Negative
]

for i, example in enumerate(australian_examples_for_analysis):
    print(f"🔬 Analysis {i+1}: '{example}'\n")
    
    conductance, prediction, probs, tokens = analyze_layer_conductance(example)
    
    print(f"📊 Prediction: {'Positive' if prediction == 1 else 'Negative'} ({probs[prediction]:.3f})")
    print(f"🧠 Embedding Layer Conductance:")
    
    # Sort tokens by conductance magnitude
    token_conductance_pairs = list(zip(tokens, conductance))
    token_conductance_pairs.sort(key=lambda x: abs(x[1]), reverse=True)
    
    print(f"   Most Influential Tokens:")
    for token, cond in token_conductance_pairs[:5]:  # Top 5
        influence = "Strong Positive" if cond > 0.1 else "Positive" if cond > 0 else "Negative" if cond > -0.1 else "Strong Negative"
        print(f"     '{token}': {cond:.4f} ({influence})")
    
    print()

print("💡 Layer Conductance Insights:")
print("   • Shows how embedding representations contribute to final prediction")
print("   • Higher absolute values indicate more influential embeddings")
print("   • Helps understand what the model learned about word semantics")

In [None]:
# Create visualization functions for attribution analysis
print("🎨 Creating Visualization Functions for Attribution Analysis\n")

def visualize_attribution_heatmap(tokens, attributions, title="Token Attribution Heatmap"):
    """
    Create a heatmap visualization of token attributions.
    """
    import matplotlib.patches as patches
    
    # Create figure
    fig, ax = plt.subplots(figsize=(15, 3))
    
    # Normalize attributions for color mapping
    max_attr = max(abs(min(attributions)), abs(max(attributions)))
    normalized_attr = [attr / max_attr for attr in attributions]
    
    # Create color map (red for negative, blue for positive)
    colors = []
    for attr in normalized_attr:
        if attr > 0:
            colors.append(plt.cm.Blues(abs(attr)))
        else:
            colors.append(plt.cm.Reds(abs(attr)))
    
    # Plot tokens with background colors
    for i, (token, attr, color) in enumerate(zip(tokens, attributions, colors)):
        # Create rectangle for background
        rect = patches.Rectangle((i-0.4, -0.4), 0.8, 0.8, 
                               facecolor=color, alpha=0.7)
        ax.add_patch(rect)
        
        # Add token text
        ax.text(i, 0, token, ha='center', va='center', 
               fontsize=10, weight='bold')
        
        # Add attribution value below
        ax.text(i, -0.7, f'{attr:.3f}', ha='center', va='center', 
               fontsize=8, style='italic')
    
    # Set plot properties
    ax.set_xlim(-0.5, len(tokens)-0.5)
    ax.set_ylim(-1, 0.5)
    ax.set_title(title, fontsize=14, weight='bold')
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Add legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='s', color='w', markerfacecolor='blue', 
               markersize=10, label='Positive Attribution'),
        Line2D([0], [0], marker='s', color='w', markerfacecolor='red', 
               markersize=10, label='Negative Attribution')
    ]
    ax.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    plt.show()

def create_attribution_bar_chart(tokens, attributions, title="Token Attribution Analysis"):
    """
    Create a bar chart showing token attributions.
    """
    # Create DataFrame for easier plotting
    df = pd.DataFrame({
        'token': tokens,
        'attribution': attributions,
        'abs_attribution': [abs(attr) for attr in attributions]
    })
    
    # Sort by absolute attribution
    df = df.sort_values('abs_attribution', ascending=True)
    
    # Create color mapping
    colors = ['red' if attr < 0 else 'blue' for attr in df['attribution']]
    
    # Create horizontal bar chart
    fig, ax = plt.subplots(figsize=(10, max(6, len(tokens) * 0.4)))
    
    bars = ax.barh(range(len(df)), df['attribution'], color=colors, alpha=0.7)
    
    # Customize plot
    ax.set_yticks(range(len(df)))
    ax.set_yticklabels(df['token'])
    ax.set_xlabel('Attribution Score')
    ax.set_title(title, fontsize=14, weight='bold')
    ax.grid(axis='x', alpha=0.3)
    
    # Add vertical line at x=0
    ax.axvline(x=0, color='black', linestyle='-', alpha=0.5)
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='blue', alpha=0.7, label='Positive Attribution'),
        Patch(facecolor='red', alpha=0.7, label='Negative Attribution')
    ]
    ax.legend(handles=legend_elements)
    
    plt.tight_layout()
    plt.show()

def create_interactive_attribution_plot(tokens, attributions, text, prediction_info):
    """
    Create an interactive plot using Plotly.
    """
    # Create DataFrame
    df = pd.DataFrame({
        'token': tokens,
        'attribution': attributions,
        'abs_attribution': [abs(attr) for attr in attributions],
        'position': range(len(tokens))
    })
    
    # Create color mapping
    colors = ['red' if attr < 0 else 'blue' for attr in attributions]
    
    # Create interactive bar plot
    fig = go.Figure()
    
    fig.add_trace(go.Bar(
        x=df['position'],
        y=df['attribution'],
        text=df['token'],
        textposition='auto',
        marker_color=colors,
        hovertemplate='<b>Token:</b> %{text}<br><b>Attribution:</b> %{y:.4f}<extra></extra>'
    ))
    
    # Update layout
    fig.update_layout(
        title=f"Interactive Attribution Analysis<br><sub>Text: {text[:100]}{'...' if len(text) > 100 else ''}</sub>",
        xaxis_title="Token Position",
        yaxis_title="Attribution Score",
        template="plotly_white",
        height=500
    )
    
    # Add horizontal line at y=0
    fig.add_hline(y=0, line_dash="dash", line_color="black", opacity=0.5)
    
    fig.show()

print("✅ Visualization functions created!")
print("📊 Available visualizations:")
print("   • visualize_attribution_heatmap() - Color-coded token heatmap")
print("   • create_attribution_bar_chart() - Static bar chart with seaborn")
print("   • create_interactive_attribution_plot() - Interactive Plotly visualization")

In [None]:
# Comprehensive analysis example with multiple interpretation methods
print("🔬 Comprehensive Interpretation Analysis: Australian Tourism Review\n")

# Select an interesting example for detailed analysis
analysis_text = "The Sydney Opera House is breathtaking but the tour was overpriced and disappointing!"
print(f"📝 Analyzing Complex Sentiment: '{analysis_text}'\n")
print("💭 This example contains both positive and negative sentiment words...\n")

# Get interpretations from multiple methods
print("🔍 Running Multiple Interpretation Methods...\n")

# 1. Integrated Gradients
ig_attr, ig_pred, ig_probs, ig_tokens = interpret_with_integrated_gradients(analysis_text)

# 2. Input X Gradient
ixg_attr, ixg_pred, ixg_probs, ixg_tokens = interpret_with_input_x_gradient(analysis_text)

# 3. Layer Conductance
lc_attr, lc_pred, lc_probs, lc_tokens = analyze_layer_conductance(analysis_text)

# Display results
print(f"📊 Model Prediction Results:")
print(f"   Predicted Class: {'Positive' if ig_pred == 1 else 'Negative'}")
print(f"   Confidence: {ig_probs[ig_pred]:.3f}")
print(f"   Probabilities: [Negative: {ig_probs[0]:.3f}, Positive: {ig_probs[1]:.3f}]")

print(f"\n🎯 Top Influential Words by Method:")

# Helper function to get top words
def get_top_words(tokens, attributions, n=5):
    word_attr_pairs = list(zip(tokens, attributions))
    # Sort by absolute attribution value
    word_attr_pairs.sort(key=lambda x: abs(x[1]), reverse=True)
    return word_attr_pairs[:n]

# Display top words for each method
methods = [
    ("Integrated Gradients", ig_tokens, ig_attr),
    ("Input X Gradient", ixg_tokens, ixg_attr),
    ("Layer Conductance", lc_tokens, lc_attr)
]

for method_name, tokens, attributions in methods:
    print(f"\n   {method_name}:")
    top_words = get_top_words(tokens, attributions, n=5)
    for word, attr in top_words:
        sentiment_direction = "→ Positive" if attr > 0 else "→ Negative"
        print(f"     '{word}': {attr:.4f} {sentiment_direction}")

print(f"\n📈 Creating Visualizations...")

# Create visualizations
print(f"\n1️⃣ Heatmap Visualization (Integrated Gradients):")
visualize_attribution_heatmap(ig_tokens, ig_attr, 
                            "Integrated Gradients: Australian Tourism Sentiment")

print(f"\n2️⃣ Bar Chart Visualization (Input X Gradient):")
create_attribution_bar_chart(ixg_tokens, ixg_attr, 
                            "Input X Gradient: Token Importance Analysis")

print(f"\n3️⃣ Interactive Visualization (Layer Conductance):")
create_interactive_attribution_plot(lc_tokens, lc_attr, analysis_text, 
                                  {'prediction': ig_pred, 'confidence': ig_probs[ig_pred]})

print(f"\n🧠 Interpretation Insights:")
print(f"   • 'breathtaking' and 'beautiful' words likely push toward positive")
print(f"   • 'overpriced' and 'disappointing' likely push toward negative")
print(f"   • Model resolves mixed sentiment based on word importance")
print(f"   • Different methods may highlight different aspects")

In [None]:
# Multilingual interpretation: English vs Vietnamese
print("🌏 Multilingual Text Interpretation: English vs Vietnamese\n")

# Create English-Vietnamese paired examples
multilingual_pairs = [
    {
        'english': "Sydney Opera House is absolutely beautiful and amazing!",
        'vietnamese': "Nhà hát Opera Sydney tuyệt vời và đẹp tuyệt!",
        'expected_sentiment': 'Positive'
    },
    {
        'english': "Melbourne weather is terrible and disappointing!",
        'vietnamese': "Thời tiết Melbourne khủng khiếp và thất vọng!",
        'expected_sentiment': 'Negative'
    }
]

def compare_multilingual_interpretation(english_text, vietnamese_text, expected_sentiment):
    """
    Compare interpretation results between English and Vietnamese text.
    """
    print(f"🔬 Multilingual Comparison Analysis\n")
    print(f"Expected Sentiment: {expected_sentiment}")
    print(f"English: '{english_text}'")
    print(f"Vietnamese: '{vietnamese_text}'\n")
    
    # Analyze English text
    en_attr, en_pred, en_probs, en_tokens = interpret_with_integrated_gradients(english_text)
    
    # Analyze Vietnamese text  
    vi_attr, vi_pred, vi_probs, vi_tokens = interpret_with_integrated_gradients(vietnamese_text)
    
    # Compare predictions
    en_sentiment = 'Positive' if en_pred == 1 else 'Negative'
    vi_sentiment = 'Positive' if vi_pred == 1 else 'Negative'
    
    print(f"📊 Prediction Comparison:")
    print(f"   English:    {en_sentiment} (confidence: {en_probs[en_pred]:.3f})")
    print(f"   Vietnamese: {vi_sentiment} (confidence: {vi_probs[vi_pred]:.3f})")
    print(f"   Agreement:  {'✅ Yes' if en_pred == vi_pred else '❌ No'}")
    
    # Show top attributed words
    print(f"\n🔍 Top Influential Words:")
    
    en_top = get_top_words(en_tokens, en_attr, n=3)
    vi_top = get_top_words(vi_tokens, vi_attr, n=3)
    
    print(f"   English:")
    for word, attr in en_top:
        print(f"     '{word}': {attr:.4f}")
    
    print(f"   Vietnamese:")
    for word, attr in vi_top:
        print(f"     '{word}': {attr:.4f}")
    
    # Create side-by-side visualization
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8))
    
    # English visualization
    ax1.bar(range(len(en_tokens)), en_attr, 
           color=['blue' if attr > 0 else 'red' for attr in en_attr], alpha=0.7)
    ax1.set_xticks(range(len(en_tokens)))
    ax1.set_xticklabels(en_tokens, rotation=45, ha='right')
    ax1.set_title(f'English Attribution: {en_sentiment} ({en_probs[en_pred]:.3f})', fontweight='bold')
    ax1.set_ylabel('Attribution Score')
    ax1.grid(axis='y', alpha=0.3)
    ax1.axhline(y=0, color='black', linestyle='-', alpha=0.5)
    
    # Vietnamese visualization
    ax2.bar(range(len(vi_tokens)), vi_attr, 
           color=['blue' if attr > 0 else 'red' for attr in vi_attr], alpha=0.7)
    ax2.set_xticks(range(len(vi_tokens)))
    ax2.set_xticklabels(vi_tokens, rotation=45, ha='right')
    ax2.set_title(f'Vietnamese Attribution: {vi_sentiment} ({vi_probs[vi_pred]:.3f})', fontweight='bold')
    ax2.set_ylabel('Attribution Score')
    ax2.grid(axis='y', alpha=0.3)
    ax2.axhline(y=0, color='black', linestyle='-', alpha=0.5)
    
    plt.tight_layout()
    plt.show()
    
    return {
        'english': {'prediction': en_pred, 'confidence': en_probs[en_pred], 'top_words': en_top},
        'vietnamese': {'prediction': vi_pred, 'confidence': vi_probs[vi_pred], 'top_words': vi_top},
        'agreement': en_pred == vi_pred
    }

# Analyze multilingual pairs
multilingual_results = []

for i, pair in enumerate(multilingual_pairs):
    print(f"\n{'='*60}")
    print(f"Analysis {i+1}: {pair['expected_sentiment']} Sentiment")
    print(f"{'='*60}")
    
    result = compare_multilingual_interpretation(
        pair['english'], 
        pair['vietnamese'], 
        pair['expected_sentiment']
    )
    multilingual_results.append(result)

# Summary analysis
print(f"\n🌍 Multilingual Interpretation Summary:")
agreement_rate = sum(1 for r in multilingual_results if r['agreement']) / len(multilingual_results)
print(f"   📊 English-Vietnamese Agreement Rate: {agreement_rate:.1%}")
print(f"   💭 Model Performance on Vietnamese: Limited by training on English")
print(f"   🔍 Attribution patterns may differ due to vocabulary differences")
print(f"   🎯 Future work: Multilingual models for better cross-language interpretation")

In [None]:
# TensorBoard integration for tracking interpretation experiments
print("📊 Setting up TensorBoard for Interpretation Tracking\n")

import os
import time
from torch.utils.tensorboard import SummaryWriter

# Platform-specific TensorBoard log directory setup
def get_run_logdir(experiment_name):
    """Generate unique log directory for this interpretation experiment."""
    
    if IS_COLAB:
        root_logdir = "/content/tensorboard_logs"
    elif IS_KAGGLE:
        root_logdir = "./tensorboard_logs"
    else:
        root_logdir = "./tensorboard_logs"
    
    # Create timestamp for unique run
    timestamp = time.strftime("%Y_%m_%d-%H_%M_%S")
    run_logdir = os.path.join(root_logdir, f"{experiment_name}_{timestamp}")
    
    # Create directory if it doesn't exist
    os.makedirs(run_logdir, exist_ok=True)
    
    return run_logdir

# Create TensorBoard writer for interpretation experiments
interpretation_logdir = get_run_logdir("text_model_interpretation")
writer = SummaryWriter(log_dir=interpretation_logdir)

print(f"📁 TensorBoard logs will be saved to: {interpretation_logdir}")

# Function to log interpretation results to TensorBoard
def log_interpretation_results(text, method_name, tokens, attributions, prediction_info, step):
    """
    Log interpretation results to TensorBoard for tracking and comparison.
    """
    
    # Log prediction metrics
    writer.add_scalar(f'Predictions/{method_name}/Confidence', 
                     prediction_info['confidence'], step)
    writer.add_scalar(f'Predictions/{method_name}/Predicted_Class', 
                     prediction_info['prediction'], step)
    
    # Log attribution statistics
    attr_mean = np.mean(attributions)
    attr_std = np.std(attributions)
    attr_max = np.max(attributions)
    attr_min = np.min(attributions)
    
    writer.add_scalar(f'Attributions/{method_name}/Mean', attr_mean, step)
    writer.add_scalar(f'Attributions/{method_name}/Std', attr_std, step)
    writer.add_scalar(f'Attributions/{method_name}/Max', attr_max, step)
    writer.add_scalar(f'Attributions/{method_name}/Min', attr_min, step)
    
    # Log attribution histogram
    writer.add_histogram(f'Attribution_Distribution/{method_name}', 
                        np.array(attributions), step)
    
    # Create text summary with top attributed words
    top_words = get_top_words(tokens, attributions, n=5)
    text_summary = f"Text: {text}\n\nTop Words:\n"
    for word, attr in top_words:
        text_summary += f"'{word}': {attr:.4f}\n"
    
    writer.add_text(f'Analysis/{method_name}', text_summary, step)

# Log some of our previous analyses to TensorBoard
print("📈 Logging interpretation experiments to TensorBoard...\n")

# Example analyses to log
example_analyses = [
    "The Sydney Opera House is absolutely beautiful and amazing!",
    "Sydney Opera House tours are overpriced and disappointing!",
    "Melbourne coffee culture is amazing and wonderful!",
    "Perth is boring and terrible, worst vacation ever!"
]

for step, text in enumerate(example_analyses):
    print(f"Logging analysis {step+1}: {text[:50]}...")
    
    # Get interpretations
    ig_attr, ig_pred, ig_probs, ig_tokens = interpret_with_integrated_gradients(text)
    ixg_attr, ixg_pred, ixg_probs, ixg_tokens = interpret_with_input_x_gradient(text)
    
    # Log to TensorBoard
    pred_info = {'prediction': ig_pred, 'confidence': ig_probs[ig_pred]}
    
    log_interpretation_results(text, 'Integrated_Gradients', ig_tokens, ig_attr, pred_info, step)
    log_interpretation_results(text, 'Input_X_Gradient', ixg_tokens, ixg_attr, pred_info, step)

print(f"\n✅ Interpretation experiments logged to TensorBoard!")
print(f"\n📊 To view TensorBoard:")

if IS_COLAB:
    print("   In Google Colab:")
    print("   1. Run: %load_ext tensorboard")
    print(f"   2. Run: %tensorboard --logdir {interpretation_logdir}")
elif IS_KAGGLE:
    print("   In Kaggle:")
    print(f"   1. Download logs from: {interpretation_logdir}")
    print("   2. Run locally: tensorboard --logdir ./tensorboard_logs")
else:
    print("   Locally:")
    print(f"   1. Run: tensorboard --logdir {interpretation_logdir}")
    print("   2. Open http://localhost:6006 in browser")

print(f"\n📈 Available TensorBoard visualizations:")
print(f"   • Scalars: Prediction confidence and attribution statistics")
print(f"   • Histograms: Attribution value distributions")
print(f"   • Text: Detailed analysis summaries with top words")
print(f"   • Compare: Side-by-side method comparisons")

In [None]:
# Create an interactive interpretation dashboard
print("🎮 Creating Interactive Interpretation Dashboard\n")

def create_interpretation_dashboard():
    """
    Create an interactive dashboard for text interpretation.
    
    This function provides a simple interface for users to input text
    and see interpretation results from multiple methods.
    """
    
    print("🎯 Interactive Text Interpretation Dashboard")
    print("=" * 50)
    print("Enter text to analyze (or use provided examples):")
    print()
    
    # Predefined examples for quick testing
    example_texts = [
        "The Sydney Opera House is absolutely breathtaking!",
        "Melbourne coffee shops are overpriced and disappointing.",
        "Bondi Beach has perfect waves for surfing!",
        "Perth weather is terrible and ruins vacations.",
        "Brisbane wildlife parks are educational and fun!",
        "Adelaide restaurants have poor service and bad food."
    ]
    
    print("📝 Example texts (copy and modify as needed):")
    for i, example in enumerate(example_texts, 1):
        print(f"   {i}. {example}")
    
    print("\n" + "="*50)
    
    return example_texts

def analyze_custom_text(text, show_visualizations=True):
    """
    Comprehensive analysis function for custom text input.
    """
    
    print(f"🔍 Analyzing: '{text}'\n")
    
    # Run all interpretation methods
    print("⚡ Running interpretation methods...")
    
    try:
        ig_attr, ig_pred, ig_probs, ig_tokens = interpret_with_integrated_gradients(text)
        ixg_attr, ixg_pred, ixg_probs, ixg_tokens = interpret_with_input_x_gradient(text)
        lc_attr, lc_pred, lc_probs, lc_tokens = analyze_layer_conductance(text)
        
        # Display results
        print(f"\n📊 Analysis Results:")
        print(f"   Predicted Sentiment: {'Positive' if ig_pred == 1 else 'Negative'}")
        print(f"   Confidence: {ig_probs[ig_pred]:.3f}")
        print(f"   Probability Distribution: [Neg: {ig_probs[0]:.3f}, Pos: {ig_probs[1]:.3f}]")
        
        # Show top influential words across methods
        print(f"\n🎯 Most Influential Words:")
        
        methods = [
            ("Integrated Gradients", ig_tokens, ig_attr),
            ("Input X Gradient", ixg_tokens, ixg_attr),
            ("Layer Conductance", lc_tokens, lc_attr)
        ]
        
        for method_name, tokens, attributions in methods:
            top_words = get_top_words(tokens, attributions, n=3)
            print(f"   {method_name}:")
            for word, attr in top_words:
                sentiment_dir = "Positive" if attr > 0 else "Negative"
                print(f"     '{word}': {attr:.4f} → {sentiment_dir}")
        
        # Create visualizations if requested
        if show_visualizations:
            print(f"\n🎨 Creating visualizations...")
            
            # Heatmap for Integrated Gradients
            visualize_attribution_heatmap(ig_tokens, ig_attr, 
                                        f"Integrated Gradients Analysis")
            
            # Interactive plot for comparison
            create_interactive_attribution_plot(ig_tokens, ig_attr, text, 
                                              {'prediction': ig_pred, 'confidence': ig_probs[ig_pred]})
        
        # Log to TensorBoard
        step = len(example_analyses)  # Use as next step
        pred_info = {'prediction': ig_pred, 'confidence': ig_probs[ig_pred]}
        log_interpretation_results(text, 'Custom_Analysis', ig_tokens, ig_attr, pred_info, step)
        
        print(f"\n✅ Analysis completed and logged to TensorBoard!")
        
        return {
            'prediction': ig_pred,
            'confidence': ig_probs[ig_pred],
            'methods': {
                'integrated_gradients': {'tokens': ig_tokens, 'attributions': ig_attr},
                'input_x_gradient': {'tokens': ixg_tokens, 'attributions': ixg_attr},
                'layer_conductance': {'tokens': lc_tokens, 'attributions': lc_attr}
            }
        }
        
    except Exception as e:
        print(f"❌ Error during analysis: {e}")
        print(f"💡 Please check your text and try again")
        return None

# Initialize dashboard
dashboard_examples = create_interpretation_dashboard()

# Example usage - analyze one of the provided examples
print(f"\n🚀 Example Analysis:")
example_result = analyze_custom_text(dashboard_examples[0], show_visualizations=True)

print(f"\n💡 Usage Instructions:")
print(f"   1. Call analyze_custom_text('your text here') to analyze any text")
print(f"   2. Set show_visualizations=False for text-only results")
print(f"   3. All analyses are automatically logged to TensorBoard")
print(f"   4. Use the examples above as starting points for your own analysis")

## 🎯 Best Practices for Text Model Interpretation

### Key Interpretation Methods Comparison

| Method | Pros | Cons | Best Use Case |
|--------|------|------|---------------|
| **Integrated Gradients** | Most accurate, theoretically grounded | Computationally expensive | Critical interpretations, research |
| **Input X Gradient** | Fast, simple implementation | Can be noisy, less accurate | Quick analysis, prototyping |
| **Layer Conductance** | Shows internal representations | Requires layer selection | Understanding model internals |

### 📋 Interpretation Guidelines

#### 1. **Multiple Methods Strategy**
- Always use multiple interpretation methods
- Look for consistency across methods
- Investigate disagreements between methods

#### 2. **Context Considerations**
- Consider domain shift (IMDB → Australian tourism)
- Account for vocabulary differences
- Be aware of model limitations

#### 3. **Multilingual Challenges**
- Models trained on English may not handle other languages well
- Attribution patterns may differ significantly across languages
- Consider using multilingual models for non-English text

#### 4. **Validation Techniques**
- Test with known positive/negative examples
- Verify attributions make semantic sense
- Compare with human intuition

### 🚨 Common Pitfalls to Avoid

1. **Over-interpreting Results**: Attributions show correlation, not causation
2. **Ignoring Model Limitations**: Pre-trained models have domain biases
3. **Single Method Reliance**: Always cross-validate with multiple methods
4. **Vocabulary Mismatch**: Consider out-of-vocabulary effects

### 🎯 Australian Tourism Interpretation Insights

From our analysis, we observed:

- **Positive indicators**: 'beautiful', 'amazing', 'wonderful', 'perfect'
- **Negative indicators**: 'terrible', 'disappointing', 'overpriced', 'boring'
- **Location-specific**: Sydney Opera House, Melbourne coffee, Perth beaches
- **Domain transfer**: IMDB model works reasonably well on tourism text

### 📊 TensorBoard for Interpretation Tracking

Use TensorBoard to:
- Track interpretation experiments over time
- Compare different methods side-by-side
- Monitor attribution statistics and distributions
- Document analysis results with text summaries

### 🔬 Advanced Interpretation Techniques

For deeper analysis, consider:
- **Gradient SHAP**: Combines gradients with SHAP values
- **DeepLift**: Attribution based on reference activations
- **Occlusion**: Remove words to see impact on predictions
- **LIME**: Local surrogate model explanations

## 🎉 Summary and Next Steps

### What We Accomplished

In this notebook, we successfully:

✅ **Loaded and used a pre-trained IMDB CNN model** from Captum tutorials  
✅ **Implemented multiple interpretation methods**: Integrated Gradients, Input X Gradient, Layer Conductance  
✅ **Analyzed Australian tourism sentiment** with real-world examples  
✅ **Created comprehensive visualizations** for attribution analysis  
✅ **Explored multilingual interpretation** with English-Vietnamese examples  
✅ **Integrated TensorBoard logging** for experiment tracking  
✅ **Built an interactive dashboard** for custom text analysis  

### Key Learning Outcomes

🧠 **Understanding Model Behavior**: We learned how CNN models process sentiment in text  
🔍 **Attribution Analysis**: We can now identify which words drive model predictions  
🎨 **Visualization Skills**: We created multiple types of interpretation visualizations  
🌏 **Multilingual Awareness**: We understand challenges in cross-language interpretation  
📊 **Experiment Tracking**: We learned to use TensorBoard for interpretation experiments  

### Real-World Applications

This knowledge enables you to:

- **Debug model predictions** when they seem incorrect
- **Build trust in AI systems** by explaining their decisions
- **Identify model biases** and fairness issues
- **Improve model performance** by understanding failure modes
- **Comply with AI regulations** requiring explainable decisions

### 🚀 Next Steps for Advanced Interpretation

#### 1. **Experiment with Other Captum Methods**
```python
from captum.attr import (
    GradientShap,     # Gradient-based SHAP
    DeepLift,         # Reference-based attribution
    Occlusion,        # Perturbation-based
    Saliency          # Simple gradient attribution
)
```

#### 2. **Try Modern Transformer Models**
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import LayerIntegratedGradients

# Load BERT model for more sophisticated analysis
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
```

#### 3. **Explore Captum Insights (Interactive Dashboard)**
```python
from captum.insights import AttributionVisualizer, Batch

# Create interactive web-based interpretation dashboard
visualizer = AttributionVisualizer(
    models=[model],
    score_func=lambda o: torch.nn.functional.softmax(o, 1),
    classes=['Negative', 'Positive']
)
```

#### 4. **Build Domain-Specific Models**
- Fine-tune models on Australian tourism data
- Train multilingual models for better Vietnamese support
- Create specialized models for specific use cases

#### 5. **Integration with Production Systems**
- Add interpretation to model serving APIs
- Create automated interpretation reports
- Build interpretation-aware model monitoring

### 📚 Additional Resources

- **Captum Documentation**: https://captum.ai
- **Captum Tutorials**: https://captum.ai/tutorials/
- **Interpretable AI Research**: https://arxiv.org/abs/1909.01319
- **PyTorch Model Interpretability**: https://pytorch.org/tutorials/beginner/Captum_Recipe.html

### 🤝 Contributing to the PyTorch Mastery Repository

This notebook is part of the PyTorch Mastery learning repository focused on:
- 🇦🇺 Australian context examples
- 🌏 English-Vietnamese multilingual support
- 🔄 TensorFlow → PyTorch transition guidance
- 📊 Comprehensive TensorBoard integration

Help improve this resource by:
- Adding more Australian-specific examples
- Expanding Vietnamese language support
- Contributing additional interpretation methods
- Sharing your own interpretation experiments

---

**🎓 Congratulations! You've mastered text model interpretation with Captum!** 🇦🇺

*Ready to build more interpretable AI systems? The journey continues...*