# Initializing AOU and MGB Models with Corrected E and Prevalence

This notebook consolidates all code needed to:
1. Correct E matrices for AOU and MGB using max censor data
2. Compute corrected prevalence with at-risk filtering
3. Initialize AOU and MGB models with correct clusters, psi config, and signature references


In [7]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append('/Users/sarahurbut/aladynoulli2/pyScripts/')
sys.path.append('/Users/sarahurbut/aladynoulli2/claudefile/aws_offsetmaster/')
sys.path.append('/Users/sarahurbut/aladynoulli2/pyScripts_forPublish/')

import rpy2.robjects as robjects
from rpy2.robjects import numpy2ri, pandas2ri
import numpy as np
import pandas as pd
import torch
from scipy.ndimage import gaussian_filter1d
from scipy.special import logit
from statsmodels.nonparametric.smoothers_lowess import lowess
from clust_huge_amp_vectorized import *


In [8]:

numpy2ri.activate()
pandas2ri.activate()

torch.manual_seed(7)
np.random.seed(4)

print("Setup complete")


Setup complete


## Helper Functions


In [9]:
def compute_smoothed_prevalence_at_risk(Y, E_corrected, enrollment_ages, window_size=5, smooth_on_logit=True):
    """Compute smoothed prevalence with proper at-risk filtering."""
    if torch.is_tensor(Y):
        Y = Y.numpy()
    if torch.is_tensor(E_corrected):
        E_corrected = E_corrected.numpy()
    
    N, D, T = Y.shape
    prevalence_t = np.zeros((D, T))
    timepoint_ages = np.arange(T) + 30
    
    print(f"Computing prevalence for {D} diseases, {T} timepoints...")
    
    E_corrected_np = E_corrected.numpy() if torch.is_tensor(E_corrected) else E_corrected
    
    for d in range(D):
        if d % 50 == 0:
            print(f"  Processing disease {d}/{D}...")
        
        for t in range(T):
            at_risk_mask = (E_corrected_np[:, d] >= t)
            
            if at_risk_mask.sum() > 0:
                prevalence_t[d, t] = Y[at_risk_mask, d, t].mean()
            else:
                prevalence_t[d, t] = np.nan
        
        if smooth_on_logit:
            epsilon = 1e-8
            valid_mask = ~np.isnan(prevalence_t[d, :])
            if valid_mask.sum() > 0:
                logit_prev = np.full(T, np.nan)
                logit_prev[valid_mask] = np.log(
                    (prevalence_t[d, valid_mask] + epsilon) / 
                    (1 - prevalence_t[d, valid_mask] + epsilon)
                )
                smoothed_logit = gaussian_filter1d(np.nan_to_num(logit_prev, nan=0), sigma=window_size)
                smoothed_logit[~valid_mask] = np.nan
                prevalence_t[d, :] = 1 / (1 + np.exp(-smoothed_logit))
        else:
            prevalence_t[d, :] = gaussian_filter1d(np.nan_to_num(prevalence_t[d, :], nan=0), sigma=window_size)
    
    return prevalence_t


def create_reference_trajectories(Y_filtered, initial_clusters, K, healthy_prop=0, frac=0.3):
    """Create reference trajectories using LOWESS smoothing on logit scale"""
    T = Y_filtered.shape[2]
    Y_counts = Y_filtered.sum(dim=0)  
    signature_props = torch.zeros(K, T)
    total_counts = Y_counts.sum(dim=0) + 1e-8
    
    for k in range(K):
        cluster_mask = (initial_clusters == k)
        signature_props[k] = Y_counts[cluster_mask].sum(dim=0) / total_counts
    
    signature_props = torch.clamp(signature_props, min=1e-8, max=1-1e-8)
    signature_props = signature_props / signature_props.sum(dim=0, keepdim=True)
    signature_props *= (1 - healthy_prop)
    
    logit_props = torch.tensor(logit(signature_props.numpy()))
    signature_refs = torch.zeros_like(logit_props)
    
    times = np.arange(T)
    for k in range(K):
        smoothed = lowess(logit_props[k].numpy(), times, frac=frac, it=3, delta=0.0, return_sorted=False)
        signature_refs[k] = torch.tensor(smoothed)
    
    healthy_ref = torch.ones(T) * logit(torch.tensor(healthy_prop))
    return signature_refs, healthy_ref


## Part 1: AOU Model Initialization


In [10]:
# ============================================================================
# AOU: Load Data
# ============================================================================
print("=" * 60)
print("AOU: Loading data...")
print("=" * 60)

data_path = "/Users/sarahurbut/Library/CloudStorage/DB_backup_5132025941p/aou_fromdl/"

Y_aou = np.array(robjects.r['readRDS'](os.path.join(data_path, 'Y_binary.rds')))
E_aou = np.array(robjects.r['readRDS'](os.path.join(data_path, 'E_binary.rds')))
E_aou = E_aou.astype(int)
Y_tensor_aou = torch.FloatTensor(Y_aou)
E_tensor_aou = torch.FloatTensor(E_aou)

YandEpatientnames_aou = pd.read_csv(os.path.join(data_path, 'patient_names.csv'))
print(f"AOU Y shape: {Y_tensor_aou.shape}")
print(f"AOU E shape: {E_tensor_aou.shape}")


AOU: Loading data...
AOU Y shape: torch.Size([10000, 348, 51])
AOU E shape: torch.Size([10000, 348])


In [11]:
# ============================================================================
# AOU: Load and Align Censor Data
# ============================================================================
print("\n" + "=" * 60)
print("AOU: Loading and aligning censor data...")
print("=" * 60)

