In [1]:
%load_ext autoreload
%autoreload 2



%autoreload 2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.spatial.distance import pdist, squareform
from scipy.special import expit
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from sklearn.cluster import SpectralClustering  # Add this import

def load_model_essentials(base_path='/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/'):
    """
    Load all essential components
    """
    print("Loading components...")
    
    # Load large matrices
    Y = torch.load(base_path + 'Y_tensor.pt')
    E = torch.load(base_path + 'E_matrix.pt')
    G = torch.load(base_path + 'G_matrix.pt')
    
    # Load other components
    essentials = torch.load(base_path + 'model_essentials.pt')
    
    print("Loaded all components successfully!")
    
    return Y, E, G, essentials

# Load and initialize model:
Y, E, G, essentials = load_model_essentials()

from clust_huge_amp import *
# Subset the data

# Subset the data
Y_100k, E_100k, G_100k, indices = subset_data(Y, E, G, start_index=0, end_index=10000)


del Y

# Load references (signatures only, no healthy)
refs = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/reference_trajectories.pt')
signature_refs = refs['signature_refs']
# When initializing the model:
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
pandas2ri.activate()

# Load the RDS file

import pandas as pd
fh_processed=pd.read_csv('/Users/sarahurbut/Library/Cloudstorage/Dropbox/baselinagefamh.csv')
len(fh_processed)



pce_df_subset = fh_processed.iloc[0:10000].reset_index(drop=True)
sex=pce_df_subset['sex'].values
G_with_sex = np.column_stack([G_100k, sex]) 




Loading components...


  Y = torch.load(base_path + 'Y_tensor.pt')
  E = torch.load(base_path + 'E_matrix.pt')
  G = torch.load(base_path + 'G_matrix.pt')
  essentials = torch.load(base_path + 'model_essentials.pt')


Loaded all components successfully!


  refs = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/reference_trajectories.pt')


In [2]:
import torch
import numpy as np
import cProfile
import pstats
from pstats import SortKey

# Store predictions for each age
age_predictions = {}

for age_offset in range(0, 11):  # Ages 0-10 years after enrollment
    print(f"\n=== Predicting for age offset {age_offset} years ===")
    
    # Set seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    # Initialize fresh model for this age
    model = AladynSurvivalFixedKernelsAvgLoss_clust_logitInit_psitest(
        N=Y_100k.shape[0],
        D=Y_100k.shape[1],
        T=Y_100k.shape[2],
        K=20,
        P=G_with_sex.shape[1],
        init_sd_scaler=1e-1,
        G=G_with_sex,
        Y=Y_100k,
        genetic_scale=1,
        W=0.0001,
        R=0,
        prevalence_t=essentials['prevalence_t'],
        signature_references=signature_refs,
        healthy_reference=True,
        disease_names=essentials['disease_names']
    )
    
    # Reset seeds for parameter initialization
    torch.manual_seed(0)
    np.random.seed(0)
    
    # Load and set initial parameters
    initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
    initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')
    model.initialize_params(true_psi=initial_psi)
    model.clusters = initial_clusters
    
    # Verify clusters match
    clusters_match = np.array_equal(initial_clusters, model.clusters)
    print(f"Clusters match exactly: {clusters_match}")
    
    # Create age-specific event times
    E_age_specific = E_100k.clone()
    pce_df_subset = fh_processed.iloc[0:10000].reset_index(drop=True)

     
    # Initialize tracking variables for this age offset
    total_times_changed = 0
    max_cap_applied = 0
    min_cap_applied = float('inf')

    
    for patient_idx, row in enumerate(pce_df_subset.itertuples()):
        if patient_idx >= E_age_specific.shape[0]:
            break
            
        # Current age = enrollment age + age_offset
        current_age = row.age + age_offset
        
        # Time since age 30 for this current age
        time_since_30 = max(0, current_age - 30)

        max_cap_applied = max(max_cap_applied, time_since_30)
        min_cap_applied = min(min_cap_applied, time_since_30)
        
        # Store original times for this patient
        original_times = E_age_specific[patient_idx, :].clone()
        
        # Cap event times to current age
        E_age_specific[patient_idx, :] = torch.minimum(
            E_age_specific[patient_idx, :],
            torch.full_like(E_age_specific[patient_idx, :], time_since_30)
        )

        times_changed = torch.sum(E_age_specific[patient_idx, :] != original_times).item()
        total_times_changed += times_changed
    
    # Print censoring verification
    print(f"Censoring verification for age offset {age_offset}:")
    print(f"  Total event times changed: {total_times_changed}")
    print(f"  Max cap applied: {max_cap_applied:.1f}")
    print(f"  Min cap applied: {min_cap_applied:.1f}")
    
    # Check a few specific patients
    test_patients = [0, 1, 100]  # Check patients 0, 1, and 100
    for test_idx in test_patients:
        if test_idx < len(pce_df_subset):
            row = pce_df_subset.iloc[test_idx]
            enrollment_age = row.age
            current_age = enrollment_age + age_offset
            expected_cap = max(0, current_age - 30)
            
            # Check max value in this patient's event times
            max_time = torch.max(E_age_specific[test_idx, :]).item()
            
            print(f"  Patient {test_idx}: enrollment={enrollment_age:.0f}, current={current_age:.0f}, "
                  f"cap={expected_cap:.1f}, max_event_time={max_time:.1f}")
            
            # Verify cap was applied correctly
            if max_time > expected_cap + 0.01:  # Small tolerance
                print(f"    WARNING: Max time {max_time:.1f} exceeds cap {expected_cap:.1f}!")
    
  
    
    # Train model for this specific age
    print(f"Training model for age offset {age_offset}...")
    profiler = cProfile.Profile()
    profiler.enable()
    
    history_new = model.fit(
        E_age_specific, 
        num_epochs=200, 
        learning_rate=1e-1, 
        lambda_reg=1e-2
    )
    
    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats(SortKey.CUMULATIVE)
    stats.print_stats(20)
    
    # Get predictions for this age
    with torch.no_grad():
        pi, _, _ = model.forward()
        
        # Save age-specific predictions
        filename = f"/Users/sarahurbut/Library/CloudStorage/Dropbox/pi_enroll_age_offset_{age_offset}_sex_0_10000.pt"
        torch.save(pi, filename)
        print(f"Saved predictions to {filename}")
        
        # Store in dictionary for potential analysis
        
    
    # Clean up to free memory
    del pi
    del model
    del E_age_specific
    torch.cuda.empty_cache() if torch.cuda.is_available() else None



