In [1]:
%load_ext autoreload
%autoreload 2



%autoreload 2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.spatial.distance import pdist, squareform
from scipy.special import expit
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from sklearn.cluster import SpectralClustering  # Add this import
from utils import *


def load_model_essentials(base_path='/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/'):
    """
    Load all essential components
    """
    print("Loading components...")
    
    # Load large matrices
    Y = torch.load(base_path + 'Y_tensor.pt')
    E = torch.load(base_path + 'E_matrix.pt')
    G = torch.load(base_path + 'G_matrix.pt')
    
    # Load other components
    essentials = torch.load(base_path + 'model_essentials.pt')
    
    print("Loaded all components successfully!")
    
    return Y, E, G, essentials

# Load and initialize model:
Y, E, G, essentials = load_model_essentials()

from clust_huge_amp import *
# Subset the data

# Subset the data
Y_100k, E_100k, G_100k, indices = subset_data(Y, E, G, start_index=0, end_index=10000)


del Y

# Load references (signatures only, no healthy)
refs = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/reference_trajectories.pt')
signature_refs = refs['signature_refs']
# When initializing the model:
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
pandas2ri.activate()

# Load the RDS file

import pandas as pd
fh_processed=pd.read_csv('/Users/sarahurbut/Library/Cloudstorage/Dropbox/baselinagefamh.csv')
len(fh_processed)



pce_df_subset = fh_processed.iloc[0:10000].reset_index(drop=True)
sex=pce_df_subset['sex'].values
G_with_sex = np.column_stack([G_100k, sex]) 



import torch
import numpy as np
import cProfile
import pstats
from pstats import SortKey

# Store predictions for each age
age_predictions = {}

