## Imports and Helper Functions


In [1]:
import rpy2.robjects as robjects
from rpy2.robjects import numpy2ri
import numpy as np
import pandas as pd
import torch
from scipy.ndimage import gaussian_filter1d
import os

# Activate automatic conversion between R and NumPy arrays
numpy2ri.activate()


---

## Part 1: AOU (All of Us)

### Step 1: Load AOU Data


AOU Y shape: torch.Size([10000, 348, 51])
AOU E shape: torch.Size([10000, 348])


### Step 2: Load Patient Names and Create Censor Info


Patient names shape: (10000, 2)
Columns: ['Unnamed: 0', 'x']

ICD10 data shape: 262001
Unique patients in ICD10: 10000

Matched 10000 / 10000 patients
Missing matches: 0
✓ Order already matches YandEpatientnames

✓ Saved AOU censor info to: aou_censor_info.csv


### Step 3: Correct AOU E Matrix


In [None]:
# Load censor info (if not already in memory)
censor_df_aou = pd.read_csv('/Users/sarahurbut/aladynoulli2/aou_censor_info.csv')

T_aou = Y_aou_tensor.shape[2]  # Number of timepoints

# Convert max_censor ages to timepoints (age 30 = timepoint 0)
max_timepoints_aou = torch.tensor(
    (censor_df_aou['max_censor'].values - 30).clip(0, T_aou-1).astype(int)
)

# Only update censored cases (where E == T-1, meaning right-censored at max time)
censored_mask_aou = (E_aou_tensor == T_aou - 1)  # Shape: (N, D)

# For each patient, cap censored diseases to their max_timepoint
# Expand max_timepoints to match E shape
max_timepoints_expanded_aou = max_timepoints_aou.unsqueeze(1).expand_as(E_aou_tensor)

# Update only censored positions
E_aou_corrected = torch.where(
    censored_mask_aou,
    torch.minimum(E_aou_tensor, max_timepoints_expanded_aou),
    E_aou_tensor  # Keep event times as-is
)

print(f"AOU E matrix correction complete:")
print(f"  Original E shape: {E_aou_tensor.shape}")
print(f"  Corrected E shape: {E_aou_corrected.shape}")
print(f"  Patients with corrections: {(censored_mask_aou.sum(dim=1) > 0).sum().item()} / {E_aou_tensor.shape[0]}")

# Save corrected E
torch.save(E_aou_corrected, '/Users/sarahurbut/aladynoulli2/aou_E_corrected.pt')
print(f"✓ Saved AOU corrected E to: aou_E_corrected.pt")


AOU E matrix correction complete:
  Original E shape: torch.Size([10000, 348])
  Corrected E shape: torch.Size([10000, 348])
  Patients with corrections: 10000 / 10000
✓ Saved AOU corrected E to: aou_E_corrected.pt


In [22]:
print("\nComputing AOU prevalence with at-risk filtering...")
aou_prevalence_corrected = compute_smoothed_prevalence_at_risk(
    Y=Y_aou_tensor, 
    E_corrected=E_aou_corrected, 
    window_size=5,
    smooth_on_logit=True
)

print(f"\nAOU prevalence shape: {aou_prevalence_corrected.shape}")
print(f"AOU prevalence range: [{aou_prevalence_corrected.min():.6f}, {aou_prevalence_corrected.max():.6f}]")

# Convert to logit scale
epsilon = 1e-8
aou_logit_prev = np.log((aou_prevalence_corrected + epsilon) / (1 - aou_prevalence_corrected + epsilon))

# Save results
torch.save(torch.tensor(aou_logit_prev), '/Users/sarahurbut/aladynoulli2/aou_logit_prev_corrected_E.pt')
torch.save(torch.tensor(aou_prevalence_corrected), '/Users/sarahurbut/aladynoulli2/aou_prevalence_corrected_E.pt')

print(f"\n✓ Saved AOU logit prevalence to: aou_logit_prev_corrected_E.pt")
print(f"✓ Saved AOU prevalence to: aou_prevalence_corrected_E.pt")
print(f"\n✓ AOU processing complete!")



Computing AOU prevalence with at-risk filtering...
Computing prevalence for 348 diseases, 51 timepoints...
  Processing disease 0/348...
  Processing disease 50/348...
  Processing disease 100/348...
  Processing disease 150/348...
  Processing disease 200/348...
  Processing disease 250/348...
  Processing disease 300/348...

AOU prevalence shape: (348, 51)
AOU prevalence range: [0.000000, 0.126083]

✓ Saved AOU logit prevalence to: aou_logit_prev_corrected_E.pt
✓ Saved AOU prevalence to: aou_prevalence_corrected_E.pt