=== Predicting for age offset 0 years ===


  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b



Cluster Sizes:
Cluster 0: 14 diseases
Cluster 1: 7 diseases
Cluster 2: 21 diseases
Cluster 3: 15 diseases
Cluster 4: 17 diseases
Cluster 5: 16 diseases
Cluster 6: 57 diseases
Cluster 7: 18 diseases
Cluster 8: 13 diseases
Cluster 9: 11 diseases
Cluster 10: 18 diseases
Cluster 11: 12 diseases
Cluster 12: 26 diseases
Cluster 13: 7 diseases
Cluster 14: 9 diseases
Cluster 15: 8 diseases
Cluster 16: 7 diseases
Cluster 17: 11 diseases
Cluster 18: 6 diseases
Cluster 19: 55 diseases

Calculating gamma for k=0:
Number of diseases in cluster: 14
Base value (first 5): tensor([-13.8155, -13.8155, -13.1095, -12.4036, -12.4036])
Base value centered (first 5): tensor([-0.3723, -0.3723,  0.3336,  1.0396,  1.0396])
Base value centered mean: 6.57081614008348e-07
Gamma init for k=0 (first 5): tensor([ 0.0008,  0.0071,  0.0117,  0.0152, -0.0106])

Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07
Gamma init for k=2 (first 5): tensor([-0.0001,  0.0092,  0.0113,  0.0160, -0.0109])

Calculating gamma for k=3:
Number of diseases in cluster: 82.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.6950, -13.5744])
Base value centered (first 5): tensor([-0.1026, -0.1026, -0.1026,  0.0179,  0.1384])
Base value centered mean: 4.7445297468584613e-07
Gamma init for k=3 (first 5): tensor([ 0.0011,  0.0003,  0.0017,  0.0019, -0.0006])

Calculating gamma for k=4:
Number of diseases in cluster: 5.0
Base value (first 5): tensor([-13.8155,  -9.8620, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1133,  3.8402, -0.1133, -0.1133, -0.1133])
Base value centered mean: -2.841758714566822e-

  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 17.1567