for age_offset in range(0, 11):  # Ages 0-10 years after enrollment
    print(f"\n=== Predicting for age offset {age_offset} years ===")
    
    # Set seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    # Initialize fresh model for this age
    model = AladynSurvivalFixedKernelsAvgLoss_clust_logitInit_psitest(
        N=Y_100k.shape[0],
        D=Y_100k.shape[1],
        T=Y_100k.shape[2],
        K=20,
        P=G_with_sex.shape[1],
        init_sd_scaler=1e-1,
        G=G_with_sex,
        Y=Y_100k,
        genetic_scale=1,
        W=0.0001,
        R=0,
        prevalence_t=essentials['prevalence_t'],
        signature_references=signature_refs,
        healthy_reference=True,
        disease_names=essentials['disease_names']
    )
    
    # Reset seeds for parameter initialization
    torch.manual_seed(0)
    np.random.seed(0)
    
    # Load and set initial parameters
    initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
    initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')
    model.initialize_params(true_psi=initial_psi)
    model.clusters = initial_clusters
    
    # Verify clusters match
    clusters_match = np.array_equal(initial_clusters, model.clusters)
    print(f"Clusters match exactly: {clusters_match}")
    
    # Create age-specific event times
    E_age_specific = E_100k.clone()
    pce_df_subset = fh_processed.iloc[0:10000].reset_index(drop=True)

     
    # Initialize tracking variables for this age offset
    total_times_changed = 0
    max_cap_applied = 0
    min_cap_applied = float('inf')

    
    for patient_idx, row in enumerate(pce_df_subset.itertuples()):
        if patient_idx >= E_age_specific.shape[0]:
            break
            
        # Current age = enrollment age + age_offset
        current_age = row.age + age_offset
        
        # Time since age 30 for this current age
        time_since_30 = max(0, current_age - 30)

        max_cap_applied = max(max_cap_applied, time_since_30)
        min_cap_applied = min(min_cap_applied, time_since_30)
        
        # Store original times for this patient
        original_times = E_age_specific[patient_idx, :].clone()
        
        # Cap event times to current age
        E_age_specific[patient_idx, :] = torch.minimum(
            E_age_specific[patient_idx, :],
            torch.full_like(E_age_specific[patient_idx, :], time_since_30)
        )

        times_changed = torch.sum(E_age_specific[patient_idx, :] != original_times).item()
        total_times_changed += times_changed
    
    # Print censoring verification
    print(f"Censoring verification for age offset {age_offset}:")
    print(f"  Total event times changed: {total_times_changed}")
    print(f"  Max cap applied: {max_cap_applied:.1f}")
    print(f"  Min cap applied: {min_cap_applied:.1f}")
    
    # Check a few specific patients
    test_patients = [0, 1, 100]  # Check patients 0, 1, and 100
    for test_idx in test_patients:
        if test_idx < len(pce_df_subset):
            row = pce_df_subset.iloc[test_idx]
            enrollment_age = row.age
            current_age = enrollment_age + age_offset
            expected_cap = max(0, current_age - 30)
            
            # Check max value in this patient's event times
            max_time = torch.max(E_age_specific[test_idx, :]).item()
            
            print(f"  Patient {test_idx}: enrollment={enrollment_age:.0f}, current={current_age:.0f}, "
                  f"cap={expected_cap:.1f}, max_event_time={max_time:.1f}")
            
            # Verify cap was applied correctly
            if max_time > expected_cap + 0.01:  # Small tolerance
                print(f"    WARNING: Max time {max_time:.1f} exceeds cap {expected_cap:.1f}!")
    
  
    
    # Train model for this specific age
    print(f"Training model for age offset {age_offset}...")
    profiler = cProfile.Profile()
    profiler.enable()
    
    history_new = model.fit(
        E_age_specific, 
        num_epochs=200, 
        learning_rate=1e-1, 
        lambda_reg=1e-2
    )
    plot_training_evolution(history_new)

    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats(SortKey.CUMULATIVE)
    stats.print_stats(20)
    
    # Get predictions for this age
    with torch.no_grad():
        pi, _, _ = model.forward()
        
        # Save age-specific predictions
        filename = f"/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_{age_offset}_sex_0_10000.pt"
        torch.save(pi, filename)
        filename = f"/Users/sarahurbut/Library/CloudStorage/Dropbox/E_enroll_age_offset_{age_offset}_sex_0_10000.pt"
        torch.save(E_age_specific, filename)
        print(f"Saved predictions to {filename}")
        
        # Store in dictionary for potential analysis
        
    
    # Clean up to free memory
    del pi
    del model
    del E_age_specific
    torch.cuda.empty_cache() if torch.cuda.is_available() else None


Loading components...


  Y = torch.load(base_path + 'Y_tensor.pt')
  E = torch.load(base_path + 'E_matrix.pt')
  G = torch.load(base_path + 'G_matrix.pt')
  essentials = torch.load(base_path + 'model_essentials.pt')


Loaded all components successfully!


  refs = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/reference_trajectories.pt')


In [None]:
# Pick a few random indices to check
indices_to_check = [0, 1, 2]  # or use np.random.choice(N, 3, replace=False)

for idx in indices_to_check:
    t = int(enrollment_ages[idx] - 30)
    print(f"\nPerson {idx} (enrollment age: {enrollment_ages[idx]}, t={t}):")
    for k in range(years_to_use):
        # Value from assembled array
        val_from_cox = pi_full[idx, 0, t+k].item()  # disease 0 as example
        # Value from batch file
        val_from_batch = pi_batches[k][idx, 0, t + k].item() if t + k < T else float('nan')
        print(f"  Year {k}: pi_full={val_from_cox:.6f}, pi_batch={val_from_batch:.6f}, match={np.isclose(val_from_cox, val_from_batch, atol=1e-6)}")


