# BiomeML - GNN Model Interpretation

## Summary
- Load a trained model and test dataset for interpretation.
- Compute node importance, embeddings, and error analysis.
- Save interpretation artifacts under `{DISEASE}_analysis_output/`.


This notebook provides tools for interpreting trained GNN models:

1. Node importance analysis (which microbes drive predictions)
2. Prediction explanation for individual samples
3. Embedding visualization to inspect sample separation
4. Error analysis for misclassifications

## Prerequisites

- Trained model from `03_model_training.ipynb`
- Processed graphs and labels in `{DISEASE}_analysis_output/`
- Set `EXPERIMENT_CONFIG_PATH` or `DISEASE` to select the target disease


In [None]:
import os
import sys
import json
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_mean_pool, global_max_pool, global_add_pool

import matplotlib.pyplot as plt
import seaborn as sns

sys.path.insert(0, str(Path.cwd().parent))

                                                                                    
import yaml

EXPERIMENT_CONFIG = None
exp_config_path = os.environ.get('EXPERIMENT_CONFIG_PATH')

if exp_config_path and Path(exp_config_path).exists():
    print(f"Loading experiment config from: {exp_config_path}")
    with open(exp_config_path, 'r') as f:
        EXPERIMENT_CONFIG = yaml.safe_load(f)
    print(f"  Loaded experiment: {EXPERIMENT_CONFIG.get('name', 'unknown')}")
else:
    print("No EXPERIMENT_CONFIG_PATH set - using default settings")
from src.model_interpretation import (
    compute_node_gradients,
    compute_integrated_gradients,
    get_top_important_nodes,
    aggregate_importance_across_samples,
    analyze_prediction,
    batch_analyze_predictions,
    visualize_node_importance,
    visualize_embedding_space,
    save_interpretation_results
)