# Load ICD10 data and compute max censor
aou_df = pd.read_csv('/Users/sarahurbut/aladynoulli2/aou_sub.csv')
aou_max_followup = aou_df.groupby('eid')['age_diag'].max().reset_index()
aou_max_followup.columns = ['eid', 'max_censor']

# Match to Y/E patient order
censor_df_aou = YandEpatientnames_aou.merge(
    aou_max_followup, 
    left_on='x',
    right_on='eid',
    how='left'
)

# Fill missing
max_age_default = aou_max_followup['max_censor'].max() if len(aou_max_followup) > 0 else 81
censor_df_aou['max_censor'] = censor_df_aou['max_censor'].fillna(max_age_default)
censor_df_aou['age'] = censor_df_aou['max_censor']

# CRITICAL: Reorder to match YandEpatientnames order
if 'x' in YandEpatientnames_aou.columns and 'eid' in censor_df_aou.columns:
    order_matches = (censor_df_aou['eid'].values == YandEpatientnames_aou['x'].values).all()
    
    if not order_matches:
        print(f"⚠ Reordering censor_df to match YandEpatientnames order...")
        censor_df_aou = censor_df_aou.set_index('eid').reindex(YandEpatientnames_aou['x']).reset_index()
        
        # Handle potential KeyError if reset_index renamed 'eid' to 'index'
        if 'eid' not in censor_df_aou.columns and 'index' in censor_df_aou.columns:
            censor_df_aou = censor_df_aou.rename(columns={'index': 'eid'})
        
        # Fill any missing patients
        missing_mask = censor_df_aou['max_censor'].isna()
        if missing_mask.any():
            censor_df_aou.loc[missing_mask, 'max_censor'] = max_age_default
            censor_df_aou.loc[missing_mask, 'age'] = max_age_default
        
        order_matches = (censor_df_aou['eid'].values == YandEpatientnames_aou['x'].values).all()
        if order_matches:
            print(f"✓ Order now matches YandEpatientnames")
    else:
        print(f"✓ Order already matches YandEpatientnames")

print(f"Matched {censor_df_aou['max_censor'].notna().sum()} / {len(censor_df_aou)} patients")



AOU: Loading and aligning censor data...
✓ Order already matches YandEpatientnames
Matched 10000 / 10000 patients


In [12]:
# ============================================================================
# AOU: Correct E Matrix
# ============================================================================
print("\n" + "=" * 60)
print("AOU: Correcting E matrix...")
print("=" * 60)

T_aou = Y_tensor_aou.shape[2]

# Convert max_censor ages to timepoints (age 30 = timepoint 0)
max_timepoints_aou = torch.tensor(
    (censor_df_aou['max_censor'].values - 30).clip(0, T_aou-1).astype(int)
)

# Only update censored cases (where E == T-1)
censored_mask_aou = (E_tensor_aou == T_aou - 1)
max_timepoints_expanded_aou = max_timepoints_aou.unsqueeze(1).expand_as(E_tensor_aou)

# Update only censored positions
E_corrected_aou = torch.where(
    censored_mask_aou,
    torch.minimum(E_tensor_aou, max_timepoints_expanded_aou),
    E_tensor_aou
)

enrollment_ages_aou = censor_df_aou['max_censor'].values

print(f"E matrix correction complete:")
print(f"  Original E shape: {E_tensor_aou.shape}")
print(f"  Corrected E shape: {E_corrected_aou.shape}")
print(f"  Patients with corrections: {(censored_mask_aou.sum(dim=1) > 0).sum().item()} / {E_tensor_aou.shape[0]}")

# Save corrected E
torch.save(E_corrected_aou, '/Users/sarahurbut/aladynoulli2/aou_E_corrected.pt')
print(f"✓ Saved corrected E to: aou_E_corrected.pt")



AOU: Correcting E matrix...
E matrix correction complete:
  Original E shape: torch.Size([10000, 348])
  Corrected E shape: torch.Size([10000, 348])
  Patients with corrections: 10000 / 10000
✓ Saved corrected E to: aou_E_corrected.pt


In [13]:
# ============================================================================
# AOU: Compute Corrected Prevalence
# ============================================================================
print("\n" + "=" * 60)
print("AOU: Computing corrected prevalence...")
print("=" * 60)

new_prevalence_t_aou = compute_smoothed_prevalence_at_risk(
    Y=Y_tensor_aou, 
    E_corrected=E_corrected_aou, 
    enrollment_ages=enrollment_ages_aou,
    window_size=5,
    smooth_on_logit=True
)

print(f"New prevalence shape: {new_prevalence_t_aou.shape}")

# Convert to logit and save
if torch.is_tensor(new_prevalence_t_aou):
    logit_prev_np_aou = new_prevalence_t_aou.numpy()
else:
    logit_prev_np_aou = new_prevalence_t_aou

epsilon = 1e-8
logit_prev_aou = np.log((logit_prev_np_aou + epsilon) / (1 - logit_prev_np_aou + epsilon))

# Save prevalence files
torch.save(torch.tensor(logit_prev_aou), '/Users/sarahurbut/aladynoulli2/aou_logit_prev_corrected_E.pt')
torch.save(torch.tensor(logit_prev_np_aou), '/Users/sarahurbut/aladynoulli2/aou_prevalence_corrected_E.pt')
print(f"✓ Saved prevalence files")



AOU: Computing corrected prevalence...
Computing prevalence for 348 diseases, 51 timepoints...
  Processing disease 0/348...
  Processing disease 50/348...
  Processing disease 100/348...
  Processing disease 150/348...
  Processing disease 200/348...
  Processing disease 250/348...
  Processing disease 300/348...
