In [1]:
%load_ext autoreload
%autoreload 2



%autoreload 2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.spatial.distance import pdist, squareform
from scipy.special import expit
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from sklearn.cluster import SpectralClustering  # Add this import
import sys
import os
import gc

class DummyFile(object):
    def write(self, x): pass

def suppress_stdout():
    sys.stdout = DummyFile()

def enable_stdout():
    sys.stdout = sys.__stdout__


def load_model_essentials(base_path='/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/'):
    """
    Load all essential components
    """
    print("Loading components...")
    
    # Load large matrices
    Y = torch.load(base_path + 'Y_tensor.pt')
    E = torch.load(base_path + 'E_matrix.pt')
    G = torch.load(base_path + 'G_matrix.pt')
    
    # Load other components
    essentials = torch.load(base_path + 'model_essentials.pt')
    
    print("Loaded all components successfully!")
    
    return Y, E, G, essentials

# Load and initialize model:
Y, E, G, essentials = load_model_essentials()

from clust_huge_amp import *
# Subset the data

# Subset the data
Y_100k, E_100k, G_100k, indices = subset_data(Y, E, G, start_index=20000, end_index=30000)


del Y

# Load references (signatures only, no healthy)
refs = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/reference_trajectories.pt')
signature_refs = refs['signature_refs']
# When initializing the model:
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
pandas2ri.activate()

# Load the RDS file

import pandas as pd
fh_processed=pd.read_csv('/Users/sarahurbut/Library/Cloudstorage/Dropbox/baselinagefamh.csv')
len(fh_processed)



pce_df_subset = fh_processed.iloc[20000:30000].reset_index(drop=True)
sex=pce_df_subset['sex'].values
G_with_sex = np.column_stack([G_100k, sex]) 




Loading components...


  Y = torch.load(base_path + 'Y_tensor.pt')
  E = torch.load(base_path + 'E_matrix.pt')
  G = torch.load(base_path + 'G_matrix.pt')
  essentials = torch.load(base_path + 'model_essentials.pt')


Loaded all components successfully!


  refs = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/reference_trajectories.pt')


In [2]:
import torch
import numpy as np
import cProfile
import pstats
from pstats import SortKey

# Store predictions for each age
age_predictions = {}

