# Inferring Diagnosis for New Subjects

This notebook demonstrates how to use a **trained sLDA model** to infer:
1. **Diagnosis probabilities** for new subjects
2. **Topic mixtures** (co-pathology patterns)
3. **Comparison** to training subjects

## Workflow

```
Trained Model (β, η) + New Subject (X_new)
           ↓
    Infer θ_new (topic mixture)
           ↓
    Compute P(DX | θ_new)
           ↓
    Diagnosis Probabilities
```

## 1. Load Trained Model

First, we need a trained model. If you haven't trained one yet, run `slda_copathology_example.ipynb` first.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from preprocessing import load_wsev_data, prepare_slda_inputs, train_test_split_stratified
from slda_model import CoPathologySLDA
from inference_new_subjects import (
    infer_new_subject,
    print_inference_results,
    infer_batch_subjects,
    compare_subject_to_training
)

sns.set_style('whitegrid')
%matplotlib inline

ModuleNotFoundError: No module named 'seaborn'

### Option A: Train a new model (or load existing)

In [None]:
# Load data
data_path = '/home/coder/data/updated_WSEV/260108_wsev_final_df.csv'
df = load_wsev_data(data_path)
X, y, feature_names, dx_labels = prepare_slda_inputs(df)

print(f"Loaded {len(X)} subjects")
print(f"Features: {len(feature_names)}")
print(f"Diagnoses: {dx_labels}")

In [None]:
# Split into train/test to simulate new subjects
X_train, X_test, y_train, y_test = train_test_split_stratified(
    X, y, test_size=0.2, random_state=42
)

print(f"Training set: {len(X_train)} subjects")
print(f"Test set (simulating new subjects): {len(X_test)} subjects")

In [None]:
# Train model (or load if already trained)
# For faster testing, use fewer samples

model = CoPathologySLDA(n_topics=4, alpha_prior=1.0, random_state=42)

print("Training model on training set...")
print("Note: This may take 5-15 minutes. Reduce n_samples for faster testing.\n")

model.fit(
    X_train,
    y_train,
    n_samples=1000,  # Increase to 2000 for final analysis
    tune=500,        # Increase to 1000 for final analysis
    chains=2,        # Increase to 4 for final analysis
    target_accept=0.9
)

### Option B: Load a previously trained model

In [None]:
# Uncomment to load a saved model
# import pickle
# with open('trained_slda_model.pkl', 'rb') as f:
#     model = pickle.load(f)

## 2. Infer Single New Subject

Let's take a subject from the test set and infer their diagnosis and topic mixture.

In [None]:
# Select a new subject from test set
subject_idx = 0
X_new_subject = X_test[subject_idx]
true_diagnosis = dx_labels[y_test[subject_idx]]

print(f"Analyzing new subject (index {subject_idx})")
print(f"True diagnosis: {true_diagnosis}")
print(f"Features shape: {X_new_subject.shape}")

In [None]:
# Infer diagnosis and topic mixture
results = infer_new_subject(
    model,
    X_new_subject,
    feature_names,
    dx_labels,
    subject_id=f"Test_Subject_{subject_idx}"
)

# Print detailed results
print_inference_results(results, verbose=True)

In [None]:
# Visualize the topic mixture
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Topic mixture
topics = [f'Topic {i}' for i in range(len(results['topic_mixture']))]
ax1.bar(topics, results['topic_mixture'], color='steelblue', alpha=0.7)
ax1.set_ylabel('Proportion', fontsize=12)
ax1.set_title(f"Topic Mixture for {results['subject_id']}", fontsize=13, fontweight='bold')
ax1.set_ylim(0, 1)
ax1.grid(axis='y', alpha=0.3)

# Diagnosis probabilities
dx_names = list(results['diagnosis_breakdown'].keys())
dx_probs = list(results['diagnosis_breakdown'].values())
colors = ['green' if dx == results['predicted_diagnosis'] else 'gray' for dx in dx_names]

ax2.bar(dx_names, dx_probs, color=colors, alpha=0.7)
ax2.set_ylabel('Probability', fontsize=12)
ax2.set_title('Diagnosis Probabilities', fontsize=13, fontweight='bold')
ax2.set_ylim(0, 1)
ax2.axhline(0.5, color='red', linestyle='--', linewidth=1, alpha=0.5, label='50% threshold')
ax2.legend()
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nTrue diagnosis: {true_diagnosis}")
print(f"Predicted diagnosis: {results['predicted_diagnosis']}")
print(f"Correct: {'✓' if true_diagnosis == results['predicted_diagnosis'] else '✗'}")

## 3. Interpret Co-Pathology

The topic mixture tells us about co-pathology patterns.

In [None]:
# Get topic patterns from model
topic_patterns = model.get_topic_patterns()

print("\nTopic Interpretation for this Subject:\n")
print("="*70)

for topic_id, proportion in results['top_topics']:
    if proportion < 0.05:  # Skip negligible topics
        continue

    print(f"\nTopic {topic_id} ({proportion:.1%} of pathology):")

    # Get top regions for this topic
    top_regions = model.get_topic_top_regions(
        topic_id,
        feature_names,
        n_regions=5,
        absolute=True
    )

    print("  Top affected regions:")
    for region, weight in top_regions:
        region_clean = region.replace('ctx_lh_', 'L_').replace('ctx_rh_', 'R_')
        print(f"    - {region_clean}: {weight:+.3f}")

## 4. Compare to Training Subjects