Monitoring signature responses:

Disease 161 (signature 7, LR=32.14):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.74):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.148
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.45):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.42):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.86):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.001

Epoch 1
Loss: 663.9888

Monitoring signature responses:

Disease 161 (signature 7, LR=32.14):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.77):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b



Calculating gamma for k=0:
Number of diseases in cluster: 14
Base value (first 5): tensor([-13.8155, -13.8155, -13.1095, -12.4036, -12.4036])
Base value centered (first 5): tensor([-0.3723, -0.3723,  0.3336,  1.0396,  1.0396])
Base value centered mean: 6.57081614008348e-07
Gamma init for k=0 (first 5): tensor([ 0.0008,  0.0071,  0.0117,  0.0152, -0.0106])

Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gam

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07
Gamma init for k=2 (first 5): tensor([-0.0001,  0.0092,  0.0113,  0.0160, -0.0109])

Calculating gamma for k=3:
Number of diseases in cluster: 82.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.6950, -13.5744])
Base value centered (first 5): tensor([-0.1026, -0.1026, -0.1026,  0.0179,  0.1384])
Base value centered mean: 4.7445297468584613e-07
Gamma init for k=3 (first 5): tensor([ 0.0011,  0.0003,  0.0017,  0.0019, -0.0006])

Calculating gamma for k=4:
Number of diseases in cluster: 5.0
Base value (first 5): tensor([-13.8155,  -9.8620, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1133,  3.8402, -0.1133, -0.1133, -0.1133])
Base value centered mean: -2.841758714566822e-

  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 19.4218

Monitoring signature responses:

Disease 161 (signature 7, LR=32.16):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.74):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.148
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.46):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.42):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.86):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 666.0272

Monitoring signature responses:

Disease 161 (signature 7, LR=32.19):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.77):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b



Calculating gamma for k=0:
Number of diseases in cluster: 14
Base value (first 5): tensor([-13.8155, -13.8155, -13.1095, -12.4036, -12.4036])
Base value centered (first 5): tensor([-0.3723, -0.3723,  0.3336,  1.0396,  1.0396])
Base value centered mean: 6.57081614008348e-07
Gamma init for k=0 (first 5): tensor([ 0.0008,  0.0071,  0.0117,  0.0152, -0.0106])

Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gam

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07
Gamma init for k=2 (first 5): tensor([-0.0001,  0.0092,  0.0113,  0.0160, -0.0109])

Calculating gamma for k=3:
Number of diseases in cluster: 82.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.6950, -13.5744])
Base value centered (first 5): tensor([-0.1026, -0.1026, -0.1026,  0.0179,  0.1384])
Base value centered mean: 4.7445297468584613e-07
Gamma init for k=3 (first 5): tensor([ 0.0011,  0.0003,  0.0017,  0.0019, -0.0006])

Calculating gamma for k=4:
Number of diseases in cluster: 5.0
Base value (first 5): tensor([-13.8155,  -9.8620, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1133,  3.8402, -0.1133, -0.1133, -0.1133])
Base value centered mean: -2.841758714566822e-

  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 22.0458

Monitoring signature responses:

Disease 161 (signature 7, LR=32.20):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.74):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.148
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.46):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.44):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.86):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 668.3786

Monitoring signature responses:

Disease 161 (signature 7, LR=32.25):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.77):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b



Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gamma init for k=2 (first 5): tensor([ 0.0018, -0.0023,  0.0065,  0.0024, -0.0054])

Calculating gamma for k=3:
Number of diseases in cluster: 15
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1155, -0.1155, -0.1155, -0.1155, -0.1155])
Base value centered mean: 1.0881424117314964e-07
G

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=1:
Number of diseases in cluster: 21.0
Base value (first 5): tensor([-13.3449, -13.8155, -13.3449, -13.3449, -12.4036])
Base value centered (first 5): tensor([ 0.1505, -0.3201,  0.1505,  0.1505,  1.0918])
Base value centered mean: -1.8495559288567165e-06
Gamma init for k=1 (first 5): tensor([0.0044, 0.0012, 0.0007, 0.0026, 0.0014])

Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07
Gamma init for k=2 (first 5): tensor([-0.0001,  0.0092,  0.0113,  0.0160, -0.0109])