print("Imports loaded successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


## 1. Load Model and Data
### Load model artifacts


In [None]:
EXPERIMENT_NAME = "baseline"

                                                                     
if EXPERIMENT_CONFIG is not None:
    DISEASE = EXPERIMENT_CONFIG.get('disease') or EXPERIMENT_CONFIG.get('data_extraction', {}).get('disease', 'IBD')
else:
    import os
    DISEASE = os.environ.get('DISEASE', 'IBD')

print(f"Disease: {DISEASE}")

project_root = Path.cwd().parent
experiment_dir = project_root / "experiments" / EXPERIMENT_NAME
output_dir = Path(f"{DISEASE}_analysis_output")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model_path = output_dir / "models" / "best_model.pt"
dataset_path = output_dir / "pytorch_geometric" / f"{DISEASE.lower()}_dataset.pt"
graphs_path = output_dir / "graphs" / f"nx_graphs_{DISEASE}.pkl"
metadata_path = output_dir / "graphs" / f"graph_metadata_{DISEASE}.csv"

print(f"\nChecking files:")
print(f"  Model: {model_path.exists()}")
print(f"  Dataset: {dataset_path.exists()}")
print(f"  Graphs: {graphs_path.exists()}")
print(f"  Metadata: {metadata_path.exists()}")


In [None]:
                                                                          
sys.path.insert(0, str(Path.cwd().parent))
from src.models import (
    GNN_GCN, GNN_GINEConv, GNN_GAT, GNN_GraphSAGE, 
    EdgeCentricRGCN, MLP_Baseline, CNN_Baseline, get_model
)

                                           
MODEL_CLASSES = {
    'GCN': GNN_GCN,
    'GINEConv': GNN_GINEConv,
    'GAT': GNN_GAT,
    'GraphSAGE': GNN_GraphSAGE,
    'EdgeCentricRGCN': EdgeCentricRGCN,
    'MLP': MLP_Baseline,
    'CNN': CNN_Baseline,
}

print("Model classes imported from src/models.py")
print(f"  Available models: {list(MODEL_CLASSES.keys())}")

                                          
                                                 
def get_model_params_from_config():
    """Get model parameters from EXPERIMENT_CONFIG if available."""
    if 'EXPERIMENT_CONFIG' in globals() and EXPERIMENT_CONFIG is not None:
        arch = EXPERIMENT_CONFIG.get('model_training', {}).get('architecture', {})
        return {
            'hidden_dim': arch.get('hidden_dim', 128),
            'num_layers': arch.get('num_layers', 2),
            'dropout': arch.get('dropout', 0.3),
            'pooling': arch.get('pooling', 'mean'),
            'num_classes': arch.get('num_classes', 2),
        }
    return {'hidden_dim': 128, 'num_layers': 2, 'dropout': 0.3, 'pooling': 'mean', 'num_classes': 2}

class GNN_Model(nn.Module):
    """DEPRECATED: Use models from src/models.py instead.
    This class is kept for backwards compatibility only.
    """
    
    def __init__(self, input_dim=1, hidden_dim=None, num_layers=None, dropout=None, 
                 pooling=None, num_classes=None):
                                                                 
        defaults = get_model_params_from_config()
        hidden_dim = hidden_dim if hidden_dim is not None else defaults['hidden_dim']
        num_layers = num_layers if num_layers is not None else defaults['num_layers']
        dropout = dropout if dropout is not None else defaults['dropout']
        pooling = pooling if pooling is not None else defaults['pooling']
        num_classes = num_classes if num_classes is not None else defaults['num_classes']
        super().__init__()
        
        self.num_classes = num_classes
        self.dropout = nn.Dropout(dropout)
        self.pooling = pooling
        
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        nn_1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.convs.append(GINEConv(nn_1, edge_dim=1))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        for _ in range(num_layers - 1):
            nn_i = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(GINEConv(nn_i, edge_dim=1))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        if num_classes == 2:
            self.classifier = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, 1),
                nn.Sigmoid()
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim // 2, num_classes)
            )
    
    def forward_embedding(self, data):
        """Get graph-level embeddings before classification."""
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index, edge_attr)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)
        
        if self.pooling == 'mean':
            x = global_mean_pool(x, batch)
        elif self.pooling == 'max':
            x = global_max_pool(x, batch)
        elif self.pooling == 'sum':
            x = global_add_pool(x, batch)
        
        return x
    
    def forward(self, data):
        x = self.forward_embedding(data)
        out = self.classifier(x)
        return out.squeeze(-1)

print("Model class defined")


In [None]:
print("Loading dataset...")
dataset = torch.load(dataset_path, map_location=device)

test_dataset = dataset['test_dataset']
test_labels = np.array(dataset['test_labels'])

print(f"Test samples: {len(test_dataset)}")
print(f"{DISEASE} positive: {sum(test_labels)}")
print(f"Control: {len(test_labels) - sum(test_labels)}")

import yaml
exp_config_yaml_path = experiment_dir / "config.yaml" if experiment_dir.exists() else project_root / "config.yaml"
if exp_config_yaml_path.exists():
    with open(exp_config_yaml_path) as f:
        exp_config_yaml = yaml.safe_load(f)
    arch_config = exp_config_yaml.get('model_training', {}).get('architecture', {})
    model_type = arch_config.get('model_type', 'GINEConv')
    hidden_dim = arch_config.get('hidden_dim', 160)
    num_layers = arch_config.get('num_layers', 2)
    dropout = arch_config.get('dropout', 0.3)
    pooling = arch_config.get('pooling', 'mean')
    num_classes = arch_config.get('num_classes', 2)
else:
    model_type = 'GINEConv'
    hidden_dim, num_layers, dropout, pooling, num_classes = 160, 2, 0.3, 'mean', 2

print(f"\nLoading model type: {model_type}")

