# Supervised LDA with ADVI (Fast Variational Inference)

This notebook uses **ADVI (Automatic Differentiation Variational Inference)** instead of MCMC for faster model fitting. This is useful for:

- **Quick prototyping** and testing model functionality
- **Hypothesis exploration** before running full MCMC
- **Large datasets** where MCMC is too slow

## ADVI vs MCMC Trade-offs

| Aspect | ADVI | MCMC (NUTS) |
|--------|------|-------------|
| **Speed** | ~30 seconds | ~10-15 minutes |
| **Posterior** | Approximate (Gaussian) | Exact samples |
| **Uncertainty** | May underestimate | Accurate |
| **Multi-modal** | May miss modes | Explores all modes |
| **Use case** | Exploration, large data | Final analysis, publication |

**Recommendation**: Use ADVI for initial exploration, then run MCMC for final results.

## 1. Setup and Data Loading

In [15]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Import our custom modules
from preprocessing import load_wsev_data, prepare_slda_inputs, train_test_split_stratified
from slda_model import CoPathologySLDA
from visualization import (
    plot_topic_heatmap,
    plot_patient_topic_distribution,
    plot_topic_diagnosis_association,
    plot_brain_topic_pattern,
    plot_copathology_mixtures,
    plot_confusion_matrix
)

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

%matplotlib inline

In [16]:
# Load the WSEV dataset
data_path = '/home/coder/data/updated_WSEV/260108_wsev_final_df.csv'
df = load_wsev_data(data_path)

# Display basic info
print(f"\nDataset shape: {df.shape}")
print(f"\nDiagnosis distribution:")
print(df['DX'].value_counts())

Loaded 160 patients from /home/coder/data/updated_WSEV/260108_wsev_final_df.csv
Columns: 113

Dataset shape: (160, 113)

Diagnosis distribution:
DX
PD      56
AD      55
DLB     28
SVAD    21
Name: count, dtype: int64


In [17]:
# Prepare data for sLDA
X, y, feature_names, dx_labels = prepare_slda_inputs(df, standardize=False)

print(f"\nFeature matrix X: {X.shape}")
print(f"Diagnosis labels y: {y.shape}")
print(f"Number of features: {len(feature_names)}")
print(f"Diagnoses: {dx_labels}")

Diagnosis distribution:
  AD: 55 patients (class 0)
  DLB: 28 patients (class 1)
  PD: 56 patients (class 2)
  SVAD: 21 patients (class 3)

Final data shape:
  X: (160, 95) (patients Ã— cortical regions)
  y: (160,) (patients,)
  Features: 95
  Diagnoses: 4

Feature matrix X: (160, 95)
Diagnosis labels y: (160,)
Number of features: 95
Diagnoses: ['AD', 'DLB', 'PD', 'SVAD']


In [18]:
# Use full dataset
X_model, y_model = X, y

## 2. Model Training with ADVI

ADVI optimizes a variational approximation to the posterior. It's much faster than MCMC but provides an approximate (Gaussian) posterior.

In [19]:
# Initialize model
model = CoPathologySLDA(
    n_topics=4,           # Number of latent pathology patterns
    alpha_prior=1.0,      # Dirichlet concentration (1.0 = uniform)
    feature_prior_std=1.0, # Prior std for topic-region weights
    random_state=42
)

print("Model initialized with 4 topics")
print("\nUsing ADVI - this should take ~30-60 seconds...")

Model initialized with 4 topics

Using ADVI - this should take ~30-60 seconds...


In [None]:
%%time

# Fit the model using ADVI
# model.fit(
#     X_model, 
#     y_model,
#     inference='advi',        # Use variational inference
#     n_advi_iterations=30000, # Optimization iterations
#     n_samples=1000           # Samples from approximate posterior
# )

model.fit(
    X_model, 
    y_model,
    inference='advi',
    n_advi_iterations=10000,
    n_samples=1000
)

Fitting sLDA model with ADVI:
  Patients: 160
  Features: 95
  Topics: 4
  Diagnoses: 4

Starting ADVI (10000 max iterations)...


Output()

## 3. Check ADVI Convergence

The ELBO (Evidence Lower Bound) should stabilize as optimization converges.