Calculating gamma for k=3:
Number of diseases in cluster: 82.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.6950, -13.5744])
Base value centered (first 5): tensor([-0.1026, -0.1026, -0.1026,  0.0179,  0.1384])
Base value centered mean: 4.7445297468584613e-07


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 24.9522

Monitoring signature responses:

Disease 161 (signature 7, LR=32.27):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.74):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.148
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.46):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.44):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.86):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 671.0048

Monitoring signature responses:

Disease 161 (signature 7, LR=32.31):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.77):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b



Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gamma init for k=2 (first 5): tensor([ 0.0018, -0.0023,  0.0065,  0.0024, -0.0054])

Calculating gamma for k=3:
Number of diseases in cluster: 15
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1155, -0.1155, -0.1155, -0.1155, -0.1155])
Base value centered mean: 1.0881424117314964e-07
G

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=0:
Number of diseases in cluster: 16.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -11.9623, -13.8155])
Base value centered (first 5): tensor([-0.1879, -0.1879, -0.1879,  1.6653, -0.1879])
Base value centered mean: -3.345489574257954e-07
Gamma init for k=0 (first 5): tensor([ 0.0069,  0.0066, -0.0055,  0.0062,  0.0243])

Calculating gamma for k=1:
Number of diseases in cluster: 21.0
Base value (first 5): tensor([-13.3449, -13.8155, -13.3449, -13.3449, -12.4036])
Base value centered (first 5): tensor([ 0.1505, -0.3201,  0.1505,  0.1505,  1.0918])
Base value centered mean: -1.8495559288567165e-06
Gamma init for k=1 (first 5): tensor([0.0044, 0.0012, 0.0007, 0.0026, 0.0014])

Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 28.1000

Monitoring signature responses:

Disease 161 (signature 7, LR=32.29):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.76):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.148
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.46):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.45):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.86):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 673.8416

Monitoring signature responses:

Disease 161 (signature 7, LR=32.37):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.77):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b



Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gamma init for k=2 (first 5): tensor([ 0.0018, -0.0023,  0.0065,  0.0024, -0.0054])

Calculating gamma for k=3:
Number of diseases in cluster: 15
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1155, -0.1155, -0.1155, -0.1155, -0.1155])
Base value centered mean: 1.0881424117314964e-07
G

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=0:
Number of diseases in cluster: 16.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -11.9623, -13.8155])
Base value centered (first 5): tensor([-0.1879, -0.1879, -0.1879,  1.6653, -0.1879])
Base value centered mean: -3.345489574257954e-07
Gamma init for k=0 (first 5): tensor([ 0.0069,  0.0066, -0.0055,  0.0062,  0.0243])

Calculating gamma for k=1:
Number of diseases in cluster: 21.0
Base value (first 5): tensor([-13.3449, -13.8155, -13.3449, -13.3449, -12.4036])
Base value centered (first 5): tensor([ 0.1505, -0.3201,  0.1505,  0.1505,  1.0918])
Base value centered mean: -1.8495559288567165e-06
Gamma init for k=1 (first 5): tensor([0.0044, 0.0012, 0.0007, 0.0026, 0.0014])

Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 31.2812

Monitoring signature responses:

Disease 161 (signature 7, LR=32.31):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.78):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.148
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.46):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.45):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.86):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 676.6797

Monitoring signature responses:

Disease 161 (signature 7, LR=32.43):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.78):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b



Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gamma init for k=2 (first 5): tensor([ 0.0018, -0.0023,  0.0065,  0.0024, -0.0054])

Calculating gamma for k=3:
Number of diseases in cluster: 15
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1155, -0.1155, -0.1155, -0.1155, -0.1155])
Base value centered mean: 1.0881424117314964e-07
G

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=0:
Number of diseases in cluster: 16.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -11.9623, -13.8155])
Base value centered (first 5): tensor([-0.1879, -0.1879, -0.1879,  1.6653, -0.1879])
Base value centered mean: -3.345489574257954e-07
Gamma init for k=0 (first 5): tensor([ 0.0069,  0.0066, -0.0055,  0.0062,  0.0243])

Calculating gamma for k=1:
Number of diseases in cluster: 21.0
Base value (first 5): tensor([-13.3449, -13.8155, -13.3449, -13.3449, -12.4036])
Base value centered (first 5): tensor([ 0.1505, -0.3201,  0.1505,  0.1505,  1.0918])
Base value centered mean: -1.8495559288567165e-06
Gamma init for k=1 (first 5): tensor([0.0044, 0.0012, 0.0007, 0.0026, 0.0014])

Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 34.7708

