In [1]:
# SPIKE-AND-SLAB-LASSO

In [None]:
pip install pyro-ppl torch pandas matplotlib numpy

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import ClippedAdam
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler
from lifelines.utils import concordance_index
from lifelines import KaplanMeierFitter, CoxPHFitter
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import KFold
import time
from datetime import datetime
import os
import json
import warnings
warnings.filterwarnings('ignore')

# Configuration
config = {
    # Training parameters
    "max_epochs": 5000,
    "batch_size": 128,
    "initial_lr": 5e-4,
    "lr_decay": 0.2,
    "decay_step": 500,
    "clip_norm": 10.0,
    "early_stop_patience": 200,
    
    # Device and precision
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "precision": torch.float32,
    "pin_memory": False if torch.cuda.is_available() else True,  # Only pin if using CPU
    "num_workers": 0,  # Set to 0 for Windows
    "cuda_deterministic": False,
    
    # Model parameters
    "elbo_particles": 10,
    "warmup_epochs": 100,
    "n_folds": 5,
    
    # Logging and checkpointing
    "log_freq": 100,
    "checkpoint_freq": 500,
    
    # Paths
    "data_path": r"D:\cox-model-imputation\error-in-r-code-for-mcar\datasets\vb-cox\realistic_cox_data.csv",
    "results_dir": f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
}

def setup_gpu():
    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        
        # Set the default generator to CUDA and set seeds for reproducibility
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        
        print("\nGPU Information:")
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        print(f"CUDA version: {torch.version.cuda}")
        print(f"Device: {torch.cuda.get_device_name(0)}")
        print(f"Device count: {torch.cuda.device_count()}")
        print(f"Current device: {torch.cuda.current_device()}")
        
        # Initial GPU memory info (just once at startup)
        print(f"Initial GPU Memory: {torch.cuda.memory_allocated(0)/1024**2:.2f} MB allocated, {torch.cuda.memory_reserved(0)/1024**2:.2f} MB reserved")
        
        # Set for reproducibility
        torch.backends.cudnn.deterministic = config["cuda_deterministic"]
        torch.backends.cudnn.benchmark = not config["cuda_deterministic"]
        
        return True
    else:
        print("\nNo GPU available, using CPU")
        return False

def load_and_preprocess(path, config):
    print(f"Loading data from: {path}")
    start_time = time.time()
    
    try:
        data = pd.read_csv(path)
        print(f"Data shape: {data.shape}")
    except Exception as e:
        print(f"Error loading data: {e}")
        raise
    
    # Extract features and survival data
    X = data.iloc[:, 1:-2].values  # Skip ID, time, event
    time_values = data['time'].values
    event_values = data['event'].values
    
    # Handle missing values and standardize
    X = (X - np.nanmean(X, axis=0)) / (np.nanstd(X, axis=0) + 1e-8)
    X = np.nan_to_num(X)
    
    # Convert to tensors and move to GPU if available
    X_tensor = torch.tensor(X, dtype=config["precision"]).to(config["device"])
    time_tensor = torch.tensor(time_values, dtype=config["precision"]).to(config["device"])
    event_tensor = torch.tensor(event_values, dtype=config["precision"]).to(config["device"])
    
    print(f"Data loading time: {time.time()-start_time:.2f}s")
    print(f"Data loaded on: {X_tensor.device}")
    print(f"Number of features: {X_tensor.shape[1]}")
    print(f"Number of samples: {X_tensor.shape[0]}")
    print(f"Event rate: {event_tensor.mean().item():.2%}")
    
    return X_tensor, time_tensor, event_tensor