Person 0 (enrollment age: 69, t=39):
  Year 0: pi_full=0.000230, pi_batch=0.000230, match=True
  Year 1: pi_full=0.000224, pi_batch=0.000224, match=True
  Year 2: pi_full=0.000229, pi_batch=0.000229, match=True
  Year 3: pi_full=0.000230, pi_batch=0.000230, match=True
  Year 4: pi_full=0.000224, pi_batch=0.000224, match=True
  Year 5: pi_full=0.000185, pi_batch=0.000185, match=True
  Year 6: pi_full=0.000180, pi_batch=0.000180, match=True
  Year 7: pi_full=0.000177, pi_batch=0.000177, match=True
  Year 8: pi_full=0.000179, pi_batch=0.000179, match=True
  Year 9: pi_full=0.000178, pi_batch=0.000178, match=True

Person 1 (enrollment age: 44, t=14):
  Year 0: pi_full=0.000143, pi_batch=0.000143, match=True
  Year 1: pi_full=0.000166, pi_batch=0.000166, match=True
  Year 2: pi_full=0.000192, pi_batch=0.000192, match=True
  Year 3: pi_full=0.000216, pi_batch=0.000216, match=True
  Year 4: pi_full=0.000253, pi_batch=0.000253, match=True
  Year 5: pi_full=0.000280, pi_batch=0.000280, match=T

In [None]:
import torch
import numpy as np
import pandas as pd

# Test function to verify age-specific censoring
def test_age_specific_censoring(E_100k, fh_processed, age_offset=5, test_patients=5):
    """
    Test that E_age_specific is correctly updated to reflect row.age + offset - 30
    """
    print(f"\n=== Testing Age-Specific Censoring (offset = {age_offset}) ===")
    
    # Create original and age-specific versions
    E_original = E_100k.clone()
    E_age_specific = E_100k.clone()
    
    # Get subset of patient data
    pce_df_subset = fh_processed.iloc[0:10000].reset_index(drop=True)
    
    # Apply age-specific censoring
    for patient_idx, row in enumerate(pce_df_subset.itertuples()):
        if patient_idx >= E_age_specific.shape[0]:
            break
            
        # Current age = enrollment age + age_offset
        current_age = row.age + age_offset
        
        # Time since age 30 for this current age
        time_since_30 = max(0, current_age - 30)
        
        # Cap event times to current age
        E_age_specific[patient_idx, :] = torch.minimum(
            E_age_specific[patient_idx, :],
            torch.full_like(E_age_specific[patient_idx, :], time_since_30)
        )
    
    # Test specific patients
    print(f"\nTesting first {test_patients} patients:")
    print("=" * 80)
    
    for i in range(min(test_patients, len(pce_df_subset))):
        row = pce_df_subset.iloc[i]
        
        enrollment_age = row.age
        current_age = enrollment_age + age_offset
        expected_cap = max(0, current_age - 30)
        
        # Get original and modified event times for this patient
        original_times = E_original[i, :].numpy()
        modified_times = E_age_specific[i, :].numpy()
        
        # Check if any times were actually capped
        times_changed = ~np.isclose(original_times, modified_times)
        
        print(f"\nPatient {i}:")
        print(f"  Enrollment age: {enrollment_age}")
        print(f"  Current age (enrollment + {age_offset}): {current_age}")
        print(f"  Expected cap (current_age - 30): {expected_cap}")
        print(f"  Times changed: {times_changed.sum()}/{len(times_changed)} diseases")
        
        if times_changed.any():
            # Show some examples of changed times
            changed_indices = np.where(times_changed)[0][:3]  # First 3 changes
            print(f"  Example changes:")
            for idx in changed_indices:
                print(f"    Disease {idx}: {original_times[idx]:.1f} â†’ {modified_times[idx]:.1f}")
        
        # Show unchanged diseases and their values
        unchanged_indices = np.where(~times_changed)[0]
        if len(unchanged_indices) > 0:
            print(f"  Unchanged diseases (indices): {unchanged_indices[:5]}")  # Show first 5
            print(f"  Their original times: {original_times[unchanged_indices[:5]]}")
            print(f"  Their modified times: {modified_times[unchanged_indices[:5]]}")
            print(f"  DEBUG - E_original values: {E_original[i, unchanged_indices[:5]].numpy()}")
            print(f"  DEBUG - E_age_specific values: {E_age_specific[i, unchanged_indices[:5]].numpy()}")
        
        # Verify all modified times are <= expected cap
        all_capped_correctly = np.all(modified_times <= expected_cap + 1e-6)  # small tolerance
        print(f"  All times correctly capped: {all_capped_correctly}")
        
        # Check that no times increased
        no_times_increased = np.all(modified_times <= original_times + 1e-6)
        print(f"  No times increased: {no_times_increased}")
    
    return E_original, E_age_specific