✓ AOU processing complete!


---

## Part 2: MGB (Mass General Brigham)

### Step 1: Load MGB Data


---

## Part 2: MGB (Mass General Brigham)

### Step 1: Load MGB Data



### Step 2: Load Patient Names and Create Censor Info


In [24]:
# Load Y and E matrices from R .rds files
mgb_data_path = "/Users/sarahurbut/Dropbox-Personal/mgbbtopic/"

Y_mgb = np.array(robjects.r['readRDS'](os.path.join(mgb_data_path, 'Y_sub.rds')))
E_mgb = np.array(robjects.r['readRDS'](os.path.join(mgb_data_path, 'E_sub.rds')))
E_mgb = E_mgb.astype(int)

# Convert to PyTorch tensors
Y_mgb_tensor = torch.FloatTensor(Y_mgb)
E_mgb_tensor = torch.FloatTensor(E_mgb)

print(f"MGB Y shape: {Y_mgb_tensor.shape}")
print(f"MGB E shape: {E_mgb_tensor.shape}")

MGB Y shape: torch.Size([34592, 346, 51])
MGB E shape: torch.Size([34592, 346])


In [21]:
# Load patient names (defines order of rows in Y/E)
YandEpatientnames_mgb = pd.read_csv('/Users/sarahurbut/aladynoulli2/mgb_patientnames.csv')
print(f"Patient names shape: {YandEpatientnames_mgb.shape}")

# Load max_censor data
max_censor_mgb = pd.read_csv('/Users/sarahurbut/aladynoulli2/max_censor_mgb.csv')
print(f"Max censor data shape: {len(max_censor_mgb)}")

# Rename age to max_censor for consistency
if 'age' in max_censor_mgb.columns:
    max_censor_mgb = max_censor_mgb.rename(columns={'age': 'max_censor'})

# Merge with patient names to ensure correct order
censor_info_mgb = YandEpatientnames_mgb.merge(
    max_censor_mgb,
    left_on='x',  # Patient ID column in Y/E
    right_on='eid' if 'eid' in max_censor_mgb.columns else max_censor_mgb.columns[0],
    how='left'
)

# Fill missing patients with default max age
missing_mask = censor_info_mgb['max_censor'].isna()
if missing_mask.any():
    default_max_censor = max_censor_mgb['max_censor'].max() if len(max_censor_mgb) > 0 else 81
    censor_info_mgb.loc[missing_mask, 'max_censor'] = default_max_censor
    print(f"Filled {missing_mask.sum()} missing patients with max_censor={default_max_censor}")

# Verify order is preserved
order_preserved = (censor_info_mgb['x'].values == YandEpatientnames_mgb['x'].values).all()
if order_preserved:
    print(f"✓ Order preserved: censor_info_mgb matches YandEpatientnames order")
    print(f"  Matched {censor_info_mgb['max_censor'].notna().sum()} / {len(censor_info_mgb)} patients")
else:
    print(f"⚠ WARNING: Order NOT preserved!")

# CRITICAL: Ensure censor_df matches YandEpatientnames order
if len(censor_info_mgb) == len(YandEpatientnames_mgb) and 'x' in YandEpatientnames_mgb.columns:
    patient_id_col = 'eid' if 'eid' in censor_info_mgb.columns else 'index'
    
    if patient_id_col in censor_info_mgb.columns:
        order_matches = (censor_info_mgb[patient_id_col].values == YandEpatientnames_mgb['x'].values).all()
        
        if not order_matches:
            print(f"⚠ Reordering censor_info_mgb to match YandEpatientnames order...")
            
            # Reorder censor_df to match YandEpatientnames order (preserves Y/E order)
            censor_info_mgb = censor_info_mgb.set_index(patient_id_col).reindex(YandEpatientnames_mgb['x']).reset_index()
            
            # Rename index column back if needed
            if 'index' in censor_info_mgb.columns and patient_id_col != 'index':
                censor_info_mgb = censor_info_mgb.rename(columns={'index': patient_id_col})
            
            # Fill any missing patients with default max age
            missing_mask = censor_info_mgb['max_censor'].isna()
            if missing_mask.any():
                default_max_censor = censor_info_mgb['max_censor'].max() if censor_info_mgb['max_censor'].notna().any() else 81
                censor_info_mgb.loc[missing_mask, 'max_censor'] = default_max_censor
                print(f"   Filled {missing_mask.sum()} missing patients with max_censor={default_max_censor}")
            
            # Verify after reordering
            final_id_col = 'eid' if 'eid' in censor_info_mgb.columns else 'index'
            order_matches = (censor_info_mgb[final_id_col].values == YandEpatientnames_mgb['x'].values).all()
            if order_matches:
                print(f"✓ Order now matches YandEpatientnames")
            else:
                print(f"⚠ Still have mismatches - check patient ID alignment")
        else:
            print(f"✓ Order already matches YandEpatientnames")

