# Statistical Insights & GNN Explainability

**Goal:** Interpret the GNN model's predictions to uncover biological insights.

## key Questions
1. **Which cells are driving the decision?** (GAT Attention Weights)
2. **Which genes are differentially expressed** between predicted High-Risk vs Low-Risk groups?
3. **What gene modules are associated with AD?**

In [None]:
import torch
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import sys
import os

# Add src to path
sys.path.append(os.path.abspath('../src'))
from models.gnn_models import CellGNN

sns.set_style('whitegrid')

### 1. Load Data & Model

In [None]:
# Load Graph
data_path = '../data/processed/graphs/cell_graph.pt'
data = torch.load(data_path)
print(f"Loaded graph: {data}")

# Init Model (Simulating a loaded checkpoint)
# In a real scenario, use: model.load_state_dict(torch.load('checkpoints/best.ckpt')['state_dict'])
model = CellGNN(in_channels=data.num_features, hidden_channels=64, out_channels=3)
model.eval()

### 2. Predictions & Confusion Matrix

In [None]:
with torch.no_grad():
    out = model(data.x, data.edge_index)
    preds = out.argmax(dim=1)

# Compare with ground truth (data.y)
if data.y is not None:
    cm = confusion_matrix(data.y.numpy(), preds.numpy())
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix (Untrained Random Weights for Demo)')
    plt.show()

### 3. GNN Explainability (Attention Analysis)
Visualizing which edges (neighboring cells) received the most attention.

In [None]:
with torch.no_grad():
    # return_attention_weights=True is needed in GATConv definition or forward pass adjustment
    # For this demo, we simulate extraction or hook into the layer
    pass

# Mock visualization of "Risk Scores" on UMAP
# (Assuming we computed UMAP in EDA)
adata = sc.read_h5ad('../data/processed/rosmap_proc.h5ad')
adata.obs['pred_risk'] = out.exp()[:, 2].numpy() # Probability of Class 2 (AD)

# Compute UMAP if missing (preprocessing only did PCA)
if 'X_umap' not in adata.obsm:
    print("Computing UMAP...")
    sc.pp.neighbors(adata, n_neighbors=15, n_pcs=30)
    sc.tl.umap(adata)

sc.pl.umap(adata, color=['pred_risk'], cmap='magma', title='Predicted AD Risk Score')

### 4. Differential Expression: High Risk vs Low Risk
Finding genes that drive the model's high-risk predictions.

In [None]:
# Divide cells into High Risk (Top 20%) vs Low Risk (Bottom 20%)
risk_scores = adata.obs['pred_risk']
high_risk_idx = risk_scores > risk_scores.quantile(0.8)
low_risk_idx = risk_scores < risk_scores.quantile(0.2)

adata.obs['risk_group'] = 'Mid'
adata.obs.loc[high_risk_idx, 'risk_group'] = 'High'
adata.obs.loc[low_risk_idx, 'risk_group'] = 'Low'

sc.tl.rank_genes_groups(adata, 'risk_group', groups=['High'], reference='Low', method='t-test')
sc.pl.rank_genes_groups(adata, n_genes=20, sharey=False)