# Example usage:
E_orig, E_modified = test_age_specific_censoring(E_100k, fh_processed, age_offset=5)

# Additional verification function
def compare_age_offsets(E_100k, fh_processed, patient_idx=0):
    """
    Show how one patient's event times change across different age offsets
    """
    print(f"\n=== Patient {patient_idx} Across Different Age Offsets ===")
    
    row = fh_processed.iloc[patient_idx]
    enrollment_age = row.age
    
    print(f"Enrollment age: {enrollment_age}")
    print(f"Original event times (first 5 diseases): {E_100k[patient_idx, :5].numpy()}")
    print()
    
    for age_offset in [0, 2, 5, 10]:
        E_test = E_100k.clone()
        current_age = enrollment_age + age_offset
        time_since_30 = max(0, current_age - 30)
        
        E_test[patient_idx, :] = torch.minimum(
            E_test[patient_idx, :],
            torch.full_like(E_test[patient_idx, :], time_since_30)
        )
        
        print(f"Age offset {age_offset:2d} (age {current_age:2.0f}, cap at {time_since_30:2.0f}): {E_test[patient_idx, :5].numpy()}")

# Example usage:
compare_age_offsets(E_100k, fh_processed, patient_idx=0)

In [6]:
import torch
import numpy as np
import pandas as pd
fh_processed=pd.read_csv('/Users/sarahurbut/Library/Cloudstorage/Dropbox/baselinagefamh.csv')
len(fh_processed)
# Load your assembled full array
pi_full = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_full_leakage_free_20000_30000.pt")  # or pi_test_full.pt
# Load all batch arrays into a list
pi_batches = [
    torch.load(f"/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_{k}_sex_20000_30000.pt")  # update path/pattern
    for k in range(10)
]
pce_df_subset = fh_processed.iloc[20000:30000].reset_index(drop=True)

# Enrollment ages for your cohort
enrollment_ages = pce_df_subset['age'].to_numpy()  # or whatever your DataFrame is

# Parameters
N, D, T = pi_full.shape
years_to_use = 10

# Pick a few random indices to check
np.random.seed(42)
indices_to_check = np.random.choice(N, 3, replace=False)
diseases_to_check = np.random.choice(D, 2, replace=False)
years_to_check = [0, 3, 7]  # e.g., enrollment, +3, +7 years

for idx in indices_to_check:
    t_enroll = int(enrollment_ages[idx] - 30)
    print(f"\nPerson {idx} (enrollment age: {enrollment_ages[idx]}, t_enroll: {t_enroll}):")
    for d in diseases_to_check:
        for k in years_to_check:
            t_full = t_enroll + k
            if t_full < T:
                val_full = pi_full[idx, d, t_full].item()
                val_batch = pi_batches[k][idx, d, t_full].item()
                print(f"  Disease {d}, year {k} after enrollment (t={t_full}): full={val_full:.6g}, batch={val_batch:.6g}, match={np.isclose(val_full, val_batch)}")

  pi_full = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_full_leakage_free_20000_30000.pt")  # or pi_test_full.pt
  torch.load(f"/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_{k}_sex_20000_30000.pt")  # update path/pattern



