# FUSED Framework: Model Interpretability Example

This notebook demonstrates how to use the interpretability tools in the FUSED framework to understand and visualize model predictions.

## Setup

First, let's import the necessary libraries:

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset

# Import FUSED utilities
from fused.utils.interpretability import FeatureImportance, AttentionVisualization
from fused.models import SequentialEncoder, TransformerEncoder, TemporalFusionModel

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Enable interactive plotting
%matplotlib inline

## Create a Simple Model

Let's create a simple model that we can interpret:

In [None]:
class InterpretableModel(nn.Module):
    """A simple model with attention for interpretability."""
    
    def __init__(self, input_dim=5, seq_length=10):
        super().__init__()
        
        self.seq_encoder = TransformerEncoder(
            input_dim=input_dim,
            hidden_dim=32,
            output_dim=32,
            num_layers=2,
            num_heads=4,
            dropout=0.1
        )
        
        self.classifier = nn.Linear(32, 2)
        
    def forward(self, x):
        # x has shape (batch_size, seq_length, input_dim)
        encoded, attention = self.seq_encoder(x, return_attention=True)
        logits = self.classifier(encoded)
        
        return {"logits": logits, "attention": attention}

# Create sample data
def generate_sample_data(n_samples=100, seq_length=10, input_dim=5):
    # Generate random sequences
    X = torch.randn(n_samples, seq_length, input_dim)
    
    # Create a pattern where the first dimension in the first half of the sequence
    # and the second dimension in the second half determine the class
    X[:, :seq_length//2, 0] += torch.rand(n_samples, seq_length//2) * 2
    X[:, seq_length//2:, 1] += torch.rand(n_samples, seq_length - seq_length//2) * 2
    
    # Generate labels
    first_feature_sum = X[:, :seq_length//2, 0].sum(dim=1)
    second_feature_sum = X[:, seq_length//2:, 1].sum(dim=1)
    y = ((first_feature_sum + second_feature_sum) > 0).long()
    
    return X, y

# Generate data
X, y = generate_sample_data(n_samples=100)
print(f"X shape: {X.shape}, y shape: {y.shape}")

# Initialize model
model = InterpretableModel(input_dim=X.shape[2], seq_length=X.shape[1])

# Pretend we've trained the model (for demo purposes)
def simulate_trained_model(model, X, y):
    # Create a simple pattern in the model weights to simulate training
    with torch.no_grad():
        # Set weights to make the model pay attention to the first feature in the first half
        # and the second feature in the second half
        for name, param in model.named_parameters():
            if "self_attn" in name and "weight" in name:
                param.data = torch.randn_like(param.data) * 0.01
                
        # Make the classifier weights emphasize the correct features
        model.classifier.weight.data[0, :] = torch.randn(32) * 0.01
        model.classifier.weight.data[1, :] = torch.randn(32) * 0.01
        model.classifier.weight.data[0, 0] = 0.5  # Class 0 correlation with feature 0
        model.classifier.weight.data[1, 1] = 0.5  # Class 1 correlation with feature 1
    
    return model

# Simulate a trained model
model = simulate_trained_model(model, X, y)

# Test the model
model.eval()
with torch.no_grad():
    outputs = model(X)
    preds = outputs["logits"].argmax(dim=1)
    accuracy = (preds == y).float().mean().item()
    
print(f"Model accuracy: {accuracy:.4f}")

## Feature Importance Analysis

Now, let's analyze the feature importance using permutation importance:

In [None]:
# Create a feature importance analyzer
analyzer = FeatureImportance(model)

# Define a metric function for classification
def accuracy_metric(y_true, y_pred):
    return (y_pred.argmax(dim=1) == y_true).float().mean().item()

# Calculate permutation importance
feature_names = [f"Feature_{i}" for i in range(X.shape[2])]
perm_importance = analyzer.permutation_importance(
    X, y, accuracy_metric, n_repeats=10,
    feature_names=feature_names
)

# Plot feature importance
fig = analyzer.plot_feature_importance(
    perm_importance, 
    title="Feature Importance (Permutation)",
    sort=True
)
plt.tight_layout()
plt.show()

## Integrated Gradients Analysis

Let's also analyze feature importance using integrated gradients:

In [None]:
# Calculate integrated gradients for a single sample
sample_idx = 0
sample_X = X[sample_idx:sample_idx+1]
sample_y = y[sample_idx:sample_idx+1]

# Get integrated gradients
integrated_grads = analyzer.integrated_gradients(
    sample_X, n_steps=50,
    target_class=sample_y.item(),
    feature_names=feature_names
)

# Plot integrated gradients
# Reshape attributions to match sequence structure
seq_length = X.shape[1]
n_features = X.shape[2]
attributions = integrated_grads["importances"].reshape(seq_length, n_features)

plt.figure(figsize=(12, 6))
sns.heatmap(attributions, cmap="coolwarm", center=0, 
            xticklabels=feature_names,
            yticklabels=[f"Step {i}" for i in range(seq_length)])
plt.title(f"Integrated Gradients for Sample {sample_idx} (Class {sample_y.item()})")
plt.xlabel("Features")
plt.ylabel("Sequence Steps")
plt.tight_layout()
plt.show()

## Attention Visualization

Now, let's visualize the attention patterns in our transformer model:

In [None]:
# Create attention visualizer
attention_viz = AttentionVisualization(model)

# Register hooks to capture attention weights
attention_viz.register_hooks()

# Get attention maps for a batch of samples
sample_batch = X[:5]  # Use first 5 samples
with torch.no_grad():
    _ = model(sample_batch)  # Forward pass to trigger hooks
    
# Get attention maps
attention_maps = attention_viz.get_attention_maps(sample_batch)

# Plot attention heatmap for the first sample and first head
sample_idx = 0
for layer_name in attention_maps.keys():
    for head_idx in range(4):  # 4 attention heads
        fig = attention_viz.plot_attention_heatmap(
            attention_maps, 
            layer_name=layer_name, 
            head_idx=head_idx, 
            sample_idx=sample_idx,
            title=f"{layer_name} - Head {head_idx}"
        )
        plt.show()

# Remove hooks
attention_viz.remove_hooks()

## Temporal Feature Attribution

Let's analyze which time steps are most important for the prediction:

In [None]:
# Calculate attribution for each time step
def analyze_temporal_importance(model, X, y):
    model.eval()
    importance_scores = []
    
    # Original accuracy
    with torch.no_grad():
        outputs = model(X)
        preds = outputs["logits"].argmax(dim=1)
        original_acc = (preds == y).float().mean().item()
    
    # Mask each time step and check impact
    seq_length = X.shape[1]
    for t in range(seq_length):
        # Create a copy with the time step masked
        X_masked = X.clone()
        X_masked[:, t, :] = 0  # Mask with zeros
        
        with torch.no_grad():
            outputs = model(X_masked)
            preds = outputs["logits"].argmax(dim=1)
            masked_acc = (preds == y).float().mean().item()
        
        # Calculate importance as decrease in accuracy
        importance = original_acc - masked_acc
        importance_scores.append(importance)
    
    return importance_scores

# Analyze temporal importance
temporal_importance = analyze_temporal_importance(model, X, y)

# Plot importance of each time step
plt.figure(figsize=(10, 6))
plt.bar(range(len(temporal_importance)), temporal_importance)
plt.xlabel("Time Step")
plt.ylabel("Importance (Decrease in Accuracy)")
plt.title("Temporal Feature Importance")
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

## Interpret Model Predictions for a Single Sample

Finally, let's combine these approaches to provide a comprehensive interpretation of predictions for a single sample:

In [None]:
def interpret_sample(model, X_sample, y_sample, feature_names):
    """Combine multiple interpretation methods for a single sample."""
    model.eval()
    analyzer = FeatureImportance(model)
    
    # Get prediction
    with torch.no_grad():
        outputs = model(X_sample)
        logits = outputs["logits"]
        probs = torch.softmax(logits, dim=1)
        pred_class = logits.argmax(dim=1).item()
    
    print(f"Sample true class: {y_sample.item()}")
    print(f"Predicted class: {pred_class}")
    print(f"Prediction probabilities: {probs[0].numpy()}")
    
    # Get integrated gradients
    ig_results = analyzer.integrated_gradients(
        X_sample, n_steps=50,
        target_class=pred_class,
        feature_names=feature_names
    )
    
    # Reshape attributions
    seq_length = X_sample.shape[1]
    n_features = X_sample.shape[2]
    attributions = ig_results["importances"].reshape(seq_length, n_features)
    
    # Visualization
    plt.figure(figsize=(15, 10))
    
    # Plot the input features
    plt.subplot(2, 1, 1)
    sns.heatmap(X_sample[0].numpy(), cmap="viridis",
                xticklabels=feature_names, 
                yticklabels=[f"Step {i}" for i in range(seq_length)])
    plt.title("Input Features")
    plt.xlabel("Features")
    plt.ylabel("Time Steps")
    
    # Plot attributions
    plt.subplot(2, 1, 2)
    sns.heatmap(attributions, cmap="coolwarm", center=0,
                xticklabels=feature_names, 
                yticklabels=[f"Step {i}" for i in range(seq_length)])
    plt.title(f"Feature Attributions for Class {pred_class}")
    plt.xlabel("Features")
    plt.ylabel("Time Steps")
    
    plt.tight_layout()
    plt.show()
    
    # Get attention maps
    attention_viz = AttentionVisualization(model)
    attention_viz.register_hooks()
    
    # Forward pass to get attention
    with torch.no_grad():
        _ = model(X_sample)
    
    # Get and plot attention maps
    attention_maps = attention_viz.get_attention_maps(X_sample)
    for layer_name in attention_maps.keys():
        if "self_attn" in layer_name:
            fig = attention_viz.plot_attention_heatmap(
                attention_maps, 
                layer_name=layer_name, 
                head_idx=0,  # First head
                sample_idx=0,
                title=f"{layer_name} Attention"
            )
            plt.show()
    
    # Remove hooks
    attention_viz.remove_hooks()

# Select a sample to interpret
sample_idx = 10
interpret_sample(
    model, 
    X[sample_idx:sample_idx+1], 
    y[sample_idx:sample_idx+1],
    feature_names
)

## Conclusion

In this notebook, we've demonstrated how to use the interpretability tools in the FUSED framework to understand and visualize model predictions. These tools help to:

1. Identify which features are most important for predictions
2. Understand how the model attends to different parts of the input sequence
3. Visualize feature attributions for individual predictions
4. Analyze the temporal importance of different time steps

By using these interpretability techniques, you can gain insights into how your models work, debug issues, and increase trust in model predictions.