Monitoring signature responses:

Disease 161 (signature 7, LR=32.32):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.78):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.147
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.46):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.45):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.86):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 679.8145

Monitoring signature responses:

Disease 161 (signature 7, LR=32.43):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.78):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b



Calculating gamma for k=0:
Number of diseases in cluster: 14
Base value (first 5): tensor([-13.8155, -13.8155, -13.1095, -12.4036, -12.4036])
Base value centered (first 5): tensor([-0.3723, -0.3723,  0.3336,  1.0396,  1.0396])
Base value centered mean: 6.57081614008348e-07
Gamma init for k=0 (first 5): tensor([ 0.0008,  0.0071,  0.0117,  0.0152, -0.0106])

Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gam

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07
Gamma init for k=2 (first 5): tensor([-0.0001,  0.0092,  0.0113,  0.0160, -0.0109])

Calculating gamma for k=3:
Number of diseases in cluster: 82.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.6950, -13.5744])
Base value centered (first 5): tensor([-0.1026, -0.1026, -0.1026,  0.0179,  0.1384])
Base value centered mean: 4.7445297468584613e-07
Gamma init for k=3 (first 5): tensor([ 0.0011,  0.0003,  0.0017,  0.0019, -0.0006])

Calculating gamma for k=4:
Number of diseases in cluster: 5.0
Base value (first 5): tensor([-13.8155,  -9.8620, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1133,  3.8402, -0.1133, -0.1133, -0.1133])
Base value centered mean: -2.841758714566822e-

  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 38.3543

Monitoring signature responses:

Disease 161 (signature 7, LR=32.36):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.78):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.147
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.46):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.52):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.86):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 683.0054

Monitoring signature responses:

Disease 161 (signature 7, LR=32.49):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.78):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b



Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gamma init for k=2 (first 5): tensor([ 0.0018, -0.0023,  0.0065,  0.0024, -0.0054])

Calculating gamma for k=3:
Number of diseases in cluster: 15
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1155, -0.1155, -0.1155, -0.1155, -0.1155])
Base value centered mean: 1.0881424117314964e-07
G

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=0:
Number of diseases in cluster: 16.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -11.9623, -13.8155])
Base value centered (first 5): tensor([-0.1879, -0.1879, -0.1879,  1.6653, -0.1879])
Base value centered mean: -3.345489574257954e-07
Gamma init for k=0 (first 5): tensor([ 0.0069,  0.0066, -0.0055,  0.0062,  0.0243])

Calculating gamma for k=1:
Number of diseases in cluster: 21.0
Base value (first 5): tensor([-13.3449, -13.8155, -13.3449, -13.3449, -12.4036])
Base value centered (first 5): tensor([ 0.1505, -0.3201,  0.1505,  0.1505,  1.0918])
Base value centered mean: -1.8495559288567165e-06
Gamma init for k=1 (first 5): tensor([0.0044, 0.0012, 0.0007, 0.0026, 0.0014])

Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07


  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 41.9413

Monitoring signature responses:

Disease 161 (signature 7, LR=32.37):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.81):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.147
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.45):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.49):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.88):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 686.1978

Monitoring signature responses:

Disease 161 (signature 7, LR=32.55):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.80):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b


Number of diseases in cluster: 14
Base value (first 5): tensor([-13.8155, -13.8155, -13.1095, -12.4036, -12.4036])
Base value centered (first 5): tensor([-0.3723, -0.3723,  0.3336,  1.0396,  1.0396])
Base value centered mean: 6.57081614008348e-07
Gamma init for k=0 (first 5): tensor([ 0.0008,  0.0071,  0.0117,  0.0152, -0.0106])

Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gamma init for k=2 (first 5): t

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07
Gamma init for k=2 (first 5): tensor([-0.0001,  0.0092,  0.0113,  0.0160, -0.0109])

Calculating gamma for k=3:
Number of diseases in cluster: 82.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.6950, -13.5744])
Base value centered (first 5): tensor([-0.1026, -0.1026, -0.1026,  0.0179,  0.1384])
Base value centered mean: 4.7445297468584613e-07
Gamma init for k=3 (first 5): tensor([ 0.0011,  0.0003,  0.0017,  0.0019, -0.0006])