Person 6252 (enrollment age: 69, t_enroll: 39):
  Disease 294, year 0 after enrollment (t=39): full=0.000641907, batch=0.000641907, match=True
  Disease 294, year 3 after enrollment (t=42): full=0.000509142, batch=0.000509142, match=True
  Disease 294, year 7 after enrollment (t=46): full=0.000347784, batch=0.000347784, match=True
  Disease 132, year 0 after enrollment (t=39): full=0.00018116, batch=0.00018116, match=True
  Disease 132, year 3 after enrollment (t=42): full=0.000263541, batch=0.000263541, match=True
  Disease 132, year 7 after enrollment (t=46): full=0.000384317, batch=0.000384317, match=True

Person 4684 (enrollment age: 44, t_enroll: 14):
  Disease 294, year 0 after enrollment (t=14): full=0.000445189, batch=0.000445189, match=True
  Disease 294, year 3 after enrollment (t=17): full=0.000580731, batch=0.000580731, match=True
  Disease 294, year 7 after enrollment (t=21): full=0.000843879, batch=0.000843879, match=True
  Disease 132, year 0 after enrollment (t=14): fu

In [7]:
import torch
import numpy as np
import pandas as pd
fh_processed=pd.read_csv('/Users/sarahurbut/Library/Cloudstorage/Dropbox/baselinagefamh.csv')
len(fh_processed)
# Load your assembled full array
pi_full = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_full_leakage_free_0_10000.pt")  # or pi_test_full.pt
# Load all batch arrays into a list
pi_batches = [
    torch.load(f"/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_{k}_sex_0_10000.pt")  # update path/pattern
    for k in range(10)
]
pce_df_subset = fh_processed.iloc[0:10000].reset_index(drop=True)

# Enrollment ages for your cohort
enrollment_ages = pce_df_subset['age'].to_numpy()  # or whatever your DataFrame is

# Parameters
N, D, T = pi_full.shape
years_to_use = 10

# Pick a few random indices to check
np.random.seed(42)
indices_to_check = np.random.choice(N, 3, replace=False)
diseases_to_check = np.random.choice(D, 2, replace=False)
years_to_check = [0, 3, 7]  # e.g., enrollment, +3, +7 years

for idx in indices_to_check:
    t_enroll = int(enrollment_ages[idx] - 30)
    print(f"\nPerson {idx} (enrollment age: {enrollment_ages[idx]}, t_enroll: {t_enroll}):")
    for d in diseases_to_check:
        for k in years_to_check:
            t_full = t_enroll + k
            if t_full < T:
                val_full = pi_full[idx, d, t_full].item()
                val_batch = pi_batches[k][idx, d, t_full].item()
                print(f"  Disease {d}, year {k} after enrollment (t={t_full}): full={val_full:.6g}, batch={val_batch:.6g}, match={np.isclose(val_full, val_batch)}")

  pi_full = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_full_leakage_free_0_10000.pt")  # or pi_test_full.pt
  torch.load(f"/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_{k}_sex_0_10000.pt")  # update path/pattern



Person 6252 (enrollment age: 63, t_enroll: 33):
  Disease 294, year 0 after enrollment (t=33): full=0.00118963, batch=0.00118963, match=True
  Disease 294, year 3 after enrollment (t=36): full=0.000917143, batch=0.000917143, match=True
  Disease 294, year 7 after enrollment (t=40): full=0.000951202, batch=0.000951202, match=True
  Disease 132, year 0 after enrollment (t=33): full=0.000104718, batch=0.000104718, match=True
  Disease 132, year 3 after enrollment (t=36): full=0.000169442, batch=0.000169442, match=True
  Disease 132, year 7 after enrollment (t=40): full=0.000287315, batch=0.000287315, match=True

Person 4684 (enrollment age: 67, t_enroll: 37):
  Disease 294, year 0 after enrollment (t=37): full=0.000699053, batch=0.000699053, match=True
  Disease 294, year 3 after enrollment (t=40): full=0.000563404, batch=0.000563404, match=True
  Disease 294, year 7 after enrollment (t=44): full=0.000300118, batch=0.000300118, match=True
  Disease 132, year 0 after enrollment (t=37): fu