In [None]:
# Plot ELBO convergence
fig = model.plot_elbo(figsize=(12, 4), save_path='advi_elbo.png')
plt.show()

## 4. Extract Model Parameters

In [None]:
# Get posterior means
topic_patterns = model.get_topic_patterns()        # (n_topics, n_features)
patient_mixtures = model.get_patient_mixtures()    # (n_patients, n_topics)
diagnosis_weights = model.get_diagnosis_weights()  # (n_topics, n_classes)

print(f"Topic patterns (beta): {topic_patterns.shape}")
print(f"Patient mixtures (theta): {patient_mixtures.shape}")
print(f"Diagnosis weights (eta): {diagnosis_weights.shape}")

# Verify topic mixtures sum to 1
print(f"\nPatient mixture sums (should be ~1.0): {patient_mixtures[0].sum():.4f}")

## 5. Visualize Topic Patterns

In [None]:
fig = plot_topic_heatmap(
    topic_patterns, 
    feature_names,
    figsize=(16, 6),
    save_path='advi_topic_heatmap.png'
)
plt.show()

In [None]:
# Print top regions for each topic
for topic_id in range(model.n_topics):
    print(f"\n{'='*60}")
    print(f"Topic {topic_id} - Top 10 Regions")
    print(f"{'='*60}")
    
    top_regions = model.get_topic_top_regions(
        topic_id, 
        feature_names, 
        n_regions=10,
        absolute=True
    )
    
    for i, (region, weight) in enumerate(top_regions, 1):
        print(f"{i:2d}. {region:40s} {weight:+.3f}")

## 6. Patient Topic Mixtures

In [None]:
fig = plot_patient_topic_distribution(
    patient_mixtures,
    y_model,
    dx_labels,
    figsize=(14, 6),
    save_path='advi_patient_topic_distribution.png'
)
plt.show()

## 7. Topic-Diagnosis Associations

In [None]:
fig = plot_topic_diagnosis_association(
    diagnosis_weights,
    dx_labels,
    figsize=(9, 6),
    save_path='advi_topic_diagnosis_association.png'
)
plt.show()

In [None]:
# Interpret topic-diagnosis associations
print("Topic-Diagnosis Associations (eta matrix):\n")
print(f"{'Topic':<10}", end="")
for dx in dx_labels:
    print(f"{dx:>10}", end="")
print("\n" + "="*60)

for topic_id in range(model.n_topics):
    print(f"Topic {topic_id:<4}", end="")
    for dx_id in range(len(dx_labels)):
        weight = diagnosis_weights[topic_id, dx_id]
        print(f"{weight:>10.3f}", end="")
    print()

print("\nInterpretation:")
print("- Positive weights: Topic increases probability of diagnosis")
print("- Negative weights: Topic decreases probability of diagnosis")

## 8. Model Predictions

In [None]:
# Predict diagnoses for training data
y_pred = model.predict_diagnosis(X_model)
y_pred_proba = model.predict_diagnosis_proba(X_model)

print(f"Predicted diagnoses shape: {y_pred.shape}")
print(f"Prediction probabilities shape: {y_pred_proba.shape}")

In [None]:
# Plot confusion matrix
fig = plot_confusion_matrix(
    y_model,
    y_pred,
    dx_labels,
    figsize=(9, 8),
    save_path='advi_confusion_matrix.png'
)
plt.show()

In [None]:
# Per-class accuracy
from sklearn.metrics import classification_report

print("Classification Report (ADVI):\n")
print(classification_report(y_model, y_pred, target_names=dx_labels))

## 9. Summary

### What we learned from ADVI

ADVI provides a quick approximation that can help you:
1. **Verify the model runs** without errors
2. **Get rough topic patterns** to check if they make biological sense
3. **Estimate classification performance** before investing in full MCMC
4. **Tune hyperparameters** (n_topics, alpha) quickly

### When to use full MCMC

For final analysis, run the original notebook with MCMC to get:
- Accurate uncertainty estimates
- Proper R-hat convergence diagnostics
- Full posterior distributions for credible intervals

### Comparing ADVI vs MCMC results

If ADVI and MCMC give very different results, this may indicate:
- The posterior is multi-modal (ADVI may miss modes)
- ADVI's Gaussian approximation is too restrictive
- More ADVI iterations are needed