Compare the new subject's topic mixture to typical patterns in each diagnosis group.

In [None]:
# Compare to training subjects
comparison = compare_subject_to_training(
    model,
    X_new_subject,
    X_train,
    y_train,
    dx_labels,
    subject_id=f"Test_Subject_{subject_idx}"
)

print(f"\nSimilarity to Training Diagnosis Groups:")
print("="*50)
for dx, similarity in sorted(comparison['similarities'].items(),
                             key=lambda x: x[1], reverse=True):
    bar = "█" * int(similarity * 40)
    print(f"{dx:<10} {similarity:.3f}  {bar}")

print(f"\nMost similar to: {comparison['most_similar_dx']}")
print(f"True diagnosis: {true_diagnosis}")

## 5. Batch Inference for Multiple New Subjects

Infer diagnoses for all test subjects at once.

In [None]:
# Infer for all test subjects
test_subject_ids = [f"Test_Subject_{i}" for i in range(len(X_test))]

results_df = infer_batch_subjects(
    model,
    X_test,
    feature_names,
    dx_labels,
    subject_ids=test_subject_ids
)

# Add true diagnoses for comparison
results_df['True_DX'] = [dx_labels[y] for y in y_test]
results_df['Correct'] = results_df['Predicted_DX'] == results_df['True_DX']

print(f"\nBatch Inference Results for {len(results_df)} test subjects:")
print(results_df.head(10))

In [None]:
# Overall accuracy
accuracy = results_df['Correct'].mean()
print(f"\nOverall Accuracy on Test Set: {accuracy:.1%}")

# Per-diagnosis accuracy
print("\nPer-Diagnosis Accuracy:")
for dx in dx_labels:
    dx_mask = results_df['True_DX'] == dx
    if dx_mask.sum() > 0:
        dx_accuracy = results_df[dx_mask]['Correct'].mean()
        print(f"  {dx}: {dx_accuracy:.1%} ({dx_mask.sum()} subjects)")

In [None]:
# Save results
results_df.to_csv('test_subjects_inference_results.csv', index=False)
print("\nResults saved to: test_subjects_inference_results.csv")

## 6. Identify Uncertain Cases

Find subjects with low confidence or mixed pathology.

In [None]:
# Find low confidence predictions
low_confidence_threshold = 0.5
uncertain_subjects = results_df[results_df['DX_Confidence'] < low_confidence_threshold]

print(f"\nSubjects with low confidence (< {low_confidence_threshold:.0%}):")
print(f"Found {len(uncertain_subjects)} uncertain cases\n")

if len(uncertain_subjects) > 0:
    display_cols = ['Subject_ID', 'True_DX', 'Predicted_DX', 'DX_Confidence', 'Correct']
    print(uncertain_subjects[display_cols])
else:
    print("No uncertain cases found - all predictions are confident!")

In [None]:
# Find subjects with mixed topic membership (co-pathology)
from scipy.stats import entropy

# Compute entropy for each subject's topic mixture
topic_cols = [col for col in results_df.columns if col.startswith('Topic_')]
topic_mixtures = results_df[topic_cols].values

results_df['Topic_Entropy'] = [entropy(mix) for mix in topic_mixtures]

# High entropy = mixed topics = co-pathology
high_entropy_subjects = results_df.nlargest(10, 'Topic_Entropy')

print("\nTop 10 Subjects with Co-Pathology (High Topic Entropy):\n")
display_cols = ['Subject_ID', 'True_DX', 'Predicted_DX', 'Topic_Entropy'] + topic_cols
print(high_entropy_subjects[display_cols])

## 7. Example: Real-World New Subject

If you have a completely new subject's data:

In [None]:
# Example: Manual input for a new subject
# Replace with actual atrophy values from your subject

# Load new subject from CSV (example)
# new_subject_df = pd.read_csv('/path/to/new_subject_data.csv')
# X_new = new_subject_df[feature_names].values[0]

# Or create manually (must have 62 cortical features)
# X_new = np.array([0.5, 0.3, ..., 0.8])  # 62 values

# Then infer:
# results = infer_new_subject(
#     model,
#     X_new,
#     feature_names,
#     dx_labels,
#     subject_id="PATIENT_NEW_001"
# )
# print_inference_results(results)

print("\nTo use with your own new subject:")
print("1. Extract 62 cortical features (ctx_lh_* and ctx_rh_*)")
print("2. Ensure features are in the same order as training data")
print("3. Use infer_new_subject() as shown above")
print("4. Interpret topic mixture for co-pathology patterns")

## Summary

### What We Can Infer for New Subjects:

1. **Diagnosis Probabilities**: P(AD), P(PD), P(DLB), P(SVAD), P(HC)
2. **Topic Mixture**: Proportion of each pathology pattern (θ)
3. **Co-Pathology**: Mixed topic membership indicates overlapping pathologies
4. **Similarity**: How similar to typical patterns in each diagnosis group

### Key Functions:

- `infer_new_subject()`: Single subject inference with detailed results
- `infer_batch_subjects()`: Batch inference returning DataFrame
- `compare_subject_to_training()`: Compare to training diagnosis groups
- `print_inference_results()`: Pretty print results

### Interpretation:

- **High confidence + single dominant topic**: Clear single pathology
- **High confidence + mixed topics**: Clear co-pathology pattern
- **Low confidence**: Uncertain or atypical presentation
- **High entropy**: Multiple co-occurring pathology patterns