for age_offset in range(0, 10):  # Ages 0-10 years after enrollment
    print(f"\n=== Predicting for age offset {age_offset} years ===")
    suppress_stdout()
    # Set seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    # Initialize fresh model for this age
    model = AladynSurvivalFixedKernelsAvgLoss_clust_logitInit_psitest(
        N=Y_100k.shape[0],
        D=Y_100k.shape[1],
        T=Y_100k.shape[2],
        K=20,
        P=G_with_sex.shape[1],
        init_sd_scaler=1e-1,
        G=G_with_sex,
        Y=Y_100k,
        genetic_scale=1,
        W=0.0001,
        R=0,
        prevalence_t=essentials['prevalence_t'],
        signature_references=signature_refs,
        healthy_reference=True,
        disease_names=essentials['disease_names']
    )
    
    # Reset seeds for parameter initialization
    torch.manual_seed(0)
    np.random.seed(0)
    
    # Load and set initial parameters
    initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
    initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')
    model.initialize_params(true_psi=initial_psi)
    model.clusters = initial_clusters
    enable_stdout()
    # Verify clusters match
    clusters_match = np.array_equal(initial_clusters, model.clusters)
    print(f"Clusters match exactly: {clusters_match}")
    
    # Create age-specific event times
    E_age_specific = E_100k.clone()
    pce_df_subset = fh_processed.iloc[20000:30000].reset_index(drop=True)

     
    # Initialize tracking variables for this age offset
    total_times_changed = 0
    max_cap_applied = 0
    min_cap_applied = float('inf')

    
    for patient_idx, row in enumerate(pce_df_subset.itertuples()):
        if patient_idx >= E_age_specific.shape[0]:
            break
            
        # Current age = enrollment age + age_offset
        current_age = row.age + age_offset
        
        # Time since age 30 for this current age
        time_since_30 = max(0, current_age - 30)

        max_cap_applied = max(max_cap_applied, time_since_30)
        min_cap_applied = min(min_cap_applied, time_since_30)
        
        # Store original times for this patient
        original_times = E_age_specific[patient_idx, :].clone()
        
        # Cap event times to current age
        E_age_specific[patient_idx, :] = torch.minimum(
            E_age_specific[patient_idx, :],
            torch.full_like(E_age_specific[patient_idx, :], time_since_30)
        )

        times_changed = torch.sum(E_age_specific[patient_idx, :] != original_times).item()
        total_times_changed += times_changed
    
    # Print censoring verification
    print(f"Censoring verification for age offset {age_offset}:")
    print(f"  Total event times changed: {total_times_changed}")
    print(f"  Max cap applied: {max_cap_applied:.1f}")
    print(f"  Min cap applied: {min_cap_applied:.1f}")
    
    # Check a few specific patients
    test_patients = [0, 1, 100]  # Check patients 0, 1, and 100
    for test_idx in test_patients:
        if test_idx < len(pce_df_subset):
            row = pce_df_subset.iloc[test_idx]
            enrollment_age = row.age
            current_age = enrollment_age + age_offset
            expected_cap = max(0, current_age - 30)
            
            # Check max value in this patient's event times
            max_time = torch.max(E_age_specific[test_idx, :]).item()
            
            print(f"  Patient {test_idx}: enrollment={enrollment_age:.0f}, current={current_age:.0f}, "
                  f"cap={expected_cap:.1f}, max_event_time={max_time:.1f}")
            
            # Verify cap was applied correctly
            if max_time > expected_cap + 0.01:  # Small tolerance
                print(f"    WARNING: Max time {max_time:.1f} exceeds cap {expected_cap:.1f}!")
    
  
    
    # Train model for this specific age
    print(f"Training model for age offset {age_offset}...")
    profiler = cProfile.Profile()
    profiler.enable()
    suppress_stdout()
    history_new = model.fit(
        E_age_specific, 
        num_epochs=200, 
        learning_rate=1e-1, 
        lambda_reg=1e-2
    )
    
    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats(SortKey.CUMULATIVE)
    stats.print_stats(20)
    enable_stdout()
    # Get predictions for this age
    with torch.no_grad():
        pi, _, _ = model.forward()
        
        # Save age-specific predictions
        filename = f"/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_{age_offset}_sex_20000_30000.pt"
        torch.save(pi, filename)
        print(f"Saved predictions to {filename}")
        
        # Store in dictionary for potential analysis
        
    
    # Clean up to free memory
    del pi
    del model
    del E_age_specific
    
    torch.cuda.empty_cache() if torch.cuda.is_available() else None



=== Predicting for age offset 0 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 0:
  Total event times changed: 3460799
  Max cap applied: 40.0
  Min cap applied: 10.0
  Patient 0: enrollment=54, current=54, cap=24.0, max_event_time=24.0
  Patient 1: enrollment=60, current=60, cap=30.0, max_event_time=30.0
  Patient 100: enrollment=46, current=46, cap=16.0, max_event_time=16.0
Training model for age offset 0...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_0_sex_20000_30000.pt

=== Predicting for age offset 1 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 1:
  Total event times changed: 3457980
  Max cap applied: 41.0
  Min cap applied: 11.0
  Patient 0: enrollment=54, current=55, cap=25.0, max_event_time=25.0
  Patient 1: enrollment=60, current=61, cap=31.0, max_event_time=31.0
  Patient 100: enrollment=46, current=47, cap=17.0, max_event_time=17.0
Training model for age offset 1...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_1_sex_20000_30000.pt

=== Predicting for age offset 2 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 2:
  Total event times changed: 3454742
  Max cap applied: 42.0
  Min cap applied: 12.0
  Patient 0: enrollment=54, current=56, cap=26.0, max_event_time=26.0
  Patient 1: enrollment=60, current=62, cap=32.0, max_event_time=32.0
  Patient 100: enrollment=46, current=48, cap=18.0, max_event_time=18.0
Training model for age offset 2...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_2_sex_20000_30000.pt