Calculating gamma for k=4:
Number of diseases in cluster: 5.0
Base value (first 5): tensor([-13.8155,  -9.8620, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1133,  3.8402, -0.1133, -0.1133, -0.1133])
Base value centered mean: -2.841758714566822e-

  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 45.8427

Monitoring signature responses:

Disease 161 (signature 7, LR=32.37):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.85):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.147
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.45):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.49):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.88):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 689.6732

Monitoring signature responses:

Disease 161 (signature 7, LR=32.55):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.82):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
  self.Y = torch.tensor(Y, dtype=torch.float32)
  ret = a @ b
  ret = a @ b
  ret = a @ b


Gamma init for k=0 (first 5): tensor([ 0.0008,  0.0071,  0.0117,  0.0152, -0.0106])

Calculating gamma for k=1:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7043, -0.7043, -0.7043, -0.7043, -0.7043])
Base value centered mean: -2.2621155437718699e-07
Gamma init for k=1 (first 5): tensor([ 0.0219,  0.0117,  0.0025, -0.0018, -0.0047])

Calculating gamma for k=2:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.9916, -13.3449])
Base value centered (first 5): tensor([-0.2393, -0.2393, -0.2393,  2.5846,  0.2313])
Base value centered mean: -3.040313742985745e-07
Gamma init for k=2 (first 5): tensor([ 0.0018, -0.0023,  0.0065,  0.0024, -0.0054])

Calculating gamma for k=3:
Number of diseases in cluster: 15
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1155, -0.

  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_psi_400k.pt')
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox/data_for_running/initial_clusters_400k.pt')



Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3849, -0.3849,  0.2740,  1.5918,  0.9329])
Base value centered mean: 9.290695288655115e-07
Gamma init for k=2 (first 5): tensor([-0.0001,  0.0092,  0.0113,  0.0160, -0.0109])

Calculating gamma for k=3:
Number of diseases in cluster: 82.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.6950, -13.5744])
Base value centered (first 5): tensor([-0.1026, -0.1026, -0.1026,  0.0179,  0.1384])
Base value centered mean: 4.7445297468584613e-07
Gamma init for k=3 (first 5): tensor([ 0.0011,  0.0003,  0.0017,  0.0019, -0.0006])

Calculating gamma for k=4:
Number of diseases in cluster: 5.0
Base value (first 5): tensor([-13.8155,  -9.8620, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1133,  3.8402, -0.1133, -0.1133, -0.1133])
Base value centered mean: -2.841758714566822e-

  event_times_tensor = torch.tensor(event_times, dtype=torch.long)



Epoch 0
Loss: 50.0059

Monitoring signature responses:

Disease 161 (signature 7, LR=32.38):
  Theta for diagnosed: 0.150 ± 0.038
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.85):
  Theta for diagnosed: 0.153 ± 0.039
  Theta for others: 0.147
  Proportion difference: 0.005

Disease 260 (signature 8, LR=30.45):
  Theta for diagnosed: 0.097 ± 0.081
  Theta for others: 0.087
  Proportion difference: 0.010

Disease 347 (signature 3, LR=29.52):
  Theta for diagnosed: 0.149 ± 0.070
  Theta for others: 0.150
  Proportion difference: -0.001

Disease 50 (signature 15, LR=28.91):
  Theta for diagnosed: 0.016 ± 0.006
  Theta for others: 0.014
  Proportion difference: 0.002

Epoch 1
Loss: 693.3644

Monitoring signature responses:

Disease 161 (signature 7, LR=32.57):
  Theta for diagnosed: 0.150 ± 0.036
  Theta for others: 0.147
  Proportion difference: 0.003

Disease 76 (signature 7, LR=30.82):
  Theta for diagnosed: 0.153 ± 0.037
  Theta for others: 0

In [None]:
import torch
import numpy as np
import pandas as pd

