# ZORRO Explainer Tutorial for HMS Multi-Modal GNN

This notebook demonstrates how to use the ZORRO (Zero-Order Rank-based Relative Output) explainer
to interpret predictions from the multi-modal GNN model for HMS brain activity classification.

ZORRO identifies which nodes and node features are most responsible for model predictions
by perturbing graph elements and measuring changes in model outputs.

## 1. Import Required Libraries

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from typing import List, Dict, Optional
from pathlib import Path

# PyTorch Geometric
from torch_geometric.data import Batch, Data
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch
import seaborn as sns
from tqdm import tqdm

# Project imports
from src.models import HMSMultiModalGNN
from src.models.zorro_explainer import ZORROExplainer, ZORROExplanation
from examples.zorro_explainer_example import (
    explain_hms_predictions,
    print_explanation,
    compare_modalities,
)

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)

print("✓ All libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

## 2. Load and Prepare the Trained GNN Model

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model with default architecture
model = HMSMultiModalGNN(
    eeg_config={
        "in_channels": 5,
        "gat_hidden_dim": 64,
        "gat_out_dim": 64,
        "gat_num_layers": 2,
        "gat_heads": 4,
        "gat_dropout": 0.3,
        "use_edge_attr": True,
        "rnn_hidden_dim": 128,
        "rnn_num_layers": 2,
        "rnn_dropout": 0.2,
        "bidirectional": True,
        "pooling_method": "mean",
    },
    spec_config={
        "in_channels": 5,
        "gat_hidden_dim": 64,
        "gat_out_dim": 64,
        "gat_num_layers": 2,
        "gat_heads": 4,
        "gat_dropout": 0.3,
        "use_edge_attr": False,
        "rnn_hidden_dim": 128,
        "rnn_num_layers": 2,
        "rnn_dropout": 0.2,
        "bidirectional": True,
        "pooling_method": "mean",
    },
    num_classes=6,
)

model.to(device)
model.eval()

# Display model info
model_info = model.get_model_info()
print("\n" + "="*60)
print("Model Architecture Information")
print("="*60)
for key, value in model_info.items():
    if isinstance(value, int) and value > 1e6:
        print(f"  {key:.<30} {value/1e6:.2f}M")
    else:
        print(f"  {key:.<30} {value}")

### Optional: Load Pretrained Weights

If you have a trained model checkpoint, load it here:

In [None]:
# Option 1: Load from checkpoint (if available)
checkpoint_path = Path("../checkpoints/best_model.pt")  # Update this path

if checkpoint_path.exists():
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Handle different checkpoint formats
    if isinstance(checkpoint, dict):
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
    else:
        model.load_state_dict(checkpoint)
    
    print("✓ Checkpoint loaded successfully")
else:
    print(f"⚠ Checkpoint not found at {checkpoint_path}")
    print("  Using randomly initialized model for demonstration")

## 3. Create Sample Data for Explanation

Create dummy graph data to demonstrate the explainer. In practice, you would load real data.

In [None]:
def create_dummy_graphs(
    batch_size: int = 2,
    num_temporal_steps: int = 9,
    nodes_per_graph: int = 9,
    num_features: int = 5,
    device: torch.device = torch.device('cpu'),
) -> List[Batch]:
    """Create dummy graphs for demonstration.
    
    Parameters
    ----------
    batch_size : int
        Number of samples in batch
    num_temporal_steps : int
        Number of temporal graphs (9 for EEG, 119 for spectrogram)
    nodes_per_graph : int
        Number of nodes per graph (typically 9 for EEG channels)
    num_features : int
        Number of features per node (e.g., 5 band powers)
    device : torch.device
        Device to place tensors on
    
    Returns
    -------
    List[Batch]
        List of batched graph data
    """
    graphs = []
    
    for step in range(num_temporal_steps):
        # Node features
        x = torch.randn(
            batch_size * nodes_per_graph,
            num_features,
            device=device,
        )
        # Normalize features
        x = (x - x.mean()) / (x.std() + 1e-8)
        
        # Edge index (fully connected graph for simplicity)
        edge_index_list = []
        for sample in range(batch_size):
            node_offset = sample * nodes_per_graph
            nodes = np.arange(nodes_per_graph) + node_offset
            
            # Create edges (e.g., nearest neighbor + self loops)
            edges_src = []
            edges_dst = []
            
            for i in nodes:
                for j in nodes:
                    if i != j and np.random.rand() < 0.5:  # Sparse connections
                        edges_src.append(i)
                        edges_dst.append(j)
            
            # Self loops
            edges_src.extend(nodes)
            edges_dst.extend(nodes)
            
            edge_index_list.append(np.array([edges_src, edges_dst]))
        
        if edge_index_list:
            edge_index = np.hstack(edge_index_list)
        else:
            edge_index = np.zeros((2, 0))
        
        edge_index = torch.tensor(edge_index, dtype=torch.long, device=device)
        
        # Batch assignment
        batch = torch.repeat_interleave(
            torch.arange(batch_size, device=device),
            nodes_per_graph,
        )
        
        # Create batch
        graph_batch = Batch(x=x, edge_index=edge_index, batch=batch)
        graphs.append(graph_batch)
    
    return graphs

# Create dummy data
batch_size = 2
eeg_graphs = create_dummy_graphs(
    batch_size=batch_size,
    num_temporal_steps=9,
    nodes_per_graph=9,
    num_features=5,
    device=device,
)

spec_graphs = create_dummy_graphs(
    batch_size=batch_size,
    num_temporal_steps=119,
    nodes_per_graph=9,
    num_features=5,
    device=device,
)

print(f"✓ Created dummy data")
print(f"  Batch size: {batch_size}")
print(f"  EEG graphs: {len(eeg_graphs)} temporal steps")
print(f"  Spectrogram graphs: {len(spec_graphs)} temporal steps")
print(f"  Nodes per graph: {eeg_graphs[0].x.shape[0] // batch_size}")
print(f"  Features per node: {eeg_graphs[0].x.shape[1]}")

## 4. Initialize ZORRO Explainer

In [None]:
# Create explainer
explainer = ZORROExplainer(
    model=model,
    target_class=None,  # Will use the predicted class
    device=device,
    perturbation_mode="zero",  # Options: "zero", "noise", "mean"
    noise_std=0.1,  # Only used if perturbation_mode="noise"
)

print("✓ ZORRO Explainer initialized")
print(f"  Perturbation mode: {explainer.perturbation_mode}")
print(f"  Device: {explainer.device}")

## 5. Extract Node Importance Scores

Compute importance scores for each node by perturbing them and measuring prediction changes.

In [None]:
# Explain first sample in batch
sample_idx = 0

print(f"\nExplaining sample {sample_idx}...")
print("="*60)

# Get EEG explanation
eeg_explanation = explainer.explain_sample(
    graphs=eeg_graphs,
    modality="eeg",
    sample_idx=sample_idx,
    top_k=10,
    n_samples=5,  # Number of perturbation samples
    pbar=True,
)

print(f"\n✓ EEG explanation computed")
print(f"  Total nodes: {len(eeg_explanation.node_indices)}")
print(f"  Node importance shape: {eeg_explanation.node_importance.shape}")
print(f"  Top-5 important nodes:")
for rank, (node_idx, importance) in enumerate(eeg_explanation.top_k_nodes[:5], 1):
    print(f"    {rank}. Node {node_idx:3d}: {importance:.4f}")

## 6. Extract Feature Importance Scores

Aggregate node importance to identify the most influential node features.

In [None]:
# Get spectrogram explanation
spec_explanation = explainer.explain_sample(
    graphs=spec_graphs,
    modality="spec",
    sample_idx=sample_idx,
    top_k=10,
    n_samples=5,
    pbar=True,
)

print(f"\n✓ Spectrogram explanation computed")
print(f"  Total nodes: {len(spec_explanation.node_indices)}")
print(f"  Node importance shape: {spec_explanation.node_importance.shape}")
print(f"  Top-5 important nodes:")
for rank, (node_idx, importance) in enumerate(spec_explanation.top_k_nodes[:5], 1):
    print(f"    {rank}. Node {node_idx:3d}: {importance:.4f}")

## 7. Visualize Explanations

In [None]:
# Print detailed explanations
print_explanation(eeg_explanation, "EEG Modality")
print_explanation(spec_explanation, "Spectrogram Modality")

In [None]:
# Visualize feature importance comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# EEG feature importance
eeg_feat_imp = eeg_explanation.feature_importance.cpu().numpy()
axes[0].bar(range(len(eeg_feat_imp)), eeg_feat_imp, color='steelblue', alpha=0.8)
axes[0].set_xlabel('Feature Index', fontsize=12)
axes[0].set_ylabel('Importance Score', fontsize=12)
axes[0].set_title('EEG Feature Importance', fontsize=14, fontweight='bold')
axes[0].grid(axis='y', alpha=0.3)

# Spectrogram feature importance
spec_feat_imp = spec_explanation.feature_importance.cpu().numpy()
axes[1].bar(range(len(spec_feat_imp)), spec_feat_imp, color='coral', alpha=0.8)
axes[1].set_xlabel('Feature Index', fontsize=12)
axes[1].set_ylabel('Importance Score', fontsize=12)
axes[1].set_title('Spectrogram Feature Importance', fontsize=14, fontweight='bold')
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("✓ Feature importance visualization complete")

In [None]:
# Visualize top-k nodes comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Extract top-k nodes for visualization
k = 10
eeg_top_nodes = eeg_explanation.top_k_nodes[:k]
spec_top_nodes = spec_explanation.top_k_nodes[:k]

# EEG top nodes
eeg_node_ids = [str(node_idx) for node_idx, _ in eeg_top_nodes]
eeg_importances = [importance for _, importance in eeg_top_nodes]

y_pos = np.arange(len(eeg_node_ids))
axes[0].barh(y_pos, eeg_importances, color='steelblue', alpha=0.8)
axes[0].set_yticks(y_pos)
axes[0].set_yticklabels(eeg_node_ids)
axes[0].set_xlabel('Importance Score', fontsize=11)
axes[0].set_ylabel('Node Index', fontsize=11)
axes[0].set_title('Top-10 Important EEG Nodes', fontsize=12, fontweight='bold')
axes[0].invert_yaxis()
axes[0].grid(axis='x', alpha=0.3)

# Spectrogram top nodes
spec_node_ids = [str(node_idx) for node_idx, _ in spec_top_nodes]
spec_importances = [importance for _, importance in spec_top_nodes]

y_pos = np.arange(len(spec_node_ids))
axes[1].barh(y_pos, spec_importances, color='coral', alpha=0.8)
axes[1].set_yticks(y_pos)
axes[1].set_yticklabels(spec_node_ids)
axes[1].set_xlabel('Importance Score', fontsize=11)
axes[1].set_ylabel('Node Index', fontsize=11)
axes[1].set_title('Top-10 Important Spectrogram Nodes', fontsize=12, fontweight='bold')
axes[1].invert_yaxis()
axes[1].grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

print("✓ Top-k nodes visualization complete")

In [None]:
# Visualize node importance heatmap (node x feature)
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# EEG importance heatmap (top nodes only)
eeg_importance = eeg_explanation.node_importance.cpu().numpy()
top_node_indices = [node_idx for node_idx, _ in eeg_explanation.top_k_nodes[:10]]
eeg_importance_top = eeg_importance[top_node_indices]

im1 = axes[0].imshow(eeg_importance_top, cmap='Blues', aspect='auto')
axes[0].set_xlabel('Feature Index', fontsize=11)
axes[0].set_ylabel('Node (Top-10)', fontsize=11)
axes[0].set_title('EEG Node-Feature Importance Heatmap', fontsize=12, fontweight='bold')
axes[0].set_xticks(range(eeg_importance_top.shape[1]))
axes[0].set_yticks(range(len(top_node_indices)))
axes[0].set_yticklabels([f"Node {idx}" for idx in top_node_indices])
plt.colorbar(im1, ax=axes[0], label='Importance')

# Spectrogram importance heatmap (top nodes only)
spec_importance = spec_explanation.node_importance.cpu().numpy()
top_node_indices = [node_idx for node_idx, _ in spec_explanation.top_k_nodes[:10]]
spec_importance_top = spec_importance[top_node_indices]

im2 = axes[1].imshow(spec_importance_top, cmap='Oranges', aspect='auto')
axes[1].set_xlabel('Feature Index', fontsize=11)
axes[1].set_ylabel('Node (Top-10)', fontsize=11)
axes[1].set_title('Spectrogram Node-Feature Importance Heatmap', fontsize=12, fontweight='bold')
axes[1].set_xticks(range(spec_importance_top.shape[1]))
axes[1].set_yticks(range(len(top_node_indices)))
axes[1].set_yticklabels([f"Node {idx}" for idx in top_node_indices])
plt.colorbar(im2, ax=axes[1], label='Importance')

plt.tight_layout()
plt.show()

print("✓ Heatmap visualization complete")

## 8. Compare Modalities

In [None]:
# Compare explanations across modalities
compare_modalities(eeg_explanation, spec_explanation)

# Create visualization comparing modalities
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Distribution of node importance
eeg_node_imp_sum = eeg_explanation.node_importance.sum(dim=1).cpu().numpy()
spec_node_imp_sum = spec_explanation.node_importance.sum(dim=1).cpu().numpy()

axes[0, 0].hist(eeg_node_imp_sum, bins=15, alpha=0.7, color='steelblue', edgecolor='black')
axes[0, 0].set_xlabel('Node Importance Score', fontsize=11)
axes[0, 0].set_ylabel('Frequency', fontsize=11)
axes[0, 0].set_title('Distribution: EEG Node Importance', fontsize=12, fontweight='bold')
axes[0, 0].grid(axis='y', alpha=0.3)

axes[0, 1].hist(spec_node_imp_sum, bins=15, alpha=0.7, color='coral', edgecolor='black')
axes[0, 1].set_xlabel('Node Importance Score', fontsize=11)
axes[0, 1].set_ylabel('Frequency', fontsize=11)
axes[0, 1].set_title('Distribution: Spectrogram Node Importance', fontsize=12, fontweight='bold')
axes[0, 1].grid(axis='y', alpha=0.3)

# Cumulative importance curves
eeg_sorted = np.sort(eeg_node_imp_sum)[::-1]
eeg_cumsum = np.cumsum(eeg_sorted) / eeg_sorted.sum()

spec_sorted = np.sort(spec_node_imp_sum)[::-1]
spec_cumsum = np.cumsum(spec_sorted) / spec_sorted.sum()

axes[1, 0].plot(range(len(eeg_cumsum)), eeg_cumsum, 'o-', color='steelblue', linewidth=2)
axes[1, 0].axhline(y=0.8, color='red', linestyle='--', label='80% threshold')
axes[1, 0].set_xlabel('Number of Top Nodes', fontsize=11)
axes[1, 0].set_ylabel('Cumulative Importance', fontsize=11)
axes[1, 0].set_title('Cumulative Importance: EEG', fontsize=12, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

axes[1, 1].plot(range(len(spec_cumsum)), spec_cumsum, 'o-', color='coral', linewidth=2)
axes[1, 1].axhline(y=0.8, color='red', linestyle='--', label='80% threshold')
axes[1, 1].set_xlabel('Number of Top Nodes', fontsize=11)
axes[1, 1].set_ylabel('Cumulative Importance', fontsize=11)
axes[1, 1].set_title('Cumulative Importance: Spectrogram', fontsize=12, fontweight='bold')
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("✓ Modality comparison visualization complete")

## 9. Evaluate Explanation Quality

Assess the quality of explanations using fidelity and sparsity metrics.

In [None]:
def compute_fidelity(
    model: nn.Module,
    graphs: List[Batch],
    modality: str,
    explanation: ZORROExplanation,
    sample_idx: int,
    device: torch.device,
    k_values: List[int] = [5, 10, 20],
) -> Dict[int, float]:
    """Compute fidelity metric: model performance when using only top-k important nodes.
    
    Fidelity measures how much the model relies on the identified important nodes.
    Higher fidelity means the important nodes are truly responsible for predictions.
    """
    original_pred = explanation.prediction_original.argmax().item()
    original_conf = explanation.prediction_original[original_pred].item()
    
    fidelity_scores = {}
    
    # Get top-k nodes
    for k in k_values:
        top_k_nodes = set([node_idx for node_idx, _ in explanation.top_k_nodes[:k]])
        
        # Zero out features of non-important nodes
        masked_graphs = []
        node_offset = 0
        
        for g in graphs:
            g_masked = Batch(
                x=g.x.clone(),
                edge_index=g.edge_index.clone() if g.edge_index is not None else None,
                batch=g.batch.clone() if hasattr(g, 'batch') else None,
            )
            
            num_nodes = g.x.shape[0]
            for i in range(num_nodes):
                global_idx = node_offset + i
                if global_idx not in top_k_nodes:
                    g_masked.x[i] = 0.0
            
            masked_graphs.append(g_masked)
            node_offset += num_nodes
        
        # Get masked prediction
        with torch.no_grad():
            # For multi-modal model, need both modalities
            if modality == "eeg":
                logits = model(masked_graphs, graphs)  # Use original spec graphs
            else:
                logits = model(graphs, masked_graphs)  # Use original eeg graphs
        
        masked_conf = logits[sample_idx, original_pred].item()
        
        # Fidelity: how much confidence is retained
        fidelity = masked_conf / original_conf if original_conf > 0 else 0
        fidelity_scores[k] = fidelity
    
    return fidelity_scores

def compute_sparsity(explanation: ZORROExplanation, k_values: List[int] = [5, 10, 20]) -> Dict[int, float]:
    """Compute sparsity metric: percentage of nodes needed to explain prediction.
    
    Sparsity measures how concentrated the importance is on a few nodes.
    Lower sparsity (more nodes needed) means less concentrated explanations.
    """
    total_nodes = len(explanation.node_indices)
    sparsity_scores = {}
    
    for k in k_values:
        sparsity = k / total_nodes
        sparsity_scores[k] = sparsity
    
    return sparsity_scores

# Note: Fidelity computation requires spec_graphs, so we'll just compute sparsity for now
print("Computing explanation quality metrics...")
print("\n" + "="*60)
print("Sparsity Metric")
print("="*60)

sparsity_eeg = compute_sparsity(eeg_explanation)
sparsity_spec = compute_sparsity(spec_explanation)

print("\nEEG Sparsity (% of nodes needed):")
for k, sparse in sparsity_eeg.items():
    print(f"  Top-{k:2d} nodes: {sparse:.1%}")

print("\nSpectrogram Sparsity (% of nodes needed):")
for k, sparse in sparsity_spec.items():
    print(f"  Top-{k:2d} nodes: {sparse:.1%}")

print(f"\n✓ Quality metrics computed")

## 10. Summary and Insights

Generate a comprehensive summary of the explanations.

In [None]:
print("\n" + "="*70)
print("ZORRO EXPLANATION SUMMARY")
print("="*70)

print(f"\nSample Index: {sample_idx}")
print(f"Batch Size: {batch_size}")

print(f"\n--- Original Predictions ---")
print(f"EEG prediction class: {eeg_explanation.prediction_original.argmax().item()}")
print(f"Spectrogram prediction class: {spec_explanation.prediction_original.argmax().item()}")

print(f"\n--- Node Importance Summary ---")
eeg_total = eeg_explanation.node_importance.sum().item()
spec_total = spec_explanation.node_importance.sum().item()
print(f"EEG total importance:         {eeg_total:>8.4f}")
print(f"Spectrogram total importance: {spec_total:>8.4f}")
print(f"Ratio (Spec/EEG):             {spec_total/eeg_total:>8.4f}")

print(f"\n--- Top-5 Important Nodes (EEG) ---")
for rank, (node_idx, importance) in enumerate(eeg_explanation.top_k_nodes[:5], 1):
    print(f"  {rank}. Node {node_idx:3d}: {importance:>8.4f}")

print(f"\n--- Top-5 Important Nodes (Spectrogram) ---")
for rank, (node_idx, importance) in enumerate(spec_explanation.top_k_nodes[:5], 1):
    print(f"  {rank}. Node {node_idx:3d}: {importance:>8.4f}")

print(f"\n--- Feature Importance (Top-5 Features) ---")
eeg_top_feat = torch.topk(eeg_explanation.feature_importance, k=min(5, len(eeg_explanation.feature_importance)))
print(f"EEG:")
for feat_idx, importance in zip(eeg_top_feat.indices, eeg_top_feat.values):
    print(f"  Feature {feat_idx.item():d}: {importance.item():.4f}")

spec_top_feat = torch.topk(spec_explanation.feature_importance, k=min(5, len(spec_explanation.feature_importance)))
print(f"\nSpectrogram:")
for feat_idx, importance in zip(spec_top_feat.indices, spec_top_feat.values):
    print(f"  Feature {feat_idx.item():d}: {importance.item():.4f}")

print(f"\n" + "="*70)
print("✓ ZORRO explanation analysis complete!")
print("="*70)