In [None]:
import rpy2.robjects as robjects
import numpy as np
import os as os
from rpy2.robjects import numpy2ri
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.special import expit
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt


%load_ext autoreload
%autoreload 2
# Specify the path to your data files
data_path = "/Users/sarahurbut/tensornoulli_ehr_new/data"

# Activate automatic conversion between R and NumPy arrays
numpy2ri.activate()# Load data saved as .rds files
Y = np.array(robjects.r['readRDS'](os.path.join(data_path, 'Y.rds')))
E = np.array(robjects.r['readRDS'](os.path.join(data_path, 'event_for_aladynoulli.rds')))
G = np.array(robjects.r['readRDS'](os.path.join(data_path, 'prs.rds')))

E = E.astype(int)

# G should be float64
G = G.astype(float)
G.shape
G = G.T
print("G shape after transposition:", G.shape)

# Convert to PyTorch tensors
Y_tensor = torch.FloatTensor(Y)
E_tensor = torch.FloatTensor(E)
G_tensor = torch.FloatTensor(G)

# Get dimensions
N, D, T = Y_tensor.shape
P = G_tensor.shape[1]
T = int(E_tensor.max() + 1)  # 0-indexed time
K = 10  # number of topics

# Print shapes to verify
print("Tensor shapes:")
print(f"Y: {Y_tensor.shape}")  # [N, D]
print(f"E: {E_tensor.shape}")  # [N, D]
print(f"G: {G_tensor.shape}")  # [N, P]
print(f"Time range: 0 to {T-1}")
print(T)
print(K)


import rpy2.robjects as robjects
import pandas as pd
from rpy2.robjects import pandas2ri
pandas2ri.activate()

# Load the metadata from R
disease_names = pd.DataFrame(robjects.r['readRDS']('/Users/sarahurbut/Dropbox (Personal)/disease_names.rds'))
prs_names = pd.DataFrame(robjects.r['readRDS']('/Users/sarahurbut/Dropbox (Personal)/prs_names.rds'))
disease_names_list = disease_names[0].tolist()

In [2]:
from cluster_g import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.spatial.distance import pdist, squareform
from scipy.special import expit
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from sklearn.cluster import SpectralClustering  # Add this import

In [None]:
prevalence_t=compute_smoothed_prevalence(Y=Y,window_size=5)
K=20
# Create model
model = AladynSurvivalFixedKernelsAvgLoss_clust(N, D, T, K, P, G, Y, prevalence_t)
# Visualize the clusters
model.visualize_clusters(disease_names_list)

In [None]:
model = AladynSurvivalFixedKernelsAvgLoss_clust(N, D, T, K, P, G, Y, prevalence_t,disease_names_list)
model.print_cluster_memberships()
model.plot_initial_params()

In [None]:
model.visualize_initialization()
model.psi

In [None]:
Y_tensor.shape

In [None]:
# Initialize model
model = AladynSurvivalFixedKernelsAvgLoss_clust(N, D, T, K, P, G, Y, prevalence_t, disease_names_list)

# Store initial psi values
initial_psi = model.psi.detach().clone()


history = model.fit(E_tensor, num_epochs=1000, learning_rate=1e-4, lambda_reg=1e-2)

# Compare final vs initial psi
print("\nOverall psi changes:")

In [None]:
import matplotlib.pyplot as plt

# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Plot loss
ax1.plot(history['loss'])
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.grid(True)

# Plot gradients
ax2.plot(history['max_grad_lambda'], label='Lambda')
ax2.plot(history['max_grad_phi'], label='Phi')
ax2.plot(history['max_grad_gamma'], label='Gamma')
ax2.plot(history['max_grad_psi'], label='Psi')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Max Gradient Magnitude')
ax2.set_title('Parameter Gradients')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Get the posterior phi values
phi_posterior = model.phi.detach().cpu().numpy()  # Shape should be [K, D, T]

# Create visualization
plt.figure(figsize=(15, 10))

# Plot phi trajectories for each cluster
for k in range(model.K):
    # Get diseases in this cluster
    cluster_mask = (model.clusters == k)
    cluster_phis = phi_posterior[k, cluster_mask, :]
    
    plt.subplot(4, 5, k+1)  # Adjust grid size based on number of clusters
    plt.plot(cluster_phis.T, alpha=0.3)  # Plot each disease trajectory
    plt.title(f'Cluster {k}\n({np.sum(cluster_mask)} diseases)')
    plt.grid(True)
    
plt.tight_layout()
plt.show()

# You might also want to see cluster means
plt.figure(figsize=(12, 6))
cluster_means = np.array([phi_posterior[k, model.clusters == k, :].mean(axis=0) 
                         for k in range(model.K)])
plt.plot(cluster_means.T)
plt.title('Mean Phi Trajectories by Cluster')
plt.xlabel('Time')
plt.ylabel('Phi Value')
plt.legend([f'Cluster {k}' for k in range(model.K)])
plt.grid(True)
plt.show()

In [11]:
# Save current model
import torch

# Save full model state
torch.save({
    'model_state_dict': model.state_dict(),
    'clusters': model.clusters,
    'history': history,
    'hyperparameters': {
        'learning_rate': 1e-4,  # current learning rate
        'lambda_reg': 1e-2,     # current lambda_reg
        # Add any other hyperparameters you want to track
    }
}, 'model_lr1e-4_1124.pt')

In [None]:
# Plot training loss history
plt.figure(figsize=(10, 5))
plt.plot(history['loss'])
plt.yscale('log')
plt.title('Training Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.grid(True)
plt.show()

# Visualize psi map
plt.figure(figsize=(15, 8))
psi_np = model.psi.detach().numpy()
plt.imshow(psi_np, aspect='auto', cmap='RdBu_r')
plt.colorbar(label='ψ value')
plt.xlabel('Disease')
plt.ylabel('State/Cluster')
plt.title('Disease-State Deviations (ψ) After Mean Removal')

# If you have disease names, add them as x-axis labels
if disease_names_list:
    plt.xticks(range(len(disease_names_list)), disease_names_list, rotation=90)

plt.tight_layout()
plt.show()

# Print some summary statistics about psi
print("\nPsi Statistics:")
print(f"Mean: {psi_np.mean():.3f}")
print(f"Std: {psi_np.std():.3f}")
print(f"Min: {psi_np.min():.3f}")
print(f"Max: {psi_np.max():.3f}")

# Optionally, identify the strongest associations
n_top = 5  # Number of top associations to show
for k in range(psi_np.shape[0]):  # For each state/cluster
    top_indices = np.argsort(np.abs(psi_np[k]))[-n_top:]
    print(f"\nTop diseases in State {k}:")
    for idx in top_indices[::-1]:
        disease_name = disease_names_list[idx] if disease_names_list else f"Disease {idx}"
        print(f"{disease_name}: {psi_np[k, idx]:.3f}")

In [None]:
psi_np = model.psi.detach().numpy()

# Print basic info about psi matrix
print("Psi matrix shape:", psi_np.shape)
print("\nPsi value ranges:")
print(f"Min: {psi_np.min():.3f}")
print(f"Max: {psi_np.max():.3f}")
print(f"Mean: {psi_np.mean():.3f}")
print(f"Std: {psi_np.std():.3f}")

# Look at a small sample of values
print("\nSample of psi values (first 5 states, first 5 diseases):")
print(psi_np[:5, :5])