# Test function to verify age-specific censoring
def test_age_specific_censoring(E_100k, fh_processed, age_offset=5, test_patients=5):
    """
    Test that E_age_specific is correctly updated to reflect row.age + offset - 30
    """
    print(f"\n=== Testing Age-Specific Censoring (offset = {age_offset}) ===")
    
    # Create original and age-specific versions
    E_original = E_100k.clone()
    E_age_specific = E_100k.clone()
    
    # Get subset of patient data
    pce_df_subset = fh_processed.iloc[0:10000].reset_index(drop=True)
    
    # Apply age-specific censoring
    for patient_idx, row in enumerate(pce_df_subset.itertuples()):
        if patient_idx >= E_age_specific.shape[0]:
            break
            
        # Current age = enrollment age + age_offset
        current_age = row.age + age_offset
        
        # Time since age 30 for this current age
        time_since_30 = max(0, current_age - 30)
        
        # Cap event times to current age
        E_age_specific[patient_idx, :] = torch.minimum(
            E_age_specific[patient_idx, :],
            torch.full_like(E_age_specific[patient_idx, :], time_since_30)
        )
    
    # Test specific patients
    print(f"\nTesting first {test_patients} patients:")
    print("=" * 80)
    
    for i in range(min(test_patients, len(pce_df_subset))):
        row = pce_df_subset.iloc[i]
        
        enrollment_age = row.age
        current_age = enrollment_age + age_offset
        expected_cap = max(0, current_age - 30)
        
        # Get original and modified event times for this patient
        original_times = E_original[i, :].numpy()
        modified_times = E_age_specific[i, :].numpy()
        
        # Check if any times were actually capped
        times_changed = ~np.isclose(original_times, modified_times)
        
        print(f"\nPatient {i}:")
        print(f"  Enrollment age: {enrollment_age}")
        print(f"  Current age (enrollment + {age_offset}): {current_age}")
        print(f"  Expected cap (current_age - 30): {expected_cap}")
        print(f"  Times changed: {times_changed.sum()}/{len(times_changed)} diseases")
        
        if times_changed.any():
            # Show some examples of changed times
            changed_indices = np.where(times_changed)[0][:3]  # First 3 changes
            print(f"  Example changes:")
            for idx in changed_indices:
                print(f"    Disease {idx}: {original_times[idx]:.1f} → {modified_times[idx]:.1f}")
        
        # Show unchanged diseases and their values
        unchanged_indices = np.where(~times_changed)[0]
        if len(unchanged_indices) > 0:
            print(f"  Unchanged diseases (indices): {unchanged_indices[:5]}")  # Show first 5
            print(f"  Their original times: {original_times[unchanged_indices[:5]]}")
            print(f"  Their modified times: {modified_times[unchanged_indices[:5]]}")
            print(f"  DEBUG - E_original values: {E_original[i, unchanged_indices[:5]].numpy()}")
            print(f"  DEBUG - E_age_specific values: {E_age_specific[i, unchanged_indices[:5]].numpy()}")
        
        # Verify all modified times are <= expected cap
        all_capped_correctly = np.all(modified_times <= expected_cap + 1e-6)  # small tolerance
        print(f"  All times correctly capped: {all_capped_correctly}")
        
        # Check that no times increased
        no_times_increased = np.all(modified_times <= original_times + 1e-6)
        print(f"  No times increased: {no_times_increased}")
    
    return E_original, E_age_specific

# Example usage:
E_orig, E_modified = test_age_specific_censoring(E_100k, fh_processed, age_offset=5)

# Additional verification function
def compare_age_offsets(E_100k, fh_processed, patient_idx=0):
    """
    Show how one patient's event times change across different age offsets
    """
    print(f"\n=== Patient {patient_idx} Across Different Age Offsets ===")
    
    row = fh_processed.iloc[patient_idx]
    enrollment_age = row.age
    
    print(f"Enrollment age: {enrollment_age}")
    print(f"Original event times (first 5 diseases): {E_100k[patient_idx, :5].numpy()}")
    print()
    
    for age_offset in [0, 2, 5, 10]:
        E_test = E_100k.clone()
        current_age = enrollment_age + age_offset
        time_since_30 = max(0, current_age - 30)
        
        E_test[patient_idx, :] = torch.minimum(
            E_test[patient_idx, :],
            torch.full_like(E_test[patient_idx, :], time_since_30)
        )
        
        print(f"Age offset {age_offset:2d} (age {current_age:2.0f}, cap at {time_since_30:2.0f}): {E_test[patient_idx, :5].numpy()}")

# Example usage:
compare_age_offsets(E_100k, fh_processed, patient_idx=0)

In [None]:
E_100k[2,180]