New prevalence shape: (348, 51)
✓ Saved prevalence files


In [6]:
aou_checkpoint = torch.load('/Users/sarahurbut/aladynoulli2/aou_model_initialized.pt', map_location='cpu')
aou_checkpoint['model_state_dict']['psi']



  aou_checkpoint = torch.load('/Users/sarahurbut/aladynoulli2/aou_model_initialized.pt', map_location='cpu')


tensor([[-2.0042, -2.0507, -1.9973,  ..., -2.0180, -2.0831, -2.1101],
        [-2.0784, -1.4886,  1.7180,  ..., -3.4644,  1.9670, -2.4102],
        [-2.0018, -1.9618, -2.0133,  ..., -2.3200, -2.1821, -2.3418],
        ...,
        [-2.0001, -2.0827, -1.9617,  ..., -1.9554, -2.0996, -2.0250],
        [-1.8886, -2.0038, -1.8580,  ..., -1.9483, -1.9566, -2.1870],
        [-4.9788, -4.9464, -4.9114,  ..., -5.0275, -5.0312, -4.9012]])

In [44]:
# ============================================================================
# AOU: Initialize Model
# ============================================================================
print("\n" + "=" * 60)
print("AOU: Initializing model...")
print("=" * 60)

# Load old AOU checkpoint for clusters and G
aou_checkpoint_old = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/model_with_kappa_bigam_AOU.pt', map_location='cpu')
initial_clusters_aou = aou_checkpoint_old['clusters']
if isinstance(initial_clusters_aou, torch.Tensor):
    initial_clusters_aou = initial_clusters_aou.numpy()
else:
    initial_clusters_aou = np.array(initial_clusters_aou)

K_aou = int(initial_clusters_aou.max() + 1)
print(f"AOU: K={K_aou} signatures")

# Create signature references
signature_refs_aou, healthy_ref_aou = create_reference_trajectories(
    Y_tensor_aou, initial_clusters_aou, K=K_aou, healthy_prop=0, frac=0.3
)

# Get G and disease names
G_aou = aou_checkpoint_old['G']
disease_names_aou = aou_checkpoint_old['disease_names']
prevalence_t_aou = torch.load('/Users/sarahurbut/aladynoulli2/aou_prevalence_corrected_E.pt')

# Create model
model_aou = AladynSurvivalFixedKernelsAvgLoss_clust_logitInit_psitest(
    N=Y_tensor_aou.shape[0],
    D=Y_tensor_aou.shape[1], 
    T=Y_tensor_aou.shape[2], 
    K=K_aou,
    P=G_aou.shape[1],
    init_sd_scaler=1e-1,
    G=G_aou, 
    Y=Y_tensor_aou,
    genetic_scale=1,
    W=0.0001,
    R=0,
    prevalence_t=prevalence_t_aou,
    signature_references=signature_refs_aou,
    healthy_reference=True,
    disease_names='disease_names'
)

# Set clusters and initialize
model_aou.clusters = initial_clusters_aou
psi_config = {'in_cluster': 1, 'out_cluster': -2, 'noise_in': 0.1, 'noise_out': 0.01}
model_aou.initialize_params(psi_config=psi_config)

# Verify
clusters_match = np.array_equal(initial_clusters_aou, model_aou.clusters)
print(f"✓ Clusters match: {clusters_match}")


history = model_aou.fit(E_corrected_aou,
num_epochs=200,learning_rate=1e-1,
lambda_reg=1e-2)

# Save initialized model
save_dict_aou = {
    'model_state_dict': model_aou.state_dict(),
    'clusters': initial_clusters_aou,
    'signature_refs': signature_refs_aou,
    'healthy_ref': healthy_ref_aou,
    'psi_config': psi_config,
    'hyperparameters': {
        'N': Y_tensor_aou.shape[0],
        'D': Y_tensor_aou.shape[1],
        'T': Y_tensor_aou.shape[2],
        'K': K_aou,
        'P': G_aou.shape[1],
        'init_sd_scaler': 1e-1,
        'genetic_scale': 1,
        'W': 0.0001,
        'R': 0,
    },
    'prevalence_t': prevalence_t_aou,
    'disease_names': disease_names_aou,
}


torch.save(save_dict_aou, '/Users/sarahurbut/aladynoulli2/aou_model_initialized.pt')
print(f"✓ Saved AOU initialized model to: aou_model_initialized.pt")



AOU: Initializing model...
AOU: K=20 signatures


  aou_checkpoint_old = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/model_with_kappa_bigam_AOU.pt', map_location='cpu')
  prevalence_t_aou = torch.load('/Users/sarahurbut/aladynoulli2/aou_prevalence_corrected_E.pt')
  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.G = torch.tensor(G, dtype=torch.float32)
  self.G = torch.tensor(G_scaled, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  self.prevalence_t = torch.tensor(prevalence_t, dtype=torch.float32)



Cluster Sizes:
Cluster 0: 4 diseases
Cluster 1: 38 diseases
Cluster 2: 10 diseases
Cluster 3: 18 diseases
Cluster 4: 8 diseases
Cluster 5: 67 diseases
Cluster 6: 20 diseases
Cluster 7: 23 diseases
Cluster 8: 19 diseases
Cluster 9: 17 diseases
Cluster 10: 6 diseases
Cluster 11: 8 diseases
Cluster 12: 36 diseases
Cluster 13: 13 diseases
Cluster 14: 13 diseases
Cluster 15: 8 diseases
Cluster 16: 21 diseases
Cluster 17: 4 diseases
Cluster 18: 3 diseases
Cluster 19: 12 diseases

Calculating gamma for k=0:
Number of diseases in cluster: 4
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.5078, -0.5078, -0.5078, -0.5078, -0.5078])
Base value centered mean: -2.5932311018550536e-06
Gamma init for k=0 (first 5): tensor([0.0042, 0.0134])

