# Supervised LDA for Co-Pathology Analysis

This notebook demonstrates how to use the supervised Latent Dirichlet Allocation (sLDA) model to discover co-pathology patterns in neurodegenerative diseases from regional gray matter atrophy data.

## Overview

- **Goal**: Identify latent pathology patterns (topics) that explain regional atrophy and predict diagnosis
- **Data**: 209 patients with 5 diagnoses (AD, PD, DLB, SVAD, HC) and 62 cortical region measurements
- **Model**: sLDA with continuous Normal likelihood for features and Categorical likelihood for diagnosis

## Key Concepts

- **Topics** = Latent pathology patterns (e.g., limbic atrophy, cortical atrophy)
- **Patient mixtures** = Each patient has a distribution over topics (co-pathology)
- **Topic patterns** = Each topic has characteristic regional atrophy signatures
- **Supervised component** = Topics predict diagnosis

## 1. Setup and Data Loading

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Import our custom modules
from preprocessing import load_wsev_data, prepare_slda_inputs, train_test_split_stratified
from slda_model import CoPathologySLDA
from visualization import (
    plot_topic_heatmap,
    plot_patient_topic_distribution,
    plot_topic_diagnosis_association,
    plot_brain_topic_pattern,
    plot_copathology_mixtures,
    plot_confusion_matrix
)

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

# %matplotlib inline

In [2]:
# Load the WSEV dataset
data_path = '/home/coder/data/updated_WSEV/260108_wsev_final_df.csv'
df = load_wsev_data(data_path)

# Display basic info
print(f"\nDataset shape: {df.shape}")
print(f"\nDiagnosis distribution:")
print(df['DX'].value_counts())

Loaded 160 patients from /home/coder/data/updated_WSEV/260108_wsev_final_df.csv
Columns: 113

Dataset shape: (160, 113)

Diagnosis distribution:
DX
PD      56
AD      55
DLB     28
SVAD    21
Name: count, dtype: int64


In [3]:
# Prepare data for sLDA
X, y, feature_names, dx_labels = prepare_slda_inputs(df, standardize=False)

print(f"\nFeature matrix X: {X.shape}")
print(f"Diagnosis labels y: {y.shape}")
print(f"Number of features: {len(feature_names)}")
print(f"Diagnoses: {dx_labels}")

Diagnosis distribution:
  AD: 55 patients (class 0)
  DLB: 28 patients (class 1)
  PD: 56 patients (class 2)
  SVAD: 21 patients (class 3)

Final data shape:
  X: (160, 95) (patients × cortical regions)
  y: (160,) (patients,)
  Features: 95
  Diagnoses: 4

Feature matrix X: (160, 95)
Diagnosis labels y: (160,)
Number of features: 95
Diagnoses: ['AD', 'DLB', 'PD', 'SVAD']


In [4]:
# Optional: Split into train/test sets
# For this example, we'll train on all data for better topic discovery
# Uncomment below to use train/test split:

# X_train, X_test, y_train, y_test = train_test_split_stratified(X, y, test_size=0.2)
# X_model, y_model = X_train, y_train

# For full dataset:
X_model, y_model = X, y

## 2. Model Training

We'll fit the sLDA model with 4 topics. This may take several minutes depending on your system.

**Note**: Adjust `n_topics` to explore different numbers of pathology patterns (3-6 recommended).

In [5]:
# Initialize model
model = CoPathologySLDA(
    n_topics=4,           # Number of latent pathology patterns
    alpha_prior=1.0,      # Dirichlet concentration (1.0 = uniform)
    feature_prior_std=1.0, # Prior std for topic-region weights
    random_state=42
)

print("Model initialized with 4 topics")
print("\nNote: Sampling may take 5-15 minutes...")

Model initialized with 4 topics

Note: Sampling may take 5-15 minutes...


In [6]:
# Fit the model
# For faster testing, reduce n_samples to 500 and chains to 2
# For final analysis, use n_samples=2000 and chains=4

model.fit(
    X_model, 
    y_model,
    n_samples=2000,      # Number of MCMC samples per chain
    tune=1000,           # Number of tuning/burn-in samples
    chains=4,            # Number of parallel chains
    target_accept=0.9    # Target acceptance rate for NUTS
)

Fitting sLDA model:
  Patients: 160
  Features: 95
  Topics: 4
  Diagnoses: 4
  Sampling: 2000 samples × 4 chains


AttributeError: module 'pymc' has no attribute 'Constant'