print(f"\ncensor_info_mgb ready with shape: {censor_info_mgb.shape}")
print(censor_info_mgb.head())


Patient names shape: (34592, 2)
Max censor data shape: 34592
✓ Order preserved: censor_info_mgb matches YandEpatientnames order
  Matched 34592 / 34592 patients
✓ Order already matches YandEpatientnames

censor_info_mgb ready with shape: (34592, 5)
   Unnamed: 0_x          x  Unnamed: 0_y        eid  max_censor
0             1  101790256         10689  101790256          53
1             2  101717153         10400  101717153          67
2             3  102456864         12942  102456864          45
3             4  100219007          2247  100219007          71
4             5  100230568          2417  100230568          69


### Step 3: Correct MGB E Matrix


MGB E matrix correction complete:
  Original E shape: torch.Size([34592, 346])
  Corrected E shape: torch.Size([34592, 346])
  Patients with corrections: 34592 / 34592
✓ Saved MGB corrected E to: mgb_E_corrected.pt


In [26]:
# Load old MGB prevalence for comparison
mgb_prevalence_corrected_old = torch.load('/Users/sarahurbut/aladynoulli2/mgb_prevalence_corrected_E.pt', map_location='cpu')


  mgb_prevalence_corrected_old = torch.load('/Users/sarahurbut/aladynoulli2/mgb_prevalence_corrected_E.pt', map_location='cpu')


In [27]:
print("\nComputing MGB prevalence with at-risk filtering...")
mgb_prevalence_corrected = compute_smoothed_prevalence_at_risk(
    Y=Y_mgb_tensor, 
    E_corrected=E_mgb_corrected, 
    window_size=5,
    smooth_on_logit=True
)

print(f"\nMGB prevalence shape: {mgb_prevalence_corrected.shape}")
print(f"MGB prevalence range: [{mgb_prevalence_corrected.min():.6f}, {mgb_prevalence_corrected.max():.6f}]")

# Convert to logit scale
epsilon = 1e-8
mgb_logit_prev = np.log((mgb_prevalence_corrected + epsilon) / (1 - mgb_prevalence_corrected + epsilon))

# Test: Compare with saved file (convert numpy to tensor for comparison)
mgb_prevalence_corrected_tensor = torch.tensor(mgb_prevalence_corrected)
matches = torch.allclose(mgb_prevalence_corrected_old, mgb_prevalence_corrected_tensor, rtol=1e-5, atol=1e-8)
max_diff = (mgb_prevalence_corrected_old - mgb_prevalence_corrected_tensor).abs().max().item()
mean_diff = (mgb_prevalence_corrected_old - mgb_prevalence_corrected_tensor).abs().mean().item()

print(f"\nMGB Prevalence Comparison:")
print(f"  Max difference: {max_diff:.10f}")
print(f"  Mean difference: {mean_diff:.10f}")
if matches:
    print(f"  ✓ PERFECT MATCH!")
else:
    print(f"  ✗ MISMATCH - values differ")

# Save results
torch.save(torch.tensor(mgb_logit_prev), '/Users/sarahurbut/aladynoulli2/mgb_logit_prev_corrected_E.pt')
torch.save(torch.tensor(mgb_prevalence_corrected), '/Users/sarahurbut/aladynoulli2/mgb_prevalence_corrected_E.pt')

print(f"\n✓ Saved MGB logit prevalence to: mgb_logit_prev_corrected_E.pt")
print(f"✓ Saved MGB prevalence to: mgb_prevalence_corrected_E.pt")
print(f"\n✓ MGB processing complete!")



Computing MGB prevalence with at-risk filtering...
Computing prevalence for 346 diseases, 51 timepoints...
  Processing disease 0/346...
  Processing disease 50/346...
  Processing disease 100/346...
  Processing disease 150/346...
  Processing disease 200/346...
  Processing disease 250/346...
  Processing disease 300/346...

MGB prevalence shape: (346, 51)
MGB prevalence range: [0.000000, 0.171433]

MGB Prevalence Comparison:
  Max difference: 0.0000000000
  Mean difference: 0.0000000000
  ✓ PERFECT MATCH!

✓ Saved MGB logit prevalence to: mgb_logit_prev_corrected_E.pt
✓ Saved MGB prevalence to: mgb_prevalence_corrected_E.pt

✓ MGB processing complete!