Calculating gamma for k=1:
Number of diseases in cluster: 38
Base value (first 5): tensor([-10.6881, -13.8155, -13.2943, -13.0337, -13.2943])
Base value centered (first 5): tensor([ 2.2550

  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 162.4854

Monitoring signature responses:

Disease 293 (signature 5, LR=29.74):
  Theta for diagnosed: 0.066 ± 0.015
  Theta for others: 0.067
  Proportion difference: -0.000

Disease 55 (signature 1, LR=29.13):
  Theta for diagnosed: 0.151 ± 0.039
  Theta for others: 0.149
  Proportion difference: 0.002

Disease 283 (signature 7, LR=28.86):
  Theta for diagnosed: 0.042 ± 0.036
  Theta for others: 0.042
  Proportion difference: 0.000

Disease 139 (signature 5, LR=28.35):
  Theta for diagnosed: 0.067 ± 0.016
  Theta for others: 0.067
  Proportion difference: 0.000

Disease 294 (signature 13, LR=28.14):
  Theta for diagnosed: 0.049 ± 0.010
  Theta for others: 0.048
  Proportion difference: 0.000

Epoch 1
Loss: 782.9916

Monitoring signature responses:

Disease 293 (signature 5, LR=29.78):
  Theta for diagnosed: 0.066 ± 0.014
  Theta for others: 0.066
  Proportion difference: -0.000

Disease 55 (signature 1, LR=29.20):
  Theta for diagnosed: 0.151 ± 0.038
  Theta for others

## Part 2: MGB Model Initialization


In [45]:
# ============================================================================
# MGB: Load Data
# ============================================================================
print("\n" + "=" * 60)
print("MGB: Loading data...")
print("=" * 60)

data_path_mgb = "/Users/sarahurbut/Dropbox-Personal/mgbbtopic/"

# Load Y and original E
Y_mgb = np.array(robjects.r['readRDS'](os.path.join(data_path_mgb, 'Y_sub.rds')))
E_mgb = np.array(robjects.r['readRDS'](os.path.join(data_path_mgb, 'E_sub.rds')))
E_mgb = E_mgb.astype(int)
Y_tensor_mgb = torch.FloatTensor(Y_mgb)
E_tensor_mgb = torch.FloatTensor(E_mgb)

print(f"MGB Y shape: {Y_tensor_mgb.shape}")
print(f"MGB E shape: {E_tensor_mgb.shape}")



MGB: Loading data...
MGB Y shape: torch.Size([34592, 346, 51])
MGB E shape: torch.Size([34592, 346])


In [46]:
# ============================================================================
# MGB: Load and Align Censor Data
# ============================================================================
print("\n" + "=" * 60)
print("MGB: Loading and aligning censor data...")
print("=" * 60)

# Load patient names in Y/E order (must be loaded first to know the order)
YandEpatientnames_mgb = pd.read_csv('/Users/sarahurbut/aladynoulli2/mgb_patientnames.csv')
print(f"MGB patient names shape: {YandEpatientnames_mgb.shape}")
print(f"MGB patient names columns: {YandEpatientnames_mgb.columns.tolist()}")
print(f"First few patient names:\n{YandEpatientnames_mgb.head()}")

# Load max censor data (created from R: mgb%>%group_by(eid)%>%summarise(max(age_diag)))
# Try different parsing approaches
try:
    max_censor_mgb = pd.read_csv('/Users/sarahurbut/aladynoulli2/max_censor_mgb.csv')
except:
    try:
        max_censor_mgb = pd.read_csv('/Users/sarahurbut/aladynoulli2/max_censor_mgb.csv')
    except:
        # If both fail, try without sep (auto-detect)
        max_censor_mgb = pd.read_csv('/Users/sarahurbut/aladynoulli2/max_censor_mgb.csv')

# Handle column name issues - check if we have a malformed header
print(f"\nMGB max censor shape (before cleaning): {max_censor_mgb.shape}")
print(f"MGB max censor columns (before cleaning): {max_censor_mgb.columns.tolist()}")

# Check for malformed header (like ',"eid","age"')
if len(max_censor_mgb.columns) == 1 and max_censor_mgb.columns[0].startswith(','):
    # Header is malformed, try to parse it
    print("⚠ Detected malformed header, attempting to fix...")
    # Read again with header=None and set proper column names
    max_censor_mgb = pd.read_csv('/Users/sarahurbut/aladynoulli2/max_censor_mgb.csv', sep='\t', header=0)
    # Try to extract column names from the first row if it's a header
    first_row = max_censor_mgb.iloc[0, 0] if len(max_censor_mgb.columns) == 1 else None
    if first_row and 'eid' in str(first_row).lower():
        # Skip first row and set proper column names
        max_censor_mgb = pd.read_csv('/Users/sarahurbut/aladynoulli2/max_censor_mgb.csv', sep='\t', skiprows=1, names=['eid', 'age'])

# Handle various column name formats
if 'X.eid' in max_censor_mgb.columns:
    max_censor_mgb = max_censor_mgb.rename(columns={'X.eid': 'eid'})
if 'X.age' in max_censor_mgb.columns:
    max_censor_mgb = max_censor_mgb.rename(columns={'X.age': 'age'})

# Drop 'Unnamed: 0' column if it exists (it's just the row index from R)
if 'Unnamed: 0' in max_censor_mgb.columns:
    max_censor_mgb = max_censor_mgb.drop(columns=['Unnamed: 0'])

# If we still don't have 'eid' and 'age', try to infer from column names
if 'eid' not in max_censor_mgb.columns or 'age' not in max_censor_mgb.columns:
    print(f"⚠ Column names issue. Available columns: {max_censor_mgb.columns.tolist()}")
    print(f"   First few rows:\n{max_censor_mgb.head()}")
    
    # Try to identify columns by content
    for col in max_censor_mgb.columns:
        if 'eid' in str(col).lower() or max_censor_mgb[col].dtype in ['int64', 'float64'] and max_censor_mgb[col].min() > 100000:
            if 'eid' not in max_censor_mgb.columns:
                max_censor_mgb = max_censor_mgb.rename(columns={col: 'eid'})
        elif 'age' in str(col).lower() or (max_censor_mgb[col].dtype in ['int64', 'float64'] and max_censor_mgb[col].max() < 100):
            if 'age' not in max_censor_mgb.columns:
                max_censor_mgb = max_censor_mgb.rename(columns={col: 'age'})

print(f"\nMGB max censor shape (after cleaning): {max_censor_mgb.shape}")
print(f"MGB max censor columns (after cleaning): {max_censor_mgb.columns.tolist()}")
print(f"First few entries:\n{max_censor_mgb.head()}")

# CRITICAL: Reorder max_censor_mgb to match YandEpatientnames order (like AOU)
# Determine the patient ID column name in YandEpatientnames
patient_id_col = None
for col in ['x', 'eid', 'patient_id', 'id']:
    if col in YandEpatientnames_mgb.columns:
        patient_id_col = col
        break

if patient_id_col is None:
    print(f"\n⚠ WARNING: Could not find patient ID column in YandEpatientnames!")
    print(f"  Available columns: {YandEpatientnames_mgb.columns.tolist()}")
    print(f"  Assuming order is already correct...")
else:
    print(f"\nUsing '{patient_id_col}' column from YandEpatientnames to match order")
    
    # Check if order matches
    if len(max_censor_mgb) == len(YandEpatientnames_mgb):
        # Check what column name max_censor_mgb uses for patient IDs
        patient_id_col_censor = 'eid' if 'eid' in max_censor_mgb.columns else 'index' if 'index' in max_censor_mgb.columns else None
        
        if patient_id_col_censor is None:
            print(f"⚠ Cannot find patient ID column in max_censor_mgb. Available columns: {max_censor_mgb.columns.tolist()}")
            print(f"   Assuming order is already correct...")
        else:
            # Try to match by patient ID
            order_matches = False
            try:
                order_matches = (max_censor_mgb[patient_id_col_censor].values == YandEpatientnames_mgb[patient_id_col].values).all()
            except:
                pass
            
            if not order_matches:
                print(f"⚠ Reordering max_censor_mgb to match YandEpatientnames order...")
                print(f"   Before: First 3 IDs in max_censor_mgb: {max_censor_mgb[patient_id_col_censor].head(3).tolist()}")
                print(f"   Before: First 3 IDs in YandEpatientnames: {YandEpatientnames_mgb[patient_id_col].head(3).tolist()}")
                
                # Use merge approach (like AOU) to preserve all columns
                # Create a temporary dataframe with the desired order
                order_df = pd.DataFrame({patient_id_col: YandEpatientnames_mgb[patient_id_col]})
                
                # Merge to reorder (preserves all columns from max_censor_mgb, including 'age')
                max_censor_mgb = order_df.merge(
                    max_censor_mgb,
                    left_on=patient_id_col,
                    right_on=patient_id_col_censor,
                    how='left'
                )
                
                # After merge, we'll have patient_id_col from order_df and possibly patient_id_col_censor from max_censor_mgb
                # Keep patient_id_col (which matches YandEpatientnames order) and drop the duplicate
                if patient_id_col_censor in max_censor_mgb.columns and patient_id_col in max_censor_mgb.columns:
                    # Both exist, drop the old one and rename the new one
                    max_censor_mgb = max_censor_mgb.drop(columns=[patient_id_col_censor])
                    max_censor_mgb = max_censor_mgb.rename(columns={patient_id_col: patient_id_col_censor})
                elif patient_id_col in max_censor_mgb.columns and patient_id_col != patient_id_col_censor:
                    # Only the merged column exists, rename it
                    max_censor_mgb = max_censor_mgb.rename(columns={patient_id_col: patient_id_col_censor})
                
                # Debug: Check columns after merge
                print(f"   After merge - columns: {max_censor_mgb.columns.tolist()}")
                print(f"   'age' column exists: {'age' in max_censor_mgb.columns}")
                
                # Fill any missing patients with default max age (check if 'age' exists first)
                if 'age' in max_censor_mgb.columns:
                    missing_mask = max_censor_mgb['age'].isna()
                    if missing_mask.any():
                        default_max_age = max_censor_mgb['age'].max() if max_censor_mgb['age'].notna().any() else 81
                        max_censor_mgb.loc[missing_mask, 'age'] = default_max_age
                        print(f"   Filled {missing_mask.sum()} missing patients with max_age={default_max_age}")
                else:
                    print(f"   ⚠ WARNING: 'age' column not found after reordering!")
                    print(f"   Available columns: {max_censor_mgb.columns.tolist()}")
                
                # Verify after reordering - check both possible column names
                final_id_col = 'eid' if 'eid' in max_censor_mgb.columns else 'index'
                order_matches = (max_censor_mgb[final_id_col].values == YandEpatientnames_mgb[patient_id_col].values).all()
                if order_matches:
                    print(f"✓ Order now matches YandEpatientnames (like AOU)")
                    print(f"   After: First 3 IDs match: {max_censor_mgb[final_id_col].head(3).tolist()}")
                else:
                    print(f"⚠ Still have mismatches - check patient ID alignment")
            else:
                print(f"✓ Order already matches YandEpatientnames")
    else:
        print(f"⚠ Length mismatch: max_censor_mgb has {len(max_censor_mgb)} rows, YandEpatientnames has {len(YandEpatientnames_mgb)} rows")

# Rename age to max_censor for consistency (like AOU)
censor_df_mgb = max_censor_mgb.rename(columns={'age': 'max_censor'})

# Ensure 'eid' column exists (might be 'index' after reset_index)
if 'index' in censor_df_mgb.columns and 'eid' not in censor_df_mgb.columns:
    censor_df_mgb = censor_df_mgb.rename(columns={'index': 'eid'})

# Fill missing patients with default max age (if any were created during reindexing)
missing_mask = censor_df_mgb['max_censor'].isna()
if missing_mask.any():
    default_max_censor = censor_df_mgb['max_censor'].max() if censor_df_mgb['max_censor'].notna().any() else 81
    censor_df_mgb.loc[missing_mask, 'max_censor'] = default_max_censor
    print(f"\nFilled {missing_mask.sum()} missing patients with max_censor={default_max_censor}")

# Verify order is preserved (like R's all.equal check, similar to AOU)
if patient_id_col is not None:
    # Check what column name censor_df_mgb uses for patient IDs
    final_id_col = 'eid' if 'eid' in censor_df_mgb.columns else 'index' if 'index' in censor_df_mgb.columns else None
    if final_id_col is not None:
        order_preserved = (censor_df_mgb[final_id_col].values == YandEpatientnames_mgb[patient_id_col].values).all()
        if order_preserved:
            print(f"✓ Order preserved: censor_df_mgb matches YandEpatientnames order")
            print(f"  Matched {censor_df_mgb['max_censor'].notna().sum()} / {len(censor_df_mgb)} patients")
        else:
            print(f"⚠ WARNING: Order NOT preserved!")
    else:
        print(f"⚠ Cannot verify order - patient ID column not found in censor_df_mgb")

# Verify final count
N_mgb = Y_tensor_mgb.shape[0]
if len(censor_df_mgb) != N_mgb:
    print(f"\n⚠ WARNING: censor_df_mgb has {len(censor_df_mgb)} patients but Y/E has {N_mgb} patients!")
else:
    print(f"\n✓ Final patient count matches: {N_mgb} patients")

# Now censor_df_mgb is in the correct order matching Y/E matrices
print(f"\ncensor_df_mgb ready with shape: {censor_df_mgb.shape}")
print(censor_df_mgb.head())


MGB: Loading and aligning censor data...
MGB patient names shape: (34592, 2)
MGB patient names columns: ['Unnamed: 0', 'x']
First few patient names:
   Unnamed: 0          x
0           1  101790256
1           2  101717153
2           3  102456864
3           4  100219007
4           5  100230568

MGB max censor shape (before cleaning): (34592, 3)
MGB max censor columns (before cleaning): ['Unnamed: 0', 'eid', 'age']

MGB max censor shape (after cleaning): (34592, 2)
MGB max censor columns (after cleaning): ['eid', 'age']
First few entries:
         eid  age
0  100000296   77
1  100000431   79
2  100000462   74
3  100001182   73
4  100001295   80

Using 'x' column from YandEpatientnames to match order
⚠ Reordering max_censor_mgb to match YandEpatientnames order...
   Before: First 3 IDs in max_censor_mgb: [100000296, 100000431, 100000462]
   Before: First 3 IDs in YandEpatientnames: [101790256, 101717153, 102456864]
   After merge - columns: ['eid', 'age']
   'age' column exists: Tru

In [37]:
# ============================================================================
# MGB: Correct E Matrix
# ============================================================================
print("\n" + "=" * 60)
print("MGB: Correcting E matrix...")
print("=" * 60)

T_mgb = Y_tensor_mgb.shape[2]

# Convert max_censor ages to timepoints (age 30 = timepoint 0)
# censor_df_mgb is now in the same order as Y/E
max_timepoints_mgb = torch.tensor(
    (censor_df_mgb['max_censor'].values - 30).clip(0, T_mgb-1).astype(int)
)

# Only update censored cases (where E == T-1)
censored_mask_mgb = (E_tensor_mgb == T_mgb - 1)
max_timepoints_expanded_mgb = max_timepoints_mgb.unsqueeze(1).expand_as(E_tensor_mgb)

# Update only censored positions
E_corrected_mgb = torch.where(
    censored_mask_mgb,
    torch.minimum(E_tensor_mgb, max_timepoints_expanded_mgb),
    E_tensor_mgb
)

enrollment_ages_mgb = censor_df_mgb['max_censor'].values

print(f"E matrix correction complete:")
print(f"  Original E shape: {E_tensor_mgb.shape}")
print(f"  Corrected E shape: {E_corrected_mgb.shape}")
print(f"  Patients with corrections: {(censored_mask_mgb.sum(dim=1) > 0).sum().item()} / {E_tensor_mgb.shape[0]}")

# Save corrected E
torch.save(E_corrected_mgb, '/Users/sarahurbut/aladynoulli2/mgb_E_corrected.pt')
print(f"✓ Saved corrected E to: mgb_E_corrected.pt")



MGB: Correcting E matrix...
E matrix correction complete:
  Original E shape: torch.Size([34592, 346])
  Corrected E shape: torch.Size([34592, 346])
  Patients with corrections: 34592 / 34592
✓ Saved corrected E to: mgb_E_corrected.pt


In [38]:
# ============================================================================
# MGB: Compute Corrected Prevalence
# ============================================================================
print("\n" + "=" * 60)
print("MGB: Computing corrected prevalence...")
print("=" * 60)

new_prevalence_t_mgb = compute_smoothed_prevalence_at_risk(
    Y=Y_tensor_mgb, 
    E_corrected=E_corrected_mgb, 
    enrollment_ages=enrollment_ages_mgb,
    window_size=5,
    smooth_on_logit=True
)

print(f"New prevalence shape: {new_prevalence_t_mgb.shape}")

# Convert to logit and save
if torch.is_tensor(new_prevalence_t_mgb):
    logit_prev_np_mgb = new_prevalence_t_mgb.numpy()
else:
    logit_prev_np_mgb = new_prevalence_t_mgb

epsilon = 1e-8
logit_prev_mgb = np.log((logit_prev_np_mgb + epsilon) / (1 - logit_prev_np_mgb + epsilon))

# Save prevalence files
torch.save(torch.tensor(logit_prev_mgb), '/Users/sarahurbut/aladynoulli2/mgb_logit_prev_corrected_E.pt')
torch.save(torch.tensor(logit_prev_np_mgb), '/Users/sarahurbut/aladynoulli2/mgb_prevalence_corrected_E.pt')
print(f"✓ Saved prevalence files")



MGB: Computing corrected prevalence...
Computing prevalence for 346 diseases, 51 timepoints...
  Processing disease 0/346...


  Processing disease 50/346...
  Processing disease 100/346...
  Processing disease 150/346...
  Processing disease 200/346...
  Processing disease 250/346...
  Processing disease 300/346...
New prevalence shape: (346, 51)
✓ Saved prevalence files


In [47]:
# ============================================================================
# MGB: Initialize Model
# ============================================================================
print("\n" + "=" * 60)
print("MGB: Initializing model...")
print("=" * 60)

# Load old MGB checkpoint for clusters and G
mgb_checkpoint_old = torch.load('/Users/sarahurbut/Dropbox-Personal/model_with_kappa_bigam_MGB.pt', map_location='cpu')
initial_clusters_mgb = mgb_checkpoint_old['clusters']
if isinstance(initial_clusters_mgb, torch.Tensor):
    initial_clusters_mgb = initial_clusters_mgb.numpy()
else:
    initial_clusters_mgb = np.array(initial_clusters_mgb)

K_mgb = int(initial_clusters_mgb.max() + 1)
print(f"MGB: K={K_mgb} signatures")

# Create signature references
signature_refs_mgb, healthy_ref_mgb = create_reference_trajectories(
    Y_tensor_mgb, initial_clusters_mgb, K=K_mgb, healthy_prop=0, frac=0.3
)

# Get G and disease names
G_mgb = mgb_checkpoint_old['G']
disease_names_mgb = pd.DataFrame(robjects.r['readRDS'](os.path.join(data_path_mgb, 'diagnames.rds')))
disease_names_mgb = disease_names_mgb[0].tolist()

# Use the corrected prevalence we just computed
prevalence_t_mgb = torch.tensor(logit_prev_np_mgb)  # Use the prevalence from cell above

# Create model
model_mgb = AladynSurvivalFixedKernelsAvgLoss_clust_logitInit_psitest(
    N=Y_tensor_mgb.shape[0],
    D=Y_tensor_mgb.shape[1], 
    T=Y_tensor_mgb.shape[2], 
    K=20,
    P=G_mgb.shape[1],
    init_sd_scaler=1e-1,
    G=G_mgb, 
    Y=Y_tensor_mgb,
    genetic_scale=1,
    W=0.0001,
    R=0,
    prevalence_t=prevalence_t_mgb,
    signature_references=signature_refs_mgb,
    healthy_reference=True,
    disease_names='disease_names'
)

# Set clusters and initialize
model_mgb.clusters = initial_clusters_mgb
psi_config = {'in_cluster': 1, 'out_cluster': -2, 'noise_in': 0.1, 'noise_out': 0.01}
model_mgb.initialize_params(psi_config=psi_config)

# Verify
clusters_match = np.array_equal(initial_clusters_mgb, model_mgb.clusters)
print(f"✓ Clusters match: {clusters_match}")

history = model_mgb.fit(E_corrected_mgb,
num_epochs=200,learning_rate=1e-1,
lambda_reg=1e-2)
# Save initialized model
save_dict_mgb = {
    'model_state_dict': model_mgb.state_dict(),
    'clusters': initial_clusters_mgb,
    'signature_refs': signature_refs_mgb,
    'healthy_ref': healthy_ref_mgb,
    'psi_config': psi_config,
    'hyperparameters': {
        'N': Y_tensor_mgb.shape[0],
        'D': Y_tensor_mgb.shape[1],
        'T': Y_tensor_mgb.shape[2],
        'K': 20,
        'P': G_mgb.shape[1],
        'init_sd_scaler': 1e-1,
        'genetic_scale': 1,
        'W': 0.0001,
        'R': 0,
    },
    'prevalence_t': prevalence_t_mgb,
    'disease_names': disease_names_mgb,
}

torch.save(save_dict_mgb, '/Users/sarahurbut/aladynoulli2/mgb_model_initialized.pt')
print(f"✓ Saved MGB initialized model to: mgb_model_initialized.pt")
print("\n" + "=" * 60)
print("✓ All initialization complete!")
print("=" * 60)



MGB: Initializing model...


  mgb_checkpoint_old = torch.load('/Users/sarahurbut/Dropbox-Personal/model_with_kappa_bigam_MGB.pt', map_location='cpu')


MGB: K=20 signatures


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.G = torch.tensor(G, dtype=torch.float32)
  self.G = torch.tensor(G_scaled, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  self.prevalence_t = torch.tensor(prevalence_t, dtype=torch.float32)



Cluster Sizes:
Cluster 0: 6 diseases
Cluster 1: 28 diseases
Cluster 2: 24 diseases
Cluster 3: 11 diseases
Cluster 4: 24 diseases
Cluster 5: 20 diseases
Cluster 6: 15 diseases
Cluster 7: 16 diseases
Cluster 8: 19 diseases
Cluster 9: 13 diseases
Cluster 10: 7 diseases
Cluster 11: 7 diseases
Cluster 12: 68 diseases
Cluster 13: 7 diseases
Cluster 14: 13 diseases
Cluster 15: 12 diseases
Cluster 16: 16 diseases
Cluster 17: 8 diseases
Cluster 18: 12 diseases
Cluster 19: 20 diseases

Calculating gamma for k=0:
Number of diseases in cluster: 6
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7714, -0.7714, -0.7714, -0.7714, -0.7714])
Base value centered mean: -1.1301173117317376e-06
Gamma init for k=0 (first 5): tensor([ 0.0043,  0.0104, -0.0070,  0.0494,  0.0009])

Calculating gamma for k=1:
Number of diseases in cluster: 28
Base value (first 5): tensor([-13.8155, -12.7544, -13.8155, -13.4618, -13.8155])
Base value cent

  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 150.9510

Monitoring signature responses:

Disease 187 (signature 12, LR=31.28):
  Theta for diagnosed: 0.065 ± 0.012
  Theta for others: 0.066
  Proportion difference: -0.001

Disease 141 (signature 12, LR=28.70):
  Theta for diagnosed: 0.066 ± 0.011
  Theta for others: 0.066
  Proportion difference: -0.001

Disease 163 (signature 14, LR=28.19):
  Theta for diagnosed: 0.030 ± 0.011
  Theta for others: 0.030
  Proportion difference: 0.000

Disease 274 (signature 19, LR=27.57):
  Theta for diagnosed: 0.032 ± 0.026
  Theta for others: 0.032
  Proportion difference: 0.001

Disease 256 (signature 12, LR=27.43):
  Theta for diagnosed: 0.066 ± 0.012
  Theta for others: 0.066
  Proportion difference: -0.000

Epoch 1
Loss: 763.2976

Monitoring signature responses:

Disease 187 (signature 12, LR=31.24):
  Theta for diagnosed: 0.065 ± 0.010
  Theta for others: 0.066
  Proportion difference: -0.001

Disease 141 (signature 12, LR=28.70):
  Theta for diagnosed: 0.065 ± 0.009
  Theta 

In [1]:
%run /Users/sarahurbut/aladynoulli2/pyScripts_forPublish/verify_corrected_E_and_prevalence_match

VERIFYING CORRECTED E MATRICES AND PREVALENCE MATCH MODEL CHECKPOINTS

AOU VERIFICATION

Standalone files:
  aou_E_corrected.pt shape: torch.Size([10000, 348])
  aou_prevalence_corrected_E.pt shape: torch.Size([348, 51])

Model checkpoint keys: ['model_state_dict', 'clusters', 'signature_refs', 'healthy_ref', 'psi_config', 'hyperparameters', 'prevalence_t', 'disease_names']

Model checkpoint prevalence_t shape: torch.Size([348, 51])

Prevalence comparison:
  Max difference: 0.0000000000
  Mean difference: 0.0000000000
  ✓ PERFECT MATCH!

MGB VERIFICATION

Standalone files:
  mgb_E_corrected.pt shape: torch.Size([34592, 346])
  mgb_prevalence_corrected_E.pt shape: torch.Size([346, 51])

Model checkpoint keys: ['model_state_dict', 'clusters', 'signature_refs', 'healthy_ref', 'psi_config', 'hyperparameters', 'prevalence_t', 'disease_names']

Model checkpoint prevalence_t shape: torch.Size([346, 51])

Prevalence comparison:
  Max difference: 0.0000000000
  Mean difference: 0.0000000000
  ✓

  aou_E_corrected = torch.load('/Users/sarahurbut/aladynoulli2/aou_E_corrected.pt', map_location='cpu')
  aou_prevalence_corrected = torch.load('/Users/sarahurbut/aladynoulli2/aou_prevalence_corrected_E.pt', map_location='cpu')
  aou_model = torch.load('/Users/sarahurbut/aladynoulli2/aou_model_initialized.pt', map_location='cpu')
  mgb_E_corrected = torch.load('/Users/sarahurbut/aladynoulli2/mgb_E_corrected.pt', map_location='cpu')
  mgb_prevalence_corrected = torch.load('/Users/sarahurbut/aladynoulli2/mgb_prevalence_corrected_E.pt', map_location='cpu')
  mgb_model = torch.load('/Users/sarahurbut/aladynoulli2/mgb_model_initialized.pt', map_location='cpu')
