# Discovery vs Prediction Framework Overview

This notebook provides a comprehensive overview of the two modes of Aladynoulli:
1. **Discovery Mode**: Joint phi estimation for understanding disease biology
2. **Prediction Mode**: Fixed phi for clinical predictions

We also demonstrate:
- How thetas update over time with different washout periods
- How to calculate GWAS AUC without genotypes (G) using signature loadings


---
## Part 1: Discovery vs Prediction Framework

### ðŸ”¬ Discovery Mode: Joint Phi Estimation

**Purpose**: Learn disease signatures (phi) from data to understand disease connections

**How it works**:
- **Phi is learned** from the data (joint estimation with lambda)
- Model learns which diseases cluster together into signatures
- Used for understanding disease biology and relationships
- **Lambda** (individual signature loadings) is also learned
- **Theta** = softmax(lambda) = individual signature proportions over time

**Location**: `/Dropbox/enrollment_retrospective_full/`
- Generated by: `run_full_retrospective.sh` â†’ `run_aladyn_batch.py`
- Outputs: `enrollment_model_W0.0001_batch_*_*.pt` (contains learned phi and lambda)

**Key Characteristics**:
- Phi can vary between batches (batch-specific disease signatures)
- Theta reflects individual signature loadings learned jointly with phi
- Best for: Understanding disease biology, pathway discovery, heritability analysis

---

### ðŸŽ¯ Prediction Mode: Fixed Phi

**Purpose**: Make predictions using pre-learned signatures (stable, generalizable)

**How it works**:
- **Phi is fixed** from master checkpoints (pre-learned signatures)
- Only **lambda** (individual signature loadings) is estimated
- **Theta** = softmax(lambda) = individual signature proportions over time
- Used for clinical prediction and performance evaluation

**Location**: `/Dropbox/models_fromAWS_enrollment_predictions_fixedphi_RETROSPECTIVE_pooled/`
- Generated by: `run_aladyn_predict_with_master.py` (on AWS)
- Outputs: `pi_enroll_fixedphi_sex_*.pt` (pi tensors for predictions)

**Key Characteristics**:
- Phi is stable across all patients (from master checkpoint)
- Theta reflects individual signature loadings given fixed phi
- Best for: Clinical predictions, performance evaluation, washout analyses

---

### Comparison Table

| Aspect | Discovery Mode | Prediction Mode |
|--------|---------------|------------------|
| **Phi** | Learned (joint) | Fixed (from master) |
| **Lambda** | Learned (joint) | Learned (phi fixed) |
| **Theta** | softmax(lambda) | softmax(lambda) |
| **Purpose** | Disease biology | Clinical prediction |
| **Stability** | Batch-specific | Stable across all |
| **Use Cases** | Pathway discovery, heritability | AUC, washout, age offset |


---
## Part 2: Timeline of Theta Updates

### How Thetas Update Over Time

Thetas (signature proportions) update as more data becomes available. This is demonstrated using the **age offset** approach:

- **Age Offset 0**: Predict at enrollment age using data up to enrollment
- **Age Offset 5**: Predict at enrollment+5 years using data up to enrollment+5 years
- **Age Offset 9**: Predict at enrollment+9 years using data up to enrollment+9 years

As more data becomes available (higher age offset), the model learns more about each patient's signature loadings, and thetas update accordingly.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import pandas as pd

# Example: Load thetas from different age offsets
# This demonstrates how thetas update as more data becomes available

print("="*80)
print("THETA UPDATE TIMELINE DEMONSTRATION")
print("="*80)
print("\nThis shows how individual signature loadings (thetas) update")
print("as more data becomes available over time.")
print("\nAge Offset 0: Data up to enrollment age")
print("Age Offset 5: Data up to enrollment+5 years")
print("Age Offset 9: Data up to enrollment+9 years")
print("\nAs offset increases, thetas become more refined based on observed events.")