=== Predicting for age offset 3 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 3:
  Total event times changed: 3451299
  Max cap applied: 43.0
  Min cap applied: 13.0
  Patient 0: enrollment=54, current=57, cap=27.0, max_event_time=27.0
  Patient 1: enrollment=60, current=63, cap=33.0, max_event_time=33.0
  Patient 100: enrollment=46, current=49, cap=19.0, max_event_time=19.0
Training model for age offset 3...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_3_sex_20000_30000.pt

=== Predicting for age offset 4 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 4:
  Total event times changed: 3447717
  Max cap applied: 44.0
  Min cap applied: 14.0
  Patient 0: enrollment=54, current=58, cap=28.0, max_event_time=28.0
  Patient 1: enrollment=60, current=64, cap=34.0, max_event_time=34.0
  Patient 100: enrollment=46, current=50, cap=20.0, max_event_time=20.0
Training model for age offset 4...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_4_sex_20000_30000.pt

=== Predicting for age offset 5 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 5:
  Total event times changed: 3443876
  Max cap applied: 45.0
  Min cap applied: 15.0
  Patient 0: enrollment=54, current=59, cap=29.0, max_event_time=29.0
  Patient 1: enrollment=60, current=65, cap=35.0, max_event_time=35.0
  Patient 100: enrollment=46, current=51, cap=21.0, max_event_time=21.0
Training model for age offset 5...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_5_sex_20000_30000.pt

=== Predicting for age offset 6 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 6:
  Total event times changed: 3439533
  Max cap applied: 46.0
  Min cap applied: 16.0
  Patient 0: enrollment=54, current=60, cap=30.0, max_event_time=30.0
  Patient 1: enrollment=60, current=66, cap=36.0, max_event_time=36.0
  Patient 100: enrollment=46, current=52, cap=22.0, max_event_time=22.0
Training model for age offset 6...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_6_sex_20000_30000.pt

=== Predicting for age offset 7 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 7:
  Total event times changed: 3434810
  Max cap applied: 47.0
  Min cap applied: 17.0
  Patient 0: enrollment=54, current=61, cap=31.0, max_event_time=31.0
  Patient 1: enrollment=60, current=67, cap=37.0, max_event_time=37.0
  Patient 100: enrollment=46, current=53, cap=23.0, max_event_time=23.0
Training model for age offset 7...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_7_sex_20000_30000.pt

=== Predicting for age offset 8 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 8:
  Total event times changed: 3429757
  Max cap applied: 48.0
  Min cap applied: 18.0
  Patient 0: enrollment=54, current=62, cap=32.0, max_event_time=32.0
  Patient 1: enrollment=60, current=68, cap=38.0, max_event_time=38.0
  Patient 100: enrollment=46, current=54, cap=24.0, max_event_time=24.0
Training model for age offset 8...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_8_sex_20000_30000.pt

=== Predicting for age offset 9 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')


Clusters match exactly: True
Censoring verification for age offset 9:
  Total event times changed: 3424576
  Max cap applied: 49.0
  Min cap applied: 19.0
  Patient 0: enrollment=54, current=63, cap=33.0, max_event_time=33.0
  Patient 1: enrollment=60, current=69, cap=39.0, max_event_time=39.0
  Patient 100: enrollment=46, current=55, cap=25.0, max_event_time=25.0
Training model for age offset 9...


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)


Saved predictions to /Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_9_sex_20000_30000.pt


In [None]:
import torch
import numpy as np
import pandas as pd

