# Multi-Modal Graph Learning for Alzheimer's Diagnosis

This notebook loads the enriched ADNI tables, trains the weighted GCN, and surfaces evaluation, interpretability, and inference workflows.

## 1. Setup

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay

from alzheimer_pipeline import (
    prepare_dataset,
    train_gcn,
    summarise_performance,
    confusion_matrix_split,
    save_artifacts,
    explain_nodes,
    compute_subgroup_metrics,
    predict_patient,
    LABEL_NAMES,
)

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)


## 2. Build the Patient Graph

In [None]:
DATA_DIR = Path("Datasets")
artifacts = prepare_dataset(DATA_DIR, k_neighbors=20, seed=SEED)
data = artifacts.data
patient_df = artifacts.patient_table.copy()
patient_df["Diagnosis"] = patient_df["Label"].map(LABEL_NAMES)

print(f"Total baseline visits: {len(patient_df):,}")
print("Class distribution:")
display(patient_df["Diagnosis"].value_counts().rename_axis("Diagnosis").to_frame("Count"))

missing_rates = patient_df[artifacts.feature_columns].isna().mean().sort_values(ascending=False)
print("Feature missingness (top 10):")
display(missing_rates.head(10).to_frame("Missing Rate").style.format({"Missing Rate": "{:.1%}"}))


## 3. Train the GCN

In [None]:
model, history = train_gcn(data, epochs=400, patience=60, lr=5e-4)
history_df = pd.DataFrame(history)
history_df.tail()


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history_df["epoch"], history_df["train_acc"], label="Train")
axes[0].plot(history_df["epoch"], history_df["val_acc"], label="Validation")
axes[0].set_title("Accuracy vs. Epoch")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Accuracy")
axes[0].legend()

axes[1].plot(history_df["epoch"], history_df["val_loss"], color="tab:red")
axes[1].set_title("Validation Loss")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")

plt.tight_layout()
plt.show()


## 4. Evaluation

In [None]:
reports = summarise_performance(model, data)
for split, metrics in reports.items():
    print(f"{split.upper()} classification report")
    display(pd.DataFrame(metrics).round(3))


In [None]:
cm = confusion_matrix_split(model, data, split="test")
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(LABEL_NAMES.values())).plot(
    cmap="Blues", values_format="d"
)
plt.title("Test Confusion Matrix")
plt.show()


## 5. Subgroup Metrics

In [None]:
patient_with_bins = patient_df.copy()
age_bins = pd.cut(patient_with_bins["Age"], bins=[45, 55, 65, 75, 85, 100], right=False, labels=["45-54", "55-64", "65-74", "75-84", "85+"])
patient_with_bins["AgeBin"] = age_bins
patient_with_bins["APOE4Carrier"] = (patient_with_bins.get("APOE4Count", 0) >= 1).astype(int)

print("Gender breakdown")
display(compute_subgroup_metrics(model, data, patient_with_bins, split="test", group_by=["GenderBinary"]))

print("APOE4 carrier breakdown")
display(compute_subgroup_metrics(model, data, patient_with_bins, split="test", group_by=["APOE4Carrier"]))

print("Age bin breakdown")
display(compute_subgroup_metrics(model, data, patient_with_bins, split="test", group_by=["AgeBin"]))


## 6. Interpretability

In [None]:
test_indices = np.where(data.test_mask.cpu().numpy())[0]
np.random.seed(SEED)
selected_nodes = np.random.choice(test_indices, size=min(3, len(test_indices)), replace=False)
explanations = explain_nodes(model, data, artifacts, selected_nodes, top_k=5)

for node, details in explanations.items():
    print(f"Node {node} | RID {int(artifacts.patient_table.iloc[node]['RID'])} | Diagnosis {LABEL_NAMES[int(artifacts.patient_table.iloc[node]['Label'])]}")
    top_features = sorted(details["feature_attributions"].items(), key=lambda kv: abs(kv[1]), reverse=True)[:5]
    print("  Top feature contributions:")
    for feature, score in top_features:
        print(f"    {feature}: {score:.3f}")
    print("  Modality attribution:")
    for modality, score in details["modality_attributions"].items():
        print(f"    {modality}: {score:.2f}")
    print("  Nearest neighbours:")
    for neighbour in details["nearest_neighbors"][:3]:
        print(f"    RID {neighbour['RID']} ({neighbour['Diagnosis']}) | distance={neighbour['distance']:.3f}")
    print("-")


## 7. Example Inference

In [None]:
example_row = patient_df.sample(1, random_state=SEED).iloc[0]
input_features = example_row[artifacts.feature_columns].to_dict()
result = predict_patient(
    model.cpu(),
    data,
    artifacts.preprocessor,
    artifacts.neighbor_model,
    artifacts.modality_processors,
    artifacts.modality_weights,
    artifacts.distance_scale,
    artifacts.feature_columns,
    input_features,
    patient_table=artifacts.patient_table,
    return_explanations=True,
)
print("Probabilities:", result["probabilities"])
print("Modality contributions:", result["modality_contributions"])
print("Nearest neighbours:")
for neighbour in result["nearest_neighbors"][:5]:
    print(neighbour)


## 8. Persist Artefacts

In [None]:
ARTIFACT_DIR = Path("artifacts")
save_artifacts(ARTIFACT_DIR, artifacts, model, history)
print(f"Saved artefacts to {ARTIFACT_DIR.resolve()}")