## 3. Extract Model Parameters

After fitting, we extract the posterior means of key parameters:
- **β (beta)**: Topic-region patterns (K × V matrix)
- **θ (theta)**: Patient-topic mixtures (D × K matrix)
- **η (eta)**: Topic-diagnosis associations (K × C matrix)

In [None]:
# Get posterior means
topic_patterns = model.get_topic_patterns()        # (n_topics, n_features)
patient_mixtures = model.get_patient_mixtures()    # (n_patients, n_topics)
diagnosis_weights = model.get_diagnosis_weights()  # (n_topics, n_classes)

print(f"Topic patterns (β): {topic_patterns.shape}")
print(f"Patient mixtures (θ): {patient_mixtures.shape}")
print(f"Diagnosis weights (η): {diagnosis_weights.shape}")

# Verify topic mixtures sum to 1
print(f"\nPatient mixture sums (should be 1.0): {patient_mixtures[0].sum():.4f}")

## 4. Visualize Topic Patterns

### 4.1 Topic-Region Heatmap

Shows which brain regions are associated with each topic.

In [None]:
fig = plot_topic_heatmap(
    topic_patterns, 
    feature_names,
    figsize=(16, 6),
    save_path='topic_heatmap.png'
)
plt.show()

### 4.2 Top Regions per Topic

Identify the most important regions for each topic.

In [None]:
# Print top regions for each topic
for topic_id in range(model.n_topics):
    print(f"\n{'='*60}")
    print(f"Topic {topic_id} - Top 10 Regions")
    print(f"{'='*60}")
    
    top_regions = model.get_topic_top_regions(
        topic_id, 
        feature_names, 
        n_regions=10,
        absolute=True
    )
    
    for i, (region, weight) in enumerate(top_regions, 1):
        print(f"{i:2d}. {region:40s} {weight:+.3f}")

### 4.3 Detailed Topic Visualization

Visualize top regions for each topic, separated by hemisphere.

In [None]:
# Visualize each topic
for topic_id in range(model.n_topics):
    fig = plot_brain_topic_pattern(
        topic_id,
        topic_patterns,
        feature_names,
        n_top_regions=12,
        figsize=(14, 5),
        save_path=f'topic_{topic_id}_brain_pattern.png'
    )
    plt.show()

## 5. Analyze Patient Topic Mixtures

### 5.1 Topic Distribution by Diagnosis

Shows how each diagnosis group loads on different topics.

In [None]:
fig = plot_patient_topic_distribution(
    patient_mixtures,
    y_model,
    dx_labels,
    figsize=(14, 6),
    save_path='patient_topic_distribution.png'
)
plt.show()

### 5.2 Co-Pathology Examples

Visualize topic mixtures for individual patients to see co-pathology patterns.

In [None]:
# Get patient IDs if available
patient_ids = df['PTID_ANONY'].values if 'PTID_ANONY' in df.columns else None

fig = plot_copathology_mixtures(
    patient_mixtures,
    y_model,
    dx_labels,
    patient_ids=patient_ids,
    n_patients_per_dx=5,
    figsize=(16, 8),
    save_path='copathology_examples.png'
)
plt.show()

### 5.3 Identify Mixed Pathology Patients

Find patients with high entropy (mixed topic membership) indicating co-pathology.

In [None]:
from scipy.stats import entropy

# Compute entropy of topic mixtures
topic_entropy = np.array([entropy(patient_mixtures[i]) for i in range(len(patient_mixtures))])

# Find patients with highest entropy (most mixed)
high_entropy_idx = np.argsort(topic_entropy)[-10:]

print("Patients with highest co-pathology (mixed topic membership):\n")
print(f"{'Patient':<15} {'Diagnosis':<10} {'Entropy':<10} {'Topic Mixture'}")
print("="*70)

for idx in high_entropy_idx:
    patient_id = patient_ids[idx] if patient_ids is not None else f"Patient_{idx}"
    dx = dx_labels[y_model[idx]]
    ent = topic_entropy[idx]
    mixture = patient_mixtures[idx]
    mixture_str = ' '.join([f"{m:.2f}" for m in mixture])
    print(f"{patient_id:<15} {dx:<10} {ent:<10.3f} [{mixture_str}]")

## 6. Topic-Diagnosis Associations

Understand how topics predict diagnoses.

In [None]:
fig = plot_topic_diagnosis_association(
    diagnosis_weights,
    dx_labels,
    figsize=(9, 6),
    save_path='topic_diagnosis_association.png'
)
plt.show()