# Test function to verify age-specific censoring
def test_age_specific_censoring(E_100k, fh_processed, age_offset=5, test_patients=5):
    """
    Test that E_age_specific is correctly updated to reflect row.age + offset - 30
    """
    print(f"\n=== Testing Age-Specific Censoring (offset = {age_offset}) ===")
    
    # Create original and age-specific versions
    E_original = E_100k.clone()
    E_age_specific = E_100k.clone()
    
    # Get subset of patient data
    pce_df_subset = fh_processed.iloc[20000:30000].reset_index(drop=True)
    # Apply age-specific censoring
    for patient_idx, row in enumerate(pce_df_subset.itertuples()):
        if patient_idx >= E_age_specific.shape[0]:
            break
            
        # Current age = enrollment age + age_offset
        current_age = row.age + age_offset
        
        # Time since age 30 for this current age
        time_since_30 = max(0, current_age - 30)
        
        # Cap event times to current age
        E_age_specific[patient_idx, :] = torch.minimum(
            E_age_specific[patient_idx, :],
            torch.full_like(E_age_specific[patient_idx, :], time_since_30)
        )
    
    # Test specific patients
    print(f"\nTesting first {test_patients} patients:")
    print("=" * 80)
    
    for i in range(min(test_patients, len(pce_df_subset))):
        row = pce_df_subset.iloc[i]
        
        enrollment_age = row.age
        current_age = enrollment_age + age_offset
        expected_cap = max(0, current_age - 30)
        
        # Get original and modified event times for this patient
        original_times = E_original[i, :].numpy()
        modified_times = E_age_specific[i, :].numpy()
        
        # Check if any times were actually capped
        times_changed = ~np.isclose(original_times, modified_times)
        
        print(f"\nPatient {i}:")
        print(f"  Enrollment age: {enrollment_age}")
        print(f"  Current age (enrollment + {age_offset}): {current_age}")
        print(f"  Expected cap (current_age - 30): {expected_cap}")
        print(f"  Times changed: {times_changed.sum()}/{len(times_changed)} diseases")
        
        if times_changed.any():
            # Show some examples of changed times
            changed_indices = np.where(times_changed)[0][:3]  # First 3 changes
            print(f"  Example changes:")
            for idx in changed_indices:
                print(f"    Disease {idx}: {original_times[idx]:.1f} â†’ {modified_times[idx]:.1f}")
        
        # Show unchanged diseases and their values
        unchanged_indices = np.where(~times_changed)[0]
        if len(unchanged_indices) > 0:
            print(f"  Unchanged diseases (indices): {unchanged_indices[:5]}")  # Show first 5
            print(f"  Their original times: {original_times[unchanged_indices[:5]]}")
            print(f"  Their modified times: {modified_times[unchanged_indices[:5]]}")
            print(f"  DEBUG - E_original values: {E_original[i, unchanged_indices[:5]].numpy()}")
            print(f"  DEBUG - E_age_specific values: {E_age_specific[i, unchanged_indices[:5]].numpy()}")
        
        # Verify all modified times are <= expected cap
        all_capped_correctly = np.all(modified_times <= expected_cap + 1e-6)  # small tolerance
        print(f"  All times correctly capped: {all_capped_correctly}")
        
        # Check that no times increased
        no_times_increased = np.all(modified_times <= original_times + 1e-6)
        print(f"  No times increased: {no_times_increased}")
    
    return E_original, E_age_specific

# Example usage:
E_orig, E_modified = test_age_specific_censoring(E_100k, fh_processed, age_offset=5)

# Additional verification function
def compare_age_offsets(E_100k, fh_processed, patient_idx=0):
    """
    Show how one patient's event times change across different age offsets
    """
    print(f"\n=== Patient {patient_idx} Across Different Age Offsets ===")
    
    row = fh_processed.iloc[patient_idx]
    enrollment_age = row.age
    
    print(f"Enrollment age: {enrollment_age}")
    print(f"Original event times (first 5 diseases): {E_100k[patient_idx, :5].numpy()}")
    print()
    
    for age_offset in [0, 2, 5, 10]:
        E_test = E_100k.clone()
        current_age = enrollment_age + age_offset
        time_since_30 = max(0, current_age - 30)
        
        E_test[patient_idx, :] = torch.minimum(
            E_test[patient_idx, :],
            torch.full_like(E_test[patient_idx, :], time_since_30)
        )
        
        print(f"Age offset {age_offset:2d} (age {current_age:2.0f}, cap at {time_since_30:2.0f}): {E_test[patient_idx, :5].numpy()}")

# Example usage:
compare_age_offsets(E_100k, fh_processed, patient_idx=0)

In [None]:
E_100k[2,180]