class CoxPartialLikelihood(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-8
        self.max_exp = 20.0
        
    def forward(self, X_beta, survival_time, event):
        event_mask = event == 1
        n_events = event_mask.sum()
        
        if n_events < 1:
            return torch.tensor(0.0, device=X_beta.device)
        
        X_beta = X_beta - X_beta.max()
        exp_Xb = torch.clamp(X_beta.exp(), max=torch.exp(torch.tensor(self.max_exp, device=X_beta.device)))
        
        risk_matrix = (survival_time.unsqueeze(0) >= survival_time[event_mask].unsqueeze(1)).float()
        sum_exp = (risk_matrix * exp_Xb.unsqueeze(0)).sum(dim=1)
        log_sum_exp = torch.log(sum_exp + self.eps)
        
        return (X_beta[event_mask] - log_sum_exp).sum() / X_beta.size(0)

def model(X, survival_time, event):
    n, p = X.shape
    device = X.device
    
    # Global shrinkage parameter
    global_scale = pyro.sample("global_scale", 
                             dist.HalfCauchy(torch.tensor(1.0, device=device)))
    
    # Use a single Normal distribution instead of a mixture for simplicity
    with pyro.plate("features", p, dim=-1):
        # Local shrinkage parameters
        local_scale = pyro.sample("local_scale", 
                                dist.HalfCauchy(torch.tensor(1.0, device=device)))
        
        # Use a single shrinkage prior
        scale = global_scale * local_scale
        beta = pyro.sample("beta", dist.Normal(torch.zeros(p, device=device), scale))
    
    # Compute likelihood
    cox_likelihood = CoxPartialLikelihood()
    log_pl = cox_likelihood(X @ beta, survival_time, event)
    pyro.factor("log_pl", log_pl)
    
    return beta

def guide(X, survival_time, event):
    n, p = X.shape
    device = X.device
    
    # Define parameters for global scale
    global_scale_loc = pyro.param(
        "global_scale_loc", 
        torch.tensor(1.0, device=device),
        constraint=constraints.positive
    )
    
    global_scale_scale = pyro.param(
        "global_scale_scale",
        torch.tensor(0.1, device=device),
        constraint=constraints.positive
    )
    
    # Sample global scale
    pyro.sample(
        "global_scale",
        dist.LogNormal(global_scale_loc.log(), global_scale_scale)
    )
    
    # Define parameters for local scales and betas
    with pyro.plate("features", p, dim=-1):
        # Local scale parameters
        local_scale_loc = pyro.param(
            "local_scale_loc",
            torch.ones(p, device=device),
            constraint=constraints.positive
        )
        
        local_scale_scale = pyro.param(
            "local_scale_scale",
            torch.ones(p, device=device) * 0.1,
            constraint=constraints.positive
        )
        
        # Sample local scales
        pyro.sample(
            "local_scale",
            dist.LogNormal(local_scale_loc.log(), local_scale_scale)
        )
        
        # Beta parameters
        beta_loc = pyro.param(
            "beta_loc", 
            torch.zeros(p, device=device)
        )
        
        beta_scale = pyro.param(
            "beta_scale",
            torch.ones(p, device=device) * 0.1,
            constraint=constraints.positive
        )
        
        # Sample beta
        pyro.sample("beta", dist.Normal(beta_loc, beta_scale))

def initialize_params():
    # This function ensures all parameters are registered before training
    device = config["device"]
    
    # Initialize with dummy tensors just to register parameters
    X_dummy = torch.zeros((10, 1000), device=device)
    time_dummy = torch.zeros(10, device=device)
    event_dummy = torch.zeros(10, device=device)
    
    # Run model and guide once with dummy data to register all parameters
    pyro.clear_param_store()
    model(X_dummy, time_dummy, event_dummy)
    guide(X_dummy, time_dummy, event_dummy)
    
    # Print registered parameters
    print("Initialized parameters:")
    for name, param in pyro.get_param_store().items():
        print(f"  {name}: {param.shape}")

class BayesianCoxModel:
    def __init__(self, config):
        self.config = config
        self.device = config["device"]
        pyro.clear_param_store()
        
        # Initialize parameters to ensure they're registered
        initialize_params()

    def train(self, train_loader, val_loader=None):
        optimizer = ClippedAdam({
            "lr": self.config["initial_lr"],
            "clip_norm": self.config["clip_norm"],
            "weight_decay": 1e-4
        })
        
        svi = SVI(model, guide, optimizer, 
                 loss=Trace_ELBO(num_particles=self.config["elbo_particles"]))
        
        best_loss = float('inf')
        losses = []
        val_losses = []
        patience = 0
        param_trajectories = {}
        
        # Initialize parameter trajectories with all current parameters
        for name in pyro.get_param_store().keys():
            param_trajectories[name] = []
        
        try:
            for epoch in range(self.config["max_epochs"]):
                epoch_loss = 0.0
                for batch_idx, (X_batch, t_batch, e_batch) in enumerate(train_loader):
                    # Ensure data is on the correct device
                    X_batch = X_batch.to(self.device)
                    t_batch = t_batch.to(self.device)
                    e_batch = e_batch.to(self.device)
                    
                    loss = svi.step(X_batch, t_batch, e_batch)
                    epoch_loss += loss
                
                avg_loss = epoch_loss / len(train_loader)
                losses.append(avg_loss)
                
                # Store parameter trajectories
                for name, param in pyro.get_param_store().items():
                    if name not in param_trajectories:
                        param_trajectories[name] = []
                    param_trajectories[name].append(param.detach().cpu().numpy())
                
                # Validation loss
                if val_loader is not None:
                    val_loss = self.evaluate(svi, val_loader)
                    val_losses.append(val_loss)
                    
                    if val_loss < best_loss:
                        best_loss = val_loss
                        patience = 0
                        # Save best model parameters
                        self.save_checkpoint(epoch, val_loss)
                    else:
                        patience += 1
                        
                    if patience >= self.config["early_stop_patience"]:
                        print(f"Early stopping at epoch {epoch}")
                        break
                
                # Learning rate decay
                if epoch % self.config["decay_step"] == 0 and epoch > 0:
                    current_lr = self.config["initial_lr"] * \
                               (self.config["lr_decay"] ** (epoch // self.config["decay_step"]))
                    optimizer.set_state({'lr': current_lr})
                    
                if epoch % self.config["log_freq"] == 0:
                    print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
                    if val_loader is not None:
                        print(f"Validation Loss = {val_losses[-1]:.4f}")
        
        except KeyboardInterrupt:
            print("\nTraining interrupted by user")
        except Exception as e:
            print(f"Error during training: {e}")
            raise
        
        return losses, val_losses, param_trajectories
    
    def evaluate(self, svi, loader):
        val_loss = 0.0
        with torch.no_grad():
            for X_batch, t_batch, e_batch in loader:
                X_batch = X_batch.to(self.device)
                t_batch = t_batch.to(self.device)
                e_batch = e_batch.to(self.device)
                val_loss += svi.evaluate_loss(X_batch, t_batch, e_batch)
        return val_loss / len(loader)
    
    def save_checkpoint(self, epoch, val_loss):
        checkpoint = {
            'epoch': epoch,
            'val_loss': val_loss,
            'param_store': pyro.get_param_store().get_state()
        }
        torch.save(checkpoint, 
                  os.path.join(self.config["results_dir"], f'best_model.pth'))

def cross_validate(model_class, X_tensor, time_tensor, event_tensor, config):
    device = config["device"]
    # Use CPU for KFold as sklearn expects numpy arrays
    kf = KFold(n_splits=config["n_folds"], shuffle=True, random_state=42)
    fold_results = []
    
    # Create array for KFold splitting (on CPU)
    X_np = X_tensor.cpu().numpy()
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(X_np)):
        print(f"\nTraining fold {fold + 1}/{config['n_folds']}")
        
        # Convert indices to torch tensors and move to GPU
        train_idx = torch.tensor(train_idx, device=device)
        val_idx = torch.tensor(val_idx, device=device)
        
        # Create data loaders for this fold
        train_loader = DataLoader(
            TensorDataset(X_tensor[train_idx], time_tensor[train_idx], event_tensor[train_idx]),
            batch_size=config["batch_size"],
            shuffle=True,
            pin_memory=False,  # Don't pin CUDA tensors
            generator=torch.Generator(device=device)  # Use CUDA generator for shuffling
        )
        
        val_loader = DataLoader(
            TensorDataset(X_tensor[val_idx], time_tensor[val_idx], event_tensor[val_idx]),
            batch_size=config["batch_size"],
            pin_memory=False,  # Don't pin CUDA tensors
            generator=torch.Generator(device=device)  # Use CUDA generator
        )
        
        # Train model
        model = model_class(config)
        losses, val_losses, param_trajectories = model.train(train_loader, val_loader)
        
        # Calculate metrics
        with torch.no_grad():
            beta = pyro.param("beta_loc").detach()
            risk_scores = X_tensor[val_idx] @ beta
            
            metrics = {
                'c_index': concordance_index(
                    time_tensor[val_idx].cpu().numpy(),
                    -risk_scores.cpu().numpy(),
                    event_tensor[val_idx].cpu().numpy()
                ),
                'final_loss': float(losses[-1]),
                'param_trajectories': param_trajectories
            }
            
            fold_results.append(metrics)
        
        # Clear GPU memory
        torch.cuda.empty_cache()
    
    return fold_results

def plot_diagnostics(results_dir, fold_results, losses, param_trajectories):
    # Create diagnostic plots directory
    plots_dir = os.path.join(results_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)
    
    # 1. ELBO convergence plot
    plt.figure(figsize=(10, 6))
    plt.plot(losses)
    plt.title('ELBO Loss Convergence')
    plt.xlabel('Epoch')
    plt.ylabel('ELBO Loss')
    plt.yscale('log')
    plt.savefig(os.path.join(plots_dir, 'elbo_convergence.png'))
    plt.close()
    
    # 2. Parameter trajectories
    for param_name, trajectories in param_trajectories.items():
        plt.figure(figsize=(12, 6))
        trajectories_array = np.array(trajectories)
        
        if len(trajectories_array.shape) == 2:
            for i in range(min(10, trajectories_array.shape[1])):
                plt.plot(trajectories_array[:, i], alpha=0.5, label=f'Dim {i}')
        else:
            plt.plot(trajectories_array)
            
        plt.title(f'{param_name} Convergence')
        plt.xlabel('Iteration')
        plt.ylabel('Value')
        plt.legend()
        plt.savefig(os.path.join(plots_dir, f'{param_name}_convergence.png'))
        plt.close()
    
    # 3. Cross-validation results
    c_indices = [result['c_index'] for result in fold_results]
    plt.figure(figsize=(8, 6))
    plt.boxplot(c_indices)
    plt.title('Cross-validation C-index Distribution')
    plt.ylabel('C-index')
    plt.savefig(os.path.join(plots_dir, 'cv_c_index.png'))
    plt.close()
    
    # 4. Feature importance and stability
    beta_samples = np.stack([result['param_trajectories']['beta_loc'][-1] 
                           for result in fold_results])
    
    plt.figure(figsize=(12, 6))
    sns.heatmap(beta_samples, cmap='coolwarm', center=0)
    plt.title('Feature Coefficient Stability Across Folds')
    plt.xlabel('Feature Index')
    plt.ylabel('Fold')
    plt.savefig(os.path.join(plots_dir, 'feature_stability.png'))
    plt.close()
    
    # 5. Correlation analysis
    plt.figure(figsize=(10, 10))
    corr_matrix = np.corrcoef(beta_samples.T)
    sns.heatmap(corr_matrix, cmap='coolwarm', center=0)
    plt.title('Feature Correlation Matrix')
    plt.savefig(os.path.join(plots_dir, 'feature_correlation.png'))
    plt.close()

def main():
    # Setup GPU
    setup_gpu()
    
    # Create results directory
    os.makedirs(config["results_dir"], exist_ok=True)
    
    # Save configuration
    with open(os.path.join(config["results_dir"], 'config.json'), 'w') as f:
        json.dump({k: str(v) if isinstance(v, (torch.dtype, type)) else v 
                  for k, v in config.items()}, f, indent=2)
    
    # Load and prepare data
    try:
        X_tensor, time_tensor, event_tensor = load_and_preprocess(config["data_path"], config)
    except Exception as e:
        print(f"Error in data preprocessing: {e}")
        return
    
    # Perform cross-validation
    try:
        fold_results = cross_validate(BayesianCoxModel, X_tensor, time_tensor, event_tensor, config)
        
        # Train final model on full dataset
        dataset = TensorDataset(X_tensor, time_tensor, event_tensor)
        train_loader = DataLoader(
            dataset,
            batch_size=config["batch_size"],
            shuffle=True,
            pin_memory=False,  # Don't pin CUDA tensors
            generator=torch.Generator(device=config["device"])  # Use CUDA generator
        )
        
        final_model = BayesianCoxModel(config)
        losses, _, param_trajectories = final_model.train(train_loader)
        
        # Plot diagnostics
        plot_diagnostics(config["results_dir"], fold_results, losses, param_trajectories)
        
        # Save results summary
        results_summary = {
            'mean_c_index': float(np.mean([r['c_index'] for r in fold_results])),
            'std_c_index': float(np.std([r['c_index'] for r in fold_results])),
            'final_loss': float(losses[-1]),
            'n_epochs': len(losses),
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'user': config.get('user', 'unknown')
        }
        
        with open(os.path.join(config["results_dir"], 'results_summary.json'), 'w') as f:
            json.dump(results_summary, f, indent=2)
            
    except Exception as e:
        print(f"Error during model training: {e}")
        torch.cuda.empty_cache()
        raise
    
    finally:
        # Clean up
        torch.cuda.empty_cache()

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nProcess interrupted by user")
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"\nError in main execution: {e}")
        torch.cuda.empty_cache()