if model_type in MODEL_CLASSES:
    model_class = MODEL_CLASSES[model_type]
    
                                                            
    if model_type in ['MLP', 'CNN']:
        model = model_class(
            input_dim=1,
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            dropout=dropout
        ).to(device)
    elif model_type == 'EdgeCentricRGCN':
        model = model_class(
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            dropout=dropout,
            pooling=pooling
        ).to(device)
    else:
        model = model_class(
            input_dim=1,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            pooling=pooling,
            num_classes=num_classes
        ).to(device)
else:
    print(f"Unknown model type '{model_type}', falling back to legacy GNN_Model")
    model = GNN_Model(
        input_dim=1,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        dropout=dropout,
        pooling=pooling,
        num_classes=num_classes
    ).to(device)

checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Model loaded: {model_type}, hidden_dim={hidden_dim}, layers={num_layers}")


## 2. Node Importance Analysis
### Gradient scoring


In [None]:
print("Computing node importance across test samples...")

n_samples_to_analyze = min(50, len(test_dataset))
all_importance_scores = []

for i in range(n_samples_to_analyze):
    data = test_dataset[i]
    importance = compute_node_gradients(model, data, device)
    all_importance_scores.append(importance)

print(f"Analyzed {n_samples_to_analyze} samples")

min_nodes = min(len(imp) for imp in all_importance_scores)
trimmed_scores = [imp[:min_nodes] for imp in all_importance_scores]
stacked_scores = np.stack(trimmed_scores)

mean_importance = np.mean(stacked_scores, axis=0)
std_importance = np.std(stacked_scores, axis=0)

print(f"Computed importance for {min_nodes} nodes")
print(f"Mean importance range: [{mean_importance.min():.4f}, {mean_importance.max():.4f}]")


In [None]:
top_k = 20

fig, ax = plt.subplots(figsize=(12, 8))

top_indices = np.argsort(mean_importance)[::-1][:top_k]
top_scores = mean_importance[top_indices]
top_stds = std_importance[top_indices]

labels = [f"Node {idx}" for idx in top_indices]

y_pos = np.arange(len(top_scores))
ax.barh(y_pos, top_scores, xerr=top_stds, align='center', 
        color='steelblue', alpha=0.8, capsize=3)
ax.set_yticks(y_pos)
ax.set_yticklabels(labels)
ax.invert_yaxis()
ax.set_xlabel('Mean Importance Score (± std)', fontsize=12)
ax.set_title(f'Top {top_k} Important Nodes (Microbes)', fontsize=14, fontweight='bold')
ax.grid(axis='x', alpha=0.3)

plt.tight_layout()
viz_dir = output_dir / 'visualizations'
viz_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(viz_dir / 'node_importance.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nTop {top_k} most important nodes:")
for i, (idx, score) in enumerate(zip(top_indices, top_scores)):
    print(f"  {i+1}. Node {idx}: {score:.4f} (±{top_stds[i]:.4f})")


## 3. Embedding Space Visualization
### t-SNE projection


In [None]:
print("Extracting embeddings...")

embeddings = []
loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

model.eval()
with torch.no_grad():
    for batch in loader:
        batch = batch.to(device)
        emb = model.forward_embedding(batch)
        embeddings.append(emb.cpu().numpy())

embeddings = np.vstack(embeddings)
print(f"Embedding shape: {embeddings.shape}")

from sklearn.manifold import TSNE

print("Running t-SNE...")
perplexity = min(30, len(embeddings) - 1)
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
embeddings_2d = tsne.fit_transform(embeddings)

fig, ax = plt.subplots(figsize=(10, 8))

colors = ['#2ecc71', '#e74c3c']                                      
labels_text = ['Control', DISEASE]

for label in [0, 1]:
    mask = test_labels == label
    ax.scatter(
        embeddings_2d[mask, 0],
        embeddings_2d[mask, 1],
        c=colors[label],
        label=labels_text[label],
        alpha=0.7,
        s=60
    )

ax.set_xlabel('t-SNE 1', fontsize=12)
ax.set_ylabel('t-SNE 2', fontsize=12)
ax.set_title('Model Embedding Space (t-SNE)', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(viz_dir / 'embedding_tsne.png', dpi=150, bbox_inches='tight')
plt.show()


## 4. Error Analysis
### Misclassification review


In [None]:
print("Analyzing predictions...")

predictions = []
for i in range(len(test_dataset)):
    data = test_dataset[i].to(device)
    with torch.no_grad():
        pred_prob = model(data).cpu().item()
    pred_class = 1 if pred_prob > 0.5 else 0
    predictions.append({
        'idx': i,
        'prob': pred_prob,
        'pred': pred_class,
        'true': int(test_labels[i]),
        'correct': pred_class == int(test_labels[i])
    })

predictions_df = pd.DataFrame(predictions)
accuracy = predictions_df['correct'].mean()

print(f"\nTest Accuracy: {accuracy:.4f}")
print(f"Correct: {predictions_df['correct'].sum()}")
print(f"Incorrect: {len(predictions_df) - predictions_df['correct'].sum()}")


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax1 = axes[0]
correct_probs = predictions_df[predictions_df['correct']]['prob']
incorrect_probs = predictions_df[~predictions_df['correct']]['prob']

ax1.hist(correct_probs, bins=20, alpha=0.7, label='Correct', color='green')
ax1.hist(incorrect_probs, bins=20, alpha=0.7, label='Incorrect', color='red')
ax1.axvline(x=0.5, color='black', linestyle='--', label='Decision boundary')
ax1.set_xlabel('Prediction Probability (Disease)', fontsize=12)
ax1.set_ylabel('Count', fontsize=12)
ax1.set_title('Prediction Confidence Distribution', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(alpha=0.3)

ax2 = axes[1]
error_types = predictions_df.groupby(['true', 'pred']).size().unstack(fill_value=0)
sns.heatmap(error_types, annot=True, fmt='d', cmap='Blues', ax=ax2,
            xticklabels=['Control', DISEASE], yticklabels=['Control', DISEASE])
ax2.set_xlabel('Predicted', fontsize=12)
ax2.set_ylabel('True', fontsize=12)
ax2.set_title('Confusion Matrix', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(viz_dir / 'error_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

fp = ((predictions_df['true'] == 0) & (predictions_df['pred'] == 1)).sum()
fn = ((predictions_df['true'] == 1) & (predictions_df['pred'] == 0)).sum()
print(f"\nError Statistics:")
print(f"  False Positives (Control -> Disease): {fp}")
print(f"  False Negatives (Disease -> Control): {fn}")


## 5. Save Interpretation Results
### Save summary outputs


In [None]:
interpretation_results = {
    'experiment': EXPERIMENT_NAME,
    'timestamp': datetime.now().isoformat(),
    'model_config': {
        'hidden_dim': hidden_dim,
        'num_layers': num_layers,
        'dropout': dropout,
        'pooling': pooling
    },
    'test_accuracy': float(accuracy),
    'top_important_nodes': [
        {'rank': i+1, 'node_index': int(idx), 'importance': float(mean_importance[idx])}
        for i, idx in enumerate(top_indices[:20])
    ],
    'error_analysis': {
        'n_correct': int(predictions_df['correct'].sum()),
        'n_incorrect': int((~predictions_df['correct']).sum()),
        'false_positives': int(fp),
        'false_negatives': int(fn)
    }
}

results_dir = output_dir / 'results'
results_dir.mkdir(parents=True, exist_ok=True)
results_path = results_dir / 'interpretation_results.json'
with open(results_path, 'w') as f:
    json.dump(interpretation_results, f, indent=2)

print(f"Interpretation results saved to: {results_path}")
print(f"\n=== Summary ===")
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Most important node: Node {top_indices[0]} (score: {mean_importance[top_indices[0]]:.4f})")
print(f"Total misclassifications: {len(predictions_df) - predictions_df['correct'].sum()}")