In [None]:
# Example visualization function for theta updates
def visualize_theta_updates(patient_idx, signature_idx, thetas_by_offset, timepoints):
    """
    Visualize how theta for a specific patient and signature updates across age offsets
    
    Parameters:
    - patient_idx: Patient index
    - signature_idx: Signature index to plot
    - thetas_by_offset: Dict {offset: theta_tensor} where theta_tensor is [N, K, T]
    - timepoints: Array of timepoints (ages)
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Theta trajectory over time for different offsets
    ax1 = axes[0]
    for offset, thetas in sorted(thetas_by_offset.items()):
        theta_traj = thetas[patient_idx, signature_idx, :]
        ax1.plot(timepoints, theta_traj, label=f'Offset {offset}', marker='o', markersize=3)
    
    ax1.set_xlabel('Age (years)', fontsize=12)
    ax1.set_ylabel(f'Theta (Signature {signature_idx})', fontsize=12)
    ax1.set_title(f'Patient {patient_idx}: Theta Updates Across Age Offsets', fontsize=14)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Theta AUC (area under curve) for different offsets
    ax2 = axes[1]
    offsets = sorted(thetas_by_offset.keys())
    theta_aucs = []
    
    for offset in offsets:
        thetas = thetas_by_offset[offset]
        theta_traj = thetas[patient_idx, signature_idx, :]
        # Calculate AUC using trapezoidal integration
        auc = np.trapz(theta_traj, timepoints)
        theta_aucs.append(auc)
    
    ax2.plot(offsets, theta_aucs, marker='o', linewidth=2, markersize=8)
    ax2.set_xlabel('Age Offset (years)', fontsize=12)
    ax2.set_ylabel(f'Theta AUC (Signature {signature_idx})', fontsize=12)
    ax2.set_title(f'Patient {patient_idx}: Theta AUC vs Age Offset', fontsize=14)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

print("\nVisualization function created.")
print("\nTo use this function, you would load thetas from different age offsets:")
print("  thetas_by_offset = {")
print("      0: load_thetas(offset=0),  # Enrollment only")
print("      5: load_thetas(offset=5),  # Enrollment + 5 years")
print("      9: load_thetas(offset=9)   # Enrollment + 9 years")
print("  }")


---
## Part 3: GWAS AUC Calculation Without Genotypes (G)

### Overview

One powerful feature of Aladynoulli is that we can perform GWAS on signature loadings **without needing genotypes (G)**. Instead, we use:

1. **Lambda** (raw signature loadings) or **Theta** (softmax of lambda) from the model
2. **Theta AUC**: Area under the curve of theta trajectories over time
3. **GWAS**: Regress theta AUCs against genotypes

### Workflow

```
Step 1: Fit model (with or without G)
  â†“
Step 2: Extract lambda (signature loadings) for each patient
  â†“
Step 3: Calculate theta = softmax(lambda) [N, K, T]
  â†“
Step 4: Calculate theta AUC = âˆ« theta(t) dt for each signature [N, K]
  â†“
Step 5: Run GWAS: theta_AUC ~ genotype for each signature
```

### Key Insight

The model learns signature loadings (lambda/theta) from disease data (Y, E), which capture genetic and environmental factors. These loadings can then be used as phenotypes for GWAS, even if G was not used in the model fitting.


In [None]:
def softmax(x):
    """
    Compute softmax values for each set of scores in x.
    x shape: (n_individuals, n_signatures, n_timepoints) or (n_individuals, n_signatures)
    """
    if x.ndim == 3:
        # Reshape to (n_individuals * n_timepoints, n_signatures)
        x_reshaped = x.transpose(0, 2, 1).reshape(-1, x.shape[1])
        
        # Compute softmax
        e_x = np.exp(x_reshaped - np.max(x_reshaped, axis=1, keepdims=True))
        softmax_x = e_x / np.sum(e_x, axis=1, keepdims=True)
        
        # Reshape back to original shape
        return softmax_x.reshape(x.shape[0], x.shape[2], x.shape[1]).transpose(0, 2, 1)
    elif x.ndim == 2:
        # Simple 2D case
        e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
        return e_x / np.sum(e_x, axis=1, keepdims=True)
    else:
        raise ValueError(f"Expected 2D or 3D array, got {x.ndim}D")


def calculate_theta_aucs(all_lambdas, timepoints=None):
    """
    Calculate theta AUCs from lambda values.
    
    Parameters:
    - all_lambdas: Array of shape (N, K, T) where N=patients, K=signatures, T=timepoints
    - timepoints: Array of timepoints (default: np.arange(T))
    
    Returns:
    - theta_aucs: Array of shape (N, K) with AUC values for each patient-signature pair
    """
    N, K, T = all_lambdas.shape
    
    if timepoints is None:
        timepoints = np.arange(T)
    
    # Calculate thetas using softmax
    all_thetas = softmax(all_lambdas)  # Shape: (N, K, T)
    
    # Calculate AUC for each patient-signature pair
    theta_aucs = np.zeros((N, K))
    for i in range(N):
        for s in range(K):
            theta_aucs[i, s] = np.trapz(all_thetas[i, s, :], timepoints)
    
    return theta_aucs, all_thetas


print("="*80)
print("GWAS AUC CALCULATION WITHOUT GENOTYPES")
print("="*80)
print("\nFunctions created:")
print("  1. softmax(): Converts lambda to theta")
print("  2. calculate_theta_aucs(): Calculates AUC for each patient-signature pair")
print("\nWorkflow:")
print("  1. Load lambda from model (shape: [N, K, T])")
print("  2. Calculate theta = softmax(lambda)")
print("  3. Calculate theta_AUC = âˆ« theta(t) dt for each signature")
print("  4. Run GWAS: theta_AUC[:, signature] ~ genotype for each SNP")
print("\nThis allows GWAS on signature loadings without needing G in the model!")


In [None]:
# Example: Schematic visualization of the workflow
def create_gwas_workflow_schematic():
    """
    Create a schematic diagram showing the GWAS workflow without G
    """
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.axis('off')
    
    # Define positions
    y_start = 0.9
    y_step = 0.15
    
    # Step 1: Model fitting
    ax.text(0.1, y_start, 'Step 1: Fit Model', fontsize=14, weight='bold',
            bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
    ax.text(0.1, y_start - 0.05, 'Input: Y (diseases), E (event times)', fontsize=10)
    ax.text(0.1, y_start - 0.08, 'Optional: G (genotypes)', fontsize=10, style='italic')
    
    # Arrow
    ax.arrow(0.5, y_start - 0.1, 0, -0.05, head_width=0.02, head_length=0.01, fc='black')
    
    # Step 2: Extract lambda
    y_pos = y_start - y_step
    ax.text(0.1, y_pos, 'Step 2: Extract Lambda', fontsize=14, weight='bold',
            bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.7))
    ax.text(0.1, y_pos - 0.05, 'Lambda shape: [N, K, T]', fontsize=10)
    ax.text(0.1, y_pos - 0.08, 'N=patients, K=signatures, T=timepoints', fontsize=10)
    
    # Arrow
    ax.arrow(0.5, y_pos - 0.1, 0, -0.05, head_width=0.02, head_length=0.01, fc='black')
    
    # Step 3: Calculate theta
    y_pos = y_pos - y_step
    ax.text(0.1, y_pos, 'Step 3: Calculate Theta', fontsize=14, weight='bold',
            bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7))
    ax.text(0.1, y_pos - 0.05, 'Theta = softmax(Lambda)', fontsize=10)
    ax.text(0.1, y_pos - 0.08, 'Theta shape: [N, K, T]', fontsize=10)
    
    # Arrow
    ax.arrow(0.5, y_pos - 0.1, 0, -0.05, head_width=0.02, head_length=0.01, fc='black')
    
    # Step 4: Calculate theta AUC
    y_pos = y_pos - y_step
    ax.text(0.1, y_pos, 'Step 4: Calculate Theta AUC', fontsize=14, weight='bold',
            bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.7))
    ax.text(0.1, y_pos - 0.05, 'Theta_AUC = âˆ« theta(t) dt', fontsize=10)
    ax.text(0.1, y_pos - 0.08, 'Theta_AUC shape: [N, K]', fontsize=10)
    
    # Arrow
    ax.arrow(0.5, y_pos - 0.1, 0, -0.05, head_width=0.02, head_length=0.01, fc='black')
    
    # Step 5: GWAS
    y_pos = y_pos - y_step
    ax.text(0.1, y_pos, 'Step 5: Run GWAS', fontsize=14, weight='bold',
            bbox=dict(boxstyle='round', facecolor='lightpink', alpha=0.7))
    ax.text(0.1, y_pos - 0.05, 'For each signature k:', fontsize=10)
    ax.text(0.1, y_pos - 0.08, '  Theta_AUC[:, k] ~ genotype for each SNP', fontsize=10)
    
    # Side note
    ax.text(0.65, 0.5, 'Key Insight:', fontsize=12, weight='bold')
    ax.text(0.65, 0.45, 'Model learns signature', fontsize=10)
    ax.text(0.65, 0.42, 'loadings from disease', fontsize=10)
    ax.text(0.65, 0.39, 'data (Y, E), which', fontsize=10)
    ax.text(0.65, 0.36, 'capture genetic and', fontsize=10)
    ax.text(0.65, 0.33, 'environmental factors.', fontsize=10)
    ax.text(0.65, 0.28, 'These can be used as', fontsize=10)
    ax.text(0.65, 0.25, 'phenotypes for GWAS', fontsize=10)
    ax.text(0.65, 0.22, 'even if G was not', fontsize=10)
    ax.text(0.65, 0.19, 'used in model fitting!', fontsize=10)
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('GWAS Workflow Without Genotypes (G)', fontsize=16, weight='bold', pad=20)
    
    plt.tight_layout()
    return fig

fig = create_gwas_workflow_schematic()
plt.show()

print("\nSchematic created showing the complete workflow.")


---
## Summary

### Key Takeaways

1. **Discovery Mode (Joint Phi)**:
   - Phi and lambda learned jointly from data
   - Best for understanding disease biology
   - Thetas reflect individual signature loadings learned with batch-specific phi

2. **Prediction Mode (Fixed Phi)**:
   - Phi fixed from master checkpoint
   - Lambda learned given fixed phi
   - Best for clinical predictions
   - Thetas reflect individual signature loadings given stable phi

3. **Theta Updates**:
   - Thetas update as more data becomes available (age offset approach)
   - Higher age offset = more data = more refined thetas
   - Theta AUC captures cumulative signature loading over time

4. **GWAS Without G**:
   - Model learns signature loadings from disease data (Y, E)
   - These loadings capture genetic and environmental factors
   - Can be used as phenotypes for GWAS even if G wasn't in the model
   - Workflow: Lambda â†’ Theta â†’ Theta AUC â†’ GWAS

### Applications

- **Discovery Mode**: Pathway analysis, heritability, population stratification
- **Prediction Mode**: Clinical risk prediction, AUC evaluation, washout validation
- **Theta AUC**: GWAS on signature loadings, genetic architecture of disease signatures