In [None]:
# Interpret topic-diagnosis associations
print("Topic-Diagnosis Associations (η matrix):\n")
print(f"{'Topic':<10}", end="")
for dx in dx_labels:
    print(f"{dx:>10}", end="")
print("\n" + "="*60)

for topic_id in range(model.n_topics):
    print(f"Topic {topic_id:<4}", end="")
    for dx_id in range(len(dx_labels)):
        weight = diagnosis_weights[topic_id, dx_id]
        print(f"{weight:>10.3f}", end="")
    print()

print("\nInterpretation:")
print("- Positive weights: Topic increases probability of diagnosis")
print("- Negative weights: Topic decreases probability of diagnosis")

## 7. Model Predictions (Optional)

Evaluate the model's ability to predict diagnoses.

In [None]:
# Predict diagnoses for training data
y_pred = model.predict_diagnosis(X_model)
y_pred_proba = model.predict_diagnosis_proba(X_model)

print(f"Predicted diagnoses shape: {y_pred.shape}")
print(f"Prediction probabilities shape: {y_pred_proba.shape}")

In [None]:
# Plot confusion matrix
fig = plot_confusion_matrix(
    y_model,
    y_pred,
    dx_labels,
    figsize=(9, 8),
    save_path='confusion_matrix.png'
)
plt.show()

In [None]:
# Per-class accuracy
from sklearn.metrics import classification_report

print("Classification Report:\n")
print(classification_report(y_model, y_pred, target_names=dx_labels))

## 8. Topic Interpretation Guide

Use the following analysis to interpret what each topic represents:

### Expected Topic Patterns (Examples)

**Topic 0: Limbic/Temporal Pattern**
- High weights: entorhinal, parahippocampal, fusiform, temporal regions
- Associated diagnoses: AD, SVAD (Alzheimer's pathology)

**Topic 1: Cortical/Parietal Pattern**
- High weights: precuneus, parietal, posterior cingulate
- Associated diagnoses: AD, DLB (cortical atrophy)

**Topic 2: Subcortical Sparing Pattern**
- Lower overall atrophy weights
- Associated diagnoses: PD (less cortical involvement)

**Topic 3: Minimal Atrophy**
- Very low atrophy across regions
- Associated diagnoses: HC (healthy controls)

### Co-Pathology Interpretation

Patients with mixed topic memberships may have:
- **AD + DLB**: Mixed limbic and cortical patterns
- **SVAD**: AD-like pattern plus vascular features
- **PD with cognitive impairment**: PD pattern plus cortical involvement

## 9. Save Results

In [None]:
# Save topic patterns
topic_df = pd.DataFrame(
    topic_patterns.T,
    index=feature_names,
    columns=[f'Topic_{i}' for i in range(model.n_topics)]
)
topic_df.to_csv('topic_patterns.csv')
print("Saved topic patterns to topic_patterns.csv")

# Save patient mixtures
mixture_df = pd.DataFrame(
    patient_mixtures,
    columns=[f'Topic_{i}' for i in range(model.n_topics)]
)
mixture_df['Diagnosis'] = [dx_labels[i] for i in y_model]
if patient_ids is not None:
    mixture_df['Patient_ID'] = patient_ids
mixture_df.to_csv('patient_topic_mixtures.csv', index=False)
print("Saved patient mixtures to patient_topic_mixtures.csv")

# Save diagnosis weights
dx_weights_df = pd.DataFrame(
    diagnosis_weights,
    index=[f'Topic_{i}' for i in range(model.n_topics)],
    columns=dx_labels
)
dx_weights_df.to_csv('topic_diagnosis_weights.csv')
print("Saved diagnosis weights to topic_diagnosis_weights.csv")

## 10. Summary and Next Steps

### Summary

This notebook demonstrated:
1. Loading and preprocessing neurodegenerative disease data
2. Fitting a supervised LDA model with continuous features
3. Extracting and interpreting latent pathology patterns
4. Analyzing co-pathology through patient topic mixtures
5. Understanding topic-diagnosis associations

### Next Steps

1. **Validate convergence**: Check R-hat values and trace plots
2. **Sensitivity analysis**: Try different numbers of topics (3-6)
3. **Clinical validation**: Compare topics with known pathology patterns
4. **Longitudinal analysis**: Apply model to follow-up scans
5. **Feature expansion**: Include subcortical structures
6. **External validation**: Test on independent dataset