In [1]:
# SPIKE-AND-SLAB-LASSO

In [None]:
pip install pyro-ppl torch pandas matplotlib numpy

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import ClippedAdam
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler
from lifelines.utils import concordance_index
from lifelines import KaplanMeierFitter, CoxPHFitter
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import KFold
import time
from datetime import datetime
import os
import json
import warnings
warnings.filterwarnings('ignore')

# Configuration
config = {
    # Training parameters
    "max_epochs": 5000,
    "batch_size": 128,
    "initial_lr": 5e-4,
    "lr_decay": 0.2,
    "decay_step": 500,
    "clip_norm": 10.0,
    "early_stop_patience": 200,
    
    # Device and precision
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "precision": torch.float32,
    "pin_memory": False if torch.cuda.is_available() else True,
    "num_workers": 0,
    "cuda_deterministic": False,
    
    # Model parameters
    "elbo_particles": 10,
    "warmup_epochs": 100,
    "n_folds": 5,
    
    # Spike-and-slab lasso parameters
    "max_risk_score": 50.0,
    "min_scale": 1e-10,
    "init_spike_prob": 0.5,
    "slab_df": 1.0,
    "lasso_scale": 1.0,
    
    # Logging and checkpointing
    "log_freq": 100,
    "checkpoint_freq": 500,
    
    # Paths
    "data_path": r"D:\cox-model-imputation\error-in-r-code-for-mcar\datasets\vb-cox\realistic_cox_data.csv",
    "results_dir": f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
}

class CoxPartialLikelihood(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-8
        self.max_exp = 20.0
        
    def forward(self, risk_scores, survival_time, event):
        event_mask = event == 1
        n_events = event_mask.sum()
        
        if n_events < 1:
            return torch.tensor(0.0, device=risk_scores.device)
        
        # Numerical stability: normalize risk scores
        max_risk = risk_scores.max()
        risk_scores = risk_scores - max_risk
        exp_risk = torch.exp(torch.clamp(risk_scores, min=-self.max_exp))
        
        # Compute risk matrix with stable operations
        risk_matrix = (survival_time.unsqueeze(0) >= 
                      survival_time[event_mask].unsqueeze(1)).float()
        
        # Compute cumulative hazard with numerical stability
        cumsum_exp_risk = torch.clamp(
            (risk_matrix * exp_risk.unsqueeze(0)).sum(dim=1),
            min=self.eps
        )
        
        log_sum_exp = torch.log(cumsum_exp_risk)
        
        # Compute partial likelihood
        events_risk_score = risk_scores[event_mask]
        partial_likelihood = (events_risk_score - log_sum_exp).sum()
        
        # Normalize by batch size for better scaling
        return partial_likelihood / risk_scores.size(0)

def model(X, survival_time, event):
    n, p = X.shape
    device = X.device
    
    # Global shrinkage parameter (for the slab component)
    global_scale = pyro.sample("global_scale", 
                             dist.HalfCauchy(torch.tensor(1.0, device=device)))
    
    # Spike probability
    spike_prob = pyro.sample("spike_prob",
                            dist.Beta(torch.tensor(1.0, device=device),
                                    torch.tensor(1.0, device=device)))
    
    with pyro.plate("features", p, dim=-1):
        # Binary indicators for spike-and-slab
        indicators = pyro.sample("indicators",
                               dist.Bernoulli(probs=spike_prob))
        
        # Local scale parameters (for the slab component)
        local_scale = pyro.sample("local_scale", 
                                dist.HalfCauchy(torch.tensor(1.0, device=device)))
        
        # Combine spike and slab components with numerical stability
        scale = (indicators * global_scale * local_scale + 
                (1 - indicators) * torch.tensor(1e-10, device=device))
        
        # Sample coefficients
        beta = pyro.sample("beta", 
                          dist.Laplace(torch.zeros(p, device=device), scale))
    
    # Compute risk scores with numerical stability
    risk_scores = torch.clamp(X @ beta, min=-50.0, max=50.0)
    
    # Compute likelihood with stable implementation
    cox_likelihood = CoxPartialLikelihood()
    log_pl = cox_likelihood(risk_scores, survival_time, event)
    pyro.factor("log_pl", log_pl)
    
    return beta

def guide(X, survival_time, event):
    n, p = X.shape
    device = X.device
    
    # Global scale parameters
    global_scale_loc = pyro.param(
        "global_scale_loc",
        torch.tensor(1.0, device=device),
        constraint=constraints.positive
    )
    global_scale_scale = pyro.param(
        "global_scale_scale",
        torch.tensor(0.1, device=device),
        constraint=constraints.positive
    )
    
    # Spike probability parameters
    spike_prob_alpha = pyro.param(
        "spike_prob_alpha",
        torch.tensor(1.0, device=device),
        constraint=constraints.positive
    )
    spike_prob_beta = pyro.param(
        "spike_prob_beta",
        torch.tensor(1.0, device=device),
        constraint=constraints.positive
    )
    
    # Sample global scale
    pyro.sample(
        "global_scale",
        dist.LogNormal(global_scale_loc.log(), global_scale_scale)
    )
    
    # Sample spike probability
    pyro.sample(
        "spike_prob",
        dist.Beta(spike_prob_alpha, spike_prob_beta)
    )
    
    with pyro.plate("features", p, dim=-1):
        # Indicator parameters
        indicator_probs = pyro.param(
            "indicator_probs",
            torch.ones(p, device=device) * config["init_spike_prob"],
            constraint=constraints.interval(0.0, 1.0)
        )
        
        # Local scale parameters
        local_scale_loc = pyro.param(
            "local_scale_loc",
            torch.ones(p, device=device),
            constraint=constraints.positive
        )
        local_scale_scale = pyro.param(
            "local_scale_scale",
            torch.ones(p, device=device) * 0.1,
            constraint=constraints.positive
        )
        
        # Beta parameters
        beta_loc = pyro.param(
            "beta_loc",
            torch.zeros(p, device=device)
        )
        beta_scale = pyro.param(
            "beta_scale",
            torch.ones(p, device=device) * config["lasso_scale"],
            constraint=constraints.positive
        )
        
        # Sample indicators
        pyro.sample("indicators", dist.Bernoulli(indicator_probs))
        
        # Sample local scales
        pyro.sample(
            "local_scale",
            dist.LogNormal(local_scale_loc.log(), local_scale_scale)
        )
        
        # Sample beta with Laplace distribution
        pyro.sample("beta", dist.Laplace(beta_loc, beta_scale))

def initialize_params():
    # This function ensures all parameters are registered before training
    device = config["device"]
    
    # Initialize with dummy tensors just to register parameters
    X_dummy = torch.zeros((10, 1000), device=device)
    time_dummy = torch.zeros(10, device=device)
    event_dummy = torch.zeros(10, device=device)
    
    # Run model and guide once with dummy data to register all parameters
    pyro.clear_param_store()
    try:
        model(X_dummy, time_dummy, event_dummy)
        guide(X_dummy, time_dummy, event_dummy)
    except Exception as e:
        print(f"Error initializing parameters: {e}")
        raise
    
    # Print registered parameters
    print("\nInitialized parameters:")
    for name, param in pyro.get_param_store().items():
        print(f"  {name}: {param.shape}")
    
    return True

#=========================
def setup_gpu():
    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        
        print("\nGPU Information:")
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        print(f"CUDA version: {torch.version.cuda}")
        print(f"Device: {torch.cuda.get_device_name(0)}")
        print(f"Device count: {torch.cuda.device_count()}")
        print(f"Current device: {torch.cuda.current_device()}")
        print(f"Initial GPU Memory: {torch.cuda.memory_allocated(0)/1024**2:.2f} MB allocated, {torch.cuda.memory_reserved(0)/1024**2:.2f} MB reserved")
        
        torch.backends.cudnn.deterministic = config["cuda_deterministic"]
        torch.backends.cudnn.benchmark = not config["cuda_deterministic"]
        
        return True
    else:
        print("\nNo GPU available, using CPU")
        return False

def load_and_preprocess(path, config):
    print(f"Loading data from: {path}")
    start_time = time.time()
    
    try:
        data = pd.read_csv(path)
        print(f"Data shape: {data.shape}")
    except Exception as e:
        print(f"Error loading data: {e}")
        raise
    
    X = data.iloc[:, 1:-2].values
    time_values = data['time'].values
    event_values = data['event'].values
    
    X = (X - np.nanmean(X, axis=0)) / (np.nanstd(X, axis=0) + 1e-8)
    X = np.nan_to_num(X)
    
    X_tensor = torch.tensor(X, dtype=config["precision"]).to(config["device"])
    time_tensor = torch.tensor(time_values, dtype=config["precision"]).to(config["device"])
    event_tensor = torch.tensor(event_values, dtype=config["precision"]).to(config["device"])
    
    print(f"Data loading time: {time.time()-start_time:.2f}s")
    print(f"Data loaded on: {X_tensor.device}")
    print(f"Number of features: {X_tensor.shape[1]}")
    print(f"Number of samples: {X_tensor.shape[0]}")
    print(f"Event rate: {event_tensor.mean().item():.2%}")
    
    return X_tensor, time_tensor, event_tensor

class BayesianCoxModel:
    def __init__(self, config):
        self.config = config
        self.device = config["device"]
        pyro.clear_param_store()
        initialize_params()

    def train(self, train_loader, val_loader=None):
        optimizer = ClippedAdam({
            "lr": self.config["initial_lr"],
            "clip_norm": self.config["clip_norm"],
            "weight_decay": 1e-4
        })
        
        svi = SVI(model, guide, optimizer, 
                 loss=Trace_ELBO(num_particles=self.config["elbo_particles"]))
        
        best_loss = float('inf')
        losses = []
        val_losses = []
        patience = 0
        param_trajectories = {}
        
        for name in pyro.get_param_store().keys():
            param_trajectories[name] = []
        
        try:
            for epoch in range(self.config["max_epochs"]):
                epoch_loss = 0.0
                for batch_idx, (X_batch, t_batch, e_batch) in enumerate(train_loader):
                    X_batch = X_batch.to(self.device)
                    t_batch = t_batch.to(self.device)
                    e_batch = e_batch.to(self.device)
                    
                    loss = svi.step(X_batch, t_batch, e_batch)
                    epoch_loss += loss
                
                avg_loss = epoch_loss / len(train_loader)
                losses.append(avg_loss)
                
                # Store parameter trajectories
                for name, param in pyro.get_param_store().items():
                    param_trajectories[name].append(param.detach().cpu().numpy())
                
                if val_loader is not None:
                    val_loss = self.evaluate(svi, val_loader)
                    val_losses.append(val_loss)
                    
                    if val_loss < best_loss:
                        best_loss = val_loss
                        patience = 0
                        self.save_checkpoint(epoch, val_loss)
                    else:
                        patience += 1
                    
                    if patience >= self.config["early_stop_patience"]:
                        print(f"Early stopping at epoch {epoch}")
                        break
                
                if epoch % self.config["decay_step"] == 0 and epoch > 0:
                    current_lr = self.config["initial_lr"] * \
                               (self.config["lr_decay"] ** (epoch // self.config["decay_step"]))
                    optimizer.set_state({'lr': current_lr})
                    
                if epoch % self.config["log_freq"] == 0:
                    print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
                    if val_loader is not None:
                        print(f"Validation Loss = {val_losses[-1]:.4f}")
        
        except KeyboardInterrupt:
            print("\nTraining interrupted by user")
        except Exception as e:
            print(f"Error during training: {e}")
            raise
        
        return losses, val_losses, param_trajectories
    
    def evaluate(self, svi, loader):
        val_loss = 0.0
        with torch.no_grad():
            for X_batch, t_batch, e_batch in loader:
                X_batch = X_batch.to(self.device)
                t_batch = t_batch.to(self.device)
                e_batch = e_batch.to(self.device)
                val_loss += svi.evaluate_loss(X_batch, t_batch, e_batch)
        return val_loss / len(loader)
    
    def save_checkpoint(self, epoch, val_loss):
        checkpoint = {
            'epoch': epoch,
            'val_loss': val_loss,
            'param_store': pyro.get_param_store().get_state()
        }
        torch.save(checkpoint, 
                  os.path.join(self.config["results_dir"], f'best_model.pth'))

def cross_validate(model_class, X_tensor, time_tensor, event_tensor, config):
    device = config["device"]
    kf = KFold(n_splits=config["n_folds"], shuffle=True, random_state=42)
    fold_results = []
    
    X_np = X_tensor.cpu().numpy()
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(X_np)):
        print(f"\nTraining fold {fold + 1}/{config['n_folds']}")
        
        train_idx = torch.tensor(train_idx, device=device)
        val_idx = torch.tensor(val_idx, device=device)
        
        train_loader = DataLoader(
            TensorDataset(X_tensor[train_idx], time_tensor[train_idx], event_tensor[train_idx]),
            batch_size=config["batch_size"],
            shuffle=True,
            pin_memory=False,
            generator=torch.Generator(device=device)
        )
        
        val_loader = DataLoader(
            TensorDataset(X_tensor[val_idx], time_tensor[val_idx], event_tensor[val_idx]),
            batch_size=config["batch_size"],
            pin_memory=False,
            generator=torch.Generator(device=device)
        )
        
        model = model_class(config)
        losses, val_losses, param_trajectories = model.train(train_loader, val_loader)
        
        with torch.no_grad():
            beta = pyro.param("beta_loc").detach()
            indicators = pyro.param("indicator_probs").detach()
            risk_scores = X_tensor[val_idx] @ beta
            
            metrics = {
                'c_index': concordance_index(
                    time_tensor[val_idx].cpu().numpy(),
                    -risk_scores.cpu().numpy(),
                    event_tensor[val_idx].cpu().numpy()
                ),
                'final_loss': float(losses[-1]),
                'param_trajectories': param_trajectories,
                'val_losses': val_losses,
                'last_epoch': len(losses),
                'selected_features': (indicators > 0.5).float().cpu().numpy()
            }
            
            fold_results.append(metrics)
        
        torch.cuda.empty_cache()
    
    return fold_results

def plot_diagnostics(results_dir, fold_results, losses, val_losses_list, param_trajectories):
    plots_dir = os.path.join(results_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)
    
    # Plot training and validation losses
    plt.figure(figsize=(12, 7))
    plt.plot(losses, label='Training Loss', color='blue', linewidth=2)
    
    if val_losses_list:
        for fold_idx, val_losses in enumerate(val_losses_list):
            plt.plot(val_losses, label=f'Validation Loss (Fold {fold_idx+1})', 
                     linestyle='--', alpha=0.7)
    
    plt.title('Training and Validation Loss Convergence', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('ELBO Loss', fontsize=12)
    plt.yscale('log')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'loss_convergence.png'), dpi=300)
    plt.close()
    
    # Plot feature selection probabilities
    if 'indicator_probs' in param_trajectories:
        plt.figure(figsize=(12, 7))
        final_probs = param_trajectories['indicator_probs'][-1]
        plt.hist(final_probs, bins=50, alpha=0.7)
        plt.title('Feature Selection Probabilities Distribution', fontsize=14)
        plt.xlabel('Selection Probability', fontsize=12)
        plt.ylabel('Count', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, 'feature_selection_dist.png'), dpi=300)
        plt.close()
    
    # Plot selected features heatmap
    selected_features = np.stack([result['selected_features'] for result in fold_results])
    plt.figure(figsize=(14, 10))
    sns.heatmap(selected_features, cmap='coolwarm', center=0.5,
                xticklabels=False, yticklabels=range(1, len(fold_results) + 1))
    plt.title('Selected Features Across Folds', fontsize=14)
    plt.xlabel('Feature Index', fontsize=12)
    plt.ylabel('Fold', fontsize=12)
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'selected_features.png'), dpi=300)
    plt.close()
    
    # Plot C-index distribution
    c_indices = [result['c_index'] for result in fold_results]
    plt.figure(figsize=(10, 7))
    plt.violinplot(c_indices, showmeans=True, showmedians=True)
    plt.plot([1] * len(c_indices), c_indices, 'o', color='blue', alpha=0.7)
    mean_c = np.mean(c_indices)
    std_c = np.std(c_indices)
    plt.text(1.2, min(c_indices), 
             f'Mean: {mean_c:.4f}\nStd: {std_c:.4f}', 
             fontsize=12, bbox=dict(facecolor='white', alpha=0.7))
    plt.title('Cross-validation C-index Distribution', fontsize=14)
    plt.ylabel('C-index', fontsize=12)
    plt.xticks([1], ['C-index'])
    plt.grid(True, linestyle='--', alpha=0.7, axis='y')
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'cv_c_index.png'), dpi=300)
    plt.close()

def main():
    setup_gpu()
    os.makedirs(config["results_dir"], exist_ok=True)
    
    with open(os.path.join(config["results_dir"], 'config.json'), 'w') as f:
        json.dump({k: str(v) if isinstance(v, (torch.dtype, type)) else v 
                  for k, v in config.items()}, f, indent=2)
    
    try:
        X_tensor, time_tensor, event_tensor = load_and_preprocess(config["data_path"], config)
        fold_results = cross_validate(BayesianCoxModel, X_tensor, time_tensor, event_tensor, config)
        
        all_val_losses = [result['val_losses'] for result in fold_results if 'val_losses' in result]
        cv_stopping_epochs = [result['last_epoch'] for result in fold_results if 'last_epoch' in result]
        
        optimal_epochs = int(np.mean(cv_stopping_epochs)) if cv_stopping_epochs else 1700
        print(f"\nUsing {optimal_epochs} epochs for final model based on cross-validation")
        
        final_config = config.copy()
        final_config["max_epochs"] = optimal_epochs
        
        dataset = TensorDataset(X_tensor, time_tensor, event_tensor)
        train_loader = DataLoader(
            dataset,
            batch_size=config["batch_size"],
            shuffle=True,
            pin_memory=False,
            generator=torch.Generator(device=config["device"])
        )
        
        print(f"\nTraining final model on full dataset for {optimal_epochs} epochs")
        final_model = BayesianCoxModel(final_config)
        losses, _, param_trajectories = final_model.train(train_loader)
        
        plot_diagnostics(config["results_dir"], fold_results, losses, all_val_losses, param_trajectories)
        
        results_summary = {
            'mean_c_index': float(np.mean([r['c_index'] for r in fold_results])),
            'std_c_index': float(np.std([r['c_index'] for r in fold_results])),
            'final_loss': float(losses[-1]),
            'n_epochs_final': len(losses),
            'optimal_epochs': optimal_epochs,
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        }
        
        with open(os.path.join(config["results_dir"], 'results_summary.json'), 'w') as f:
            json.dump(results_summary, f, indent=2)
            
    except Exception as e:
        print(f"Error during model training: {e}")
        torch.cuda.empty_cache()
        raise
    
    finally:
        # Clean up
        torch.cuda.empty_cache()
        print("\nTraining completed")

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nProcess interrupted by user")
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"\nError in main execution: {e}")
        torch.cuda.empty_cache()
    finally:
        print("\nExecution finished")


GPU Information:
PyTorch version: 2.4.1+cu121
CUDA available: True
CUDA version: 12.1
Device: NVIDIA GeForce GTX 1650
Device count: 1
Current device: 0
Initial GPU Memory: 16.25 MB allocated, 42.00 MB reserved
Loading data from: D:\cox-model-imputation\error-in-r-code-for-mcar\datasets\vb-cox\realistic_cox_data.csv
Data shape: (800, 1003)
Data loading time: 0.29s
Data loaded on: cuda:0
Number of features: 1000
Number of samples: 800
Event rate: 58.88%

Training fold 1/5

Initialized parameters:
  global_scale_loc: torch.Size([])
  global_scale_scale: torch.Size([])
  spike_prob_alpha: torch.Size([])
  spike_prob_beta: torch.Size([])
  indicator_probs: torch.Size([1000])
  local_scale_loc: torch.Size([1000])
  local_scale_scale: torch.Size([1000])
  beta_loc: torch.Size([1000])
  beta_scale: torch.Size([1000])
Epoch 0: Loss = 4976091993944.2617
Validation Loss = 4920644808012.4580
Epoch 100: Loss = 3407639400252.8369
Validation Loss = 3415521334237.0576
Epoch 200: Loss = 2276986792789.

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import ClippedAdam
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler
from lifelines.utils import concordance_index
from lifelines import KaplanMeierFitter, CoxPHFitter
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import KFold
import time
from datetime import datetime
import os
import json
import warnings
warnings.filterwarnings('ignore')

# # Configuration
# config = {
#     # Training parameters
#     "max_epochs": 800,
#     "batch_size": 32,  # Reduced batch size for better stability
#     "initial_lr": 1e-3,  # Increased learning rate for faster initial convergence
#     "lr_decay": 0.5,   # More aggressive decay for better fine-tuning
#     "decay_step": 100,  # More frequent decay steps
#     "clip_norm": 0.1,  # Reduced clip norm for better gradient stability
#     "early_stop_patience": 50,
    
#     # Device
#     "device": "cuda" if torch.cuda.is_available() else "cpu",
#     "precision": torch.float32,
    
#     # Model parameters
#     "elbo_particles": 5,
#     "n_folds": 5,
    
#     # Regularization parameters
#     "max_risk_score": 3.0,  # Reduced max risk score for better numerical stability
#     "min_scale": 1e-4,  # Increased min scale to prevent underflow
#     "init_spike_prob": 0.1,  # Increased initial probability for better feature selection
#     "global_scale_prior": 0.1,  # Increased prior scale for better initialization
#     "lasso_scale": 0.1,  # Increased lasso scale for smoother optimization
    
#     # Paths
#     "data_path": r"D:\cox-model-imputation\error-in-r-code-for-mcar\datasets\vb-cox\realistic_cox_data.csv",
#     "results_dir": f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
# }

# Configuration
config = {
    # Training parameters
    "max_epochs": 800,
    "batch_size": 32,
    "initial_lr": 1e-3,
    "lr_decay": 0.5,
    "decay_step": 100,
    "clip_norm": 0.1,
    "early_stop_patience": 50,
    "log_freq": 50,  # Add logging frequency parameter
    
    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "precision": torch.float32,
    
    # Model parameters
    "elbo_particles": 5,
    "n_folds": 5,
    
    # Regularization parameters
    "max_risk_score": 3.0,
    "min_scale": 1e-4,
    "init_spike_prob": 0.1,
    "global_scale_prior": 0.1,
    "lasso_scale": 0.1,
    
    # Paths
    "data_path": r"D:\cox-model-imputation\error-in-r-code-for-mcar\datasets\vb-cox\realistic_cox_data.csv",
    "results_dir": f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
}

class CoxPartialLikelihood(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-8
        self.max_exp = 20.0
        
    def forward(self, risk_scores, survival_time, event):
        event_mask = event == 1
        n_events = event_mask.sum()
        
        if n_events < 1:
            return torch.tensor(0.0, device=risk_scores.device)
        
        risk_scores = risk_scores - risk_scores.mean()
        risk_scores = torch.clamp(risk_scores, min=-self.max_exp, max=self.max_exp)
        
        sorted_times, indices = torch.sort(survival_time, descending=True)
        risk_scores = risk_scores[indices]
        event_mask = event_mask[indices]
        
        risk_scores_exp = torch.exp(risk_scores)
        cumsum_risk = torch.cumsum(risk_scores_exp, dim=0)
        log_cumsum_risk = torch.log(cumsum_risk + self.eps)
        
        event_likelihood = risk_scores[event_mask] - log_cumsum_risk[event_mask]
        
        # Remove the division by n_events
        return event_likelihood.sum()

def model(X, survival_time, event):
    n, p = X.shape
    device = X.device
    
    # More conservative scale for global shrinkage
    global_scale = pyro.sample("global_scale", 
                             dist.HalfCauchy(torch.tensor(0.1, device=device)))
    
    # More informative prior for spike probability
    spike_prob = pyro.sample("spike_prob",
                            dist.Beta(torch.tensor(2.0, device=device),
                                    torch.tensor(2.0, device=device)))
    
    with pyro.plate("features", p, dim=-1):
        # Binary indicators for spike-and-slab
        indicators = pyro.sample("indicators",
                               dist.Bernoulli(probs=spike_prob))
        
        # More conservative scale for local shrinkage
        local_scale = pyro.sample("local_scale", 
                                dist.HalfCauchy(torch.tensor(0.1, device=device)))
        
        # Combine spike and slab components with better numerical stability
        scale = torch.clamp(
            indicators * global_scale * local_scale + 
            (1 - indicators) * config["min_scale"],
            min=config["min_scale"],
            max=1.0
        )
        
        # Use Normal distribution instead of Laplace for better stability
        beta = pyro.sample("beta", 
                          dist.Normal(torch.zeros(p, device=device), scale))
    
    # Compute risk scores with better numerical stability
    risk_scores = torch.clamp(X @ beta, min=-config["max_risk_score"], max=config["max_risk_score"])
    
    # Compute likelihood
    cox_likelihood = CoxPartialLikelihood()
    log_pl = cox_likelihood(risk_scores, survival_time, event)
    pyro.factor("log_pl", log_pl)
    
    return beta

def guide(X, survival_time, event):
    n, p = X.shape
    device = X.device
    
    # More conservative initialization for global scale
    global_scale_loc = pyro.param(
        "global_scale_loc",
        torch.tensor(0.1, device=device),
        constraint=constraints.positive
    )
    global_scale_scale = pyro.param(
        "global_scale_scale",
        torch.tensor(0.01, device=device),
        constraint=constraints.positive
    )
    
    # More informative initialization for spike probability
    spike_prob_alpha = pyro.param(
        "spike_prob_alpha",
        torch.tensor(2.0, device=device),
        constraint=constraints.positive
    )
    spike_prob_beta = pyro.param(
        "spike_prob_beta",
        torch.tensor(2.0, device=device),
        constraint=constraints.positive
    )
    
    pyro.sample(
        "global_scale",
        dist.LogNormal(global_scale_loc.log(), global_scale_scale)
    )
    
    pyro.sample(
        "spike_prob",
        dist.Beta(spike_prob_alpha, spike_prob_beta)
    )
    
    with pyro.plate("features", p, dim=-1):
        # More conservative initialization for indicators
        indicator_probs = pyro.param(
            "indicator_probs",
            torch.ones(p, device=device) * 0.1,  # Start with sparse model
            constraint=constraints.interval(0.0, 1.0)
        )
        
        # More conservative initialization for local scales
        local_scale_loc = pyro.param(
            "local_scale_loc",
            torch.ones(p, device=device) * 0.1,
            constraint=constraints.positive
        )
        local_scale_scale = pyro.param(
            "local_scale_scale",
            torch.ones(p, device=device) * 0.01,
            constraint=constraints.positive
        )
        
        beta_loc = pyro.param(
            "beta_loc",
            torch.zeros(p, device=device)
        )
        beta_scale = pyro.param(
            "beta_scale",
            torch.ones(p, device=device) * 0.1,
            constraint=constraints.positive
        )
        
        pyro.sample("indicators", dist.Bernoulli(indicator_probs))
        
        pyro.sample(
            "local_scale",
            dist.LogNormal(local_scale_loc.log(), local_scale_scale)
        )
        
        # Use Normal distribution instead of Laplace
        pyro.sample("beta", dist.Normal(beta_loc, beta_scale))

def initialize_params():
    # This function ensures all parameters are registered before training
    device = config["device"]
    
    # Initialize with dummy tensors just to register parameters
    X_dummy = torch.zeros((10, 1000), device=device)
    time_dummy = torch.zeros(10, device=device)
    event_dummy = torch.zeros(10, device=device)
    
    # Run model and guide once with dummy data to register all parameters
    pyro.clear_param_store()
    try:
        model(X_dummy, time_dummy, event_dummy)
        guide(X_dummy, time_dummy, event_dummy)
    except Exception as e:
        print(f"Error initializing parameters: {e}")
        raise
    
    # Print registered parameters
    print("\nInitialized parameters:")
    for name, param in pyro.get_param_store().items():
        print(f"  {name}: {param.shape}")
    
    return True

#=========================
def setup_gpu():
    if torch.cuda.is_available():
        # Set deterministic settings using the new PyTorch 2.x syntax
        torch.backends.cudnn.benchmark = False
        torch.backends.cuda.deterministic = True  # Changed from cuda_deterministic
        torch.backends.cuda.matmul.allow_tf32 = False  # Disable TF32 for deterministic results
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        
        # Print GPU information
        print(f"\nGPU: {torch.cuda.get_device_name(0)}")
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA version: {torch.version.cuda}")
        print(f"Current device: {torch.cuda.current_device()}")
    else:
        print("\nUsing CPU")

def load_and_preprocess(path):
    print(f"\nLoading data from: {path}")
    data = pd.read_csv(path)
    print(f"Data shape: {data.shape}")
    
    X = data.iloc[:, 1:-2].values.astype(np.float32)
    time_values = data['time'].values.astype(np.float32)
    event_values = data['event'].values.astype(np.float32)
    
    X = (X - np.nanmean(X, axis=0)) / (np.nanstd(X, axis=0) + 1e-8)
    X = np.nan_to_num(X)
    
    # Move all data to GPU upfront
    X_tensor = torch.tensor(X, dtype=config["precision"]).to(config["device"])
    time_tensor = torch.tensor(time_values, dtype=config["precision"]).to(config["device"])
    event_tensor = torch.tensor(event_values, dtype=config["precision"]).to(config["device"])
    
    print(f"Number of features: {X_tensor.shape[1]}")
    print(f"Number of samples: {X_tensor.shape[0]}")
    print(f"Event rate: {event_tensor.float().mean().item():.2%}")
    
    return X_tensor, time_tensor, event_tensor

class BayesianCoxModel:
    def __init__(self):
        self.device = config["device"]
        pyro.clear_param_store()
        self._initialize_params()

    def _initialize_params(self):
        X_dummy = torch.zeros((2, 1000), device=self.device, dtype=config["precision"])
        time_dummy = torch.zeros(2, device=self.device, dtype=config["precision"])
        event_dummy = torch.zeros(2, device=self.device, dtype=config["precision"])
        
        model(X_dummy, time_dummy, event_dummy)
        guide(X_dummy, time_dummy, event_dummy)

    def train(self, train_loader, val_loader=None):
        optimizer = ClippedAdam({
            "lr": config["initial_lr"],
            "clip_norm": config["clip_norm"],
            "weight_decay": 1e-4,
            "betas": (0.9, 0.999)
        })
        
        svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=config["elbo_particles"]))
        
        best_loss = float('inf')
        losses = []
        val_losses = []
        patience = 0
        
        try:
            for epoch in range(config["max_epochs"]):
                epoch_loss = 0.0
                for batch in train_loader:
                    X_batch, t_batch, e_batch = [b.to(self.device) for b in batch]
                    
                    # Scale the batch data
                    X_batch = (X_batch - X_batch.mean(0)) / (X_batch.std(0) + 1e-8)
                    
                    loss = svi.step(X_batch, t_batch, e_batch)
                    epoch_loss += loss
                
                avg_loss = epoch_loss / len(train_loader)
                losses.append(avg_loss)
                
                if val_loader:
                    val_loss = self._evaluate(val_loader)
                    val_losses.append(val_loss)
                    
                    if val_loss < best_loss:
                        best_loss = val_loss
                        patience = 0
                    else:
                        patience += 1
                    
                    if patience >= config["early_stop_patience"]:
                        print(f"\nEarly stopping at epoch {epoch}")
                        break
                
                if epoch % config["decay_step"] == 0 and epoch > 0:
                    optimizer.set_state({'lr': config["initial_lr"] * (config["lr_decay"] ** (epoch // config["decay_step"]))})
                
                if epoch % config["log_freq"] == 0:  # Use log_freq from config
                    print(f"Epoch {epoch:4d} | Train Loss: {avg_loss:.2f}" + 
                          (f" | Val Loss: {val_loss:.2f}" if val_loader else ""))
        
        except KeyboardInterrupt:
            print("\nTraining stopped by user")
        except Exception as e:
            print(f"\nError during training: {str(e)}")
        
        return losses, val_losses

    def _evaluate(self, loader):
        total_loss = 0.0
        try:
            with torch.no_grad():
                for batch in loader:
                    X_batch, t_batch, e_batch = [b.to(self.device) for b in batch]
                    # Scale the batch data
                    X_batch = (X_batch - X_batch.mean(0)) / (X_batch.std(0) + 1e-8)
                    total_loss += pyro.infer.Trace_ELBO(num_particles=5).loss(
                        model, guide, X_batch, t_batch, e_batch
                    )
            return total_loss / len(loader)
        except Exception as e:
            print(f"\nError during evaluation: {str(e)}")
            return float('inf')
            
    def save_checkpoint(self, epoch, val_loss):
        checkpoint = {
            'epoch': epoch,
            'val_loss': val_loss,
            'param_store': pyro.get_param_store().get_state()
        }
        torch.save(checkpoint, 
                  os.path.join(self.config["results_dir"], f'best_model.pth'))

def cross_validate(X, time, event):
    kf = KFold(n_splits=config["n_folds"], shuffle=True, random_state=42)
    fold_results = []
    device = config["device"]
    
    # Keep original data on GPU
    X = X.to(device)
    time = time.to(device)
    event = event.to(device)
    
    # Create CUDA generator
    generator = torch.Generator(device=device)
    generator.manual_seed(42)
    
    for fold, (train_idx, test_idx) in enumerate(kf.split(X.cpu().numpy())):
        print(f"\nTraining fold {fold+1}/{config['n_folds']}")
        try:
            # Convert indices to GPU tensors
            train_idx = torch.tensor(train_idx, device=device)
            test_idx = torch.tensor(test_idx, device=device)
            
            # Create subsets directly on GPU
            X_train = X[train_idx]
            time_train = time[train_idx]
            event_train = event[train_idx]
            
            X_test = X[test_idx]
            time_test = time[test_idx]
            event_test = event[test_idx]
            
            # Create datasets
            train_dataset = TensorDataset(X_train, time_train, event_train)
            test_dataset = TensorDataset(X_test, time_test, event_test)
            
            # Create data loaders without generator
            train_loader = DataLoader(
                train_dataset,
                batch_size=config["batch_size"],
                shuffle=True,
                pin_memory=False
            )
            
            test_loader = DataLoader(
                test_dataset,
                batch_size=config["batch_size"],
                shuffle=False,
                pin_memory=False
            )
            
            # Initialize and train model
            model_instance = BayesianCoxModel()
            train_losses, val_losses = model_instance.train(train_loader, test_loader)
            
            # Compute metrics
            with torch.no_grad():
                beta = pyro.param("beta_loc")
                risk_scores = X_test @ beta
                
                c_index = concordance_index(
                    time_test.cpu().numpy(),
                    -risk_scores.cpu().numpy(),
                    event_test.cpu().numpy()
                )
                
                probs = pyro.param("indicator_probs").cpu().numpy()
            
            fold_results.append({
                'c_index': c_index,
                'selected_features': (probs > 0.5).astype(int),
                'train_loss': train_losses,
                'val_loss': val_losses,
                'last_epoch': len(train_losses)
            })
            
            print(f"Fold {fold+1} completed - C-index: {c_index:.3f}")
            
        except Exception as e:
            print(f"\nError during fold {fold+1}: {str(e)}")
            continue
    
    return fold_results
    
def plot_diagnostics(results_dir, fold_results, losses, val_losses_list, param_trajectories):
    plots_dir = os.path.join(results_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)
    
    # Plot training and validation losses
    plt.figure(figsize=(12, 7))
    plt.plot(losses, label='Training Loss', color='blue', linewidth=2)
    
    if val_losses_list:
        for fold_idx, val_losses in enumerate(val_losses_list):
            plt.plot(val_losses, label=f'Validation Loss (Fold {fold_idx+1})', 
                     linestyle='--', alpha=0.7)
    
    plt.title('Training and Validation Loss Convergence', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('ELBO Loss', fontsize=12)
    plt.yscale('log')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'loss_convergence.png'), dpi=300)
    plt.close()
    
    # Plot feature selection probabilities
    if 'indicator_probs' in param_trajectories:
        plt.figure(figsize=(12, 7))
        final_probs = param_trajectories['indicator_probs'][-1]
        plt.hist(final_probs, bins=50, alpha=0.7)
        plt.title('Feature Selection Probabilities Distribution', fontsize=14)
        plt.xlabel('Selection Probability', fontsize=12)
        plt.ylabel('Count', fontsize=12)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, 'feature_selection_dist.png'), dpi=300)
        plt.close()
    
    # Plot selected features heatmap
    selected_features = np.stack([result['selected_features'] for result in fold_results])
    plt.figure(figsize=(14, 10))
    sns.heatmap(selected_features, cmap='coolwarm', center=0.5,
                xticklabels=False, yticklabels=range(1, len(fold_results) + 1))
    plt.title('Selected Features Across Folds', fontsize=14)
    plt.xlabel('Feature Index', fontsize=12)
    plt.ylabel('Fold', fontsize=12)
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'selected_features.png'), dpi=300)
    plt.close()
    
    # Plot C-index distribution
    c_indices = [result['c_index'] for result in fold_results]
    plt.figure(figsize=(10, 7))
    plt.violinplot(c_indices, showmeans=True, showmedians=True)
    plt.plot([1] * len(c_indices), c_indices, 'o', color='blue', alpha=0.7)
    mean_c = np.mean(c_indices)
    std_c = np.std(c_indices)
    plt.text(1.2, min(c_indices), 
             f'Mean: {mean_c:.4f}\nStd: {std_c:.4f}', 
             fontsize=12, bbox=dict(facecolor='white', alpha=0.7))
    plt.title('Cross-validation C-index Distribution', fontsize=14)
    plt.ylabel('C-index', fontsize=12)
    plt.xticks([1], ['C-index'])
    plt.grid(True, linestyle='--', alpha=0.7, axis='y')
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'cv_c_index.png'), dpi=300)
    plt.close()

def main():
    setup_gpu()
    os.makedirs(config["results_dir"], exist_ok=True)
    
    # Save configuration
    with open(os.path.join(config["results_dir"], 'config.json'), 'w') as f:
        json.dump({k: str(v) if isinstance(v, (torch.dtype, type)) else v 
                  for k, v in config.items()}, f, indent=2)
    
    try:
        print(f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Current User's Login: {os.getlogin()}")
        
        # Load and preprocess data
        X_tensor, time_tensor, event_tensor = load_and_preprocess(config["data_path"])
        
        # Run cross-validation
        fold_results = cross_validate(X_tensor, time_tensor, event_tensor)
        
        # Process cross-validation results
        all_val_losses = [result['val_losses'] for result in fold_results if 'val_losses' in result]
        cv_stopping_epochs = [result['last_epoch'] for result in fold_results if 'last_epoch' in result]
        
        optimal_epochs = int(np.mean(cv_stopping_epochs)) if cv_stopping_epochs else config["max_epochs"]
        print(f"\nUsing {optimal_epochs} epochs for final model based on cross-validation")
        
        # Update config for final model
        final_config = config.copy()
        final_config["max_epochs"] = optimal_epochs
        
        # Create dataset for final training
        dataset = TensorDataset(X_tensor, time_tensor, event_tensor)
        train_loader = DataLoader(
            dataset,
            batch_size=config["batch_size"],
            shuffle=True,
            pin_memory=False  # Removed generator
        )
        
        # Train final model
        print(f"\nTraining final model on full dataset for {optimal_epochs} epochs")
        final_model = BayesianCoxModel()
        losses, _ = final_model.train(train_loader)
        
        # Save results
        results_summary = {
            'mean_c_index': float(np.mean([r['c_index'] for r in fold_results])),
            'std_c_index': float(np.std([r['c_index'] for r in fold_results])),
            'final_loss': float(losses[-1]) if losses else float('nan'),
            'n_epochs_final': len(losses),
            'optimal_epochs': optimal_epochs,
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        }
        
        with open(os.path.join(config["results_dir"], 'results_summary.json'), 'w') as f:
            json.dump(results_summary, f, indent=2)
            
    except Exception as e:
        print(f"Error during model training: {e}")
        torch.cuda.empty_cache()
        raise
    
    finally:
        # Clean up
        torch.cuda.empty_cache()
        print("\nTraining completed")



if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nProcess interrupted by user")
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"\nError in main execution: {e}")
        torch.cuda.empty_cache()
    finally:
        print("\nExecution finished")

  from .autonotebook import tqdm as notebook_tqdm



GPU: NVIDIA GeForce GTX 1650
PyTorch version: 2.4.1+cu121
CUDA version: 12.1
Current device: 0
Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): 2025-03-10 23:29:37
Current User's Login: CHIKOMANA

Loading data from: D:\cox-model-imputation\error-in-r-code-for-mcar\datasets\vb-cox\realistic_cox_data.csv
Data shape: (800, 1003)
Number of features: 1000
Number of samples: 800
Event rate: 58.88%

Training fold 1/5
Epoch    0 | Train Loss: 442121988.98 | Val Loss: 433382492.14
Epoch   50 | Train Loss: 51391597.47 | Val Loss: 49756325.93
Epoch  100 | Train Loss: 5228408.21 | Val Loss: 5122794.27
Epoch  150 | Train Loss: 595467.10 | Val Loss: 585467.14
Epoch  200 | Train Loss: 121392.58 | Val Loss: 123495.56
Epoch  250 | Train Loss: 45958.49 | Val Loss: 44624.07
Epoch  300 | Train Loss: 27288.88 | Val Loss: 27833.89
Epoch  350 | Train Loss: 22884.63 | Val Loss: 21726.77
Epoch  400 | Train Loss: 21106.68 | Val Loss: 22002.78
Epoch  450 | Train Loss: 21732.67 | Val Loss: 22513.28



In [None]:
import pandas as pd
import numpy as np
import torch
from datetime import datetime
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import TensorDataset, DataLoader
from lifelines.utils import concordance_index
from sklearn.model_selection import KFold
import os
import json
import warnings
from torch.distributions import constraints  # Changed this line
warnings.filterwarnings('ignore')

# Rest of the code remains the same...

# Configuration
config = {
    # Training parameters
    "max_epochs": 800,
    "batch_size": 64,
    "initial_lr": 1e-5,
    "lr_decay": 0.1,
    "decay_step": 200,
    "clip_norm": 1.0,
    "early_stop_patience": 50,
    
    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "precision": torch.float32,
    
    # Model parameters
    "elbo_particles": 5,
    "n_folds": 5,
    
    # Regularization parameters
    "max_risk_score": 5.0,
    "min_scale": 1e-6,
    "init_spike_prob": 0.01,
    "global_scale_prior": 0.01,
    "lasso_scale": 0.01,
    
    # Paths
    "data_path": r"D:\cox-model-imputation\error-in-r-code-for-mcar\datasets\vb-cox\realistic_cox_data.csv",
    "results_dir": f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
}

class CoxPartialLikelihood(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-8
        self.max_exp = 5.0
        
    def forward(self, risk_scores, survival_time, event):
        event_mask = event == 1
        n_events = event_mask.sum()
        
        if n_events < 1:
            return torch.tensor(0.0, device=risk_scores.device)
        
        # Normalize risk scores
        risk_scores = (risk_scores - risk_scores.mean()) / (risk_scores.std() + self.eps)
        risk_scores = torch.clamp(risk_scores, -self.max_exp, self.max_exp)
        
        sorted_times, indices = torch.sort(survival_time, descending=True)
        risk_scores = risk_scores[indices]
        event_mask = event_mask[indices]

        with torch.no_grad():
            max_rs = risk_scores.max()
        
        risk_exp = torch.exp(risk_scores - max_rs)
        cumsum_risk = torch.cumsum(risk_exp, 0)
        log_cumsum = torch.log(cumsum_risk + self.eps)

        pl = risk_scores[event_mask] - log_cumsum[event_mask]
        return -pl.mean()  # Negative log-likelihood

def model(X, survival_time, event):
    n, p = X.shape
    device = X.device
    
    # Move scalar tensors to device
    global_scale = pyro.sample("global_scale", 
                             dist.HalfNormal(torch.tensor(config["global_scale_prior"], device=device)))
    spike_prob = pyro.sample("spike_prob", 
                           dist.Beta(torch.tensor(1.05, device=device), 
                                   torch.tensor(10.0, device=device)))
    
    with pyro.plate("features", p):
        indicators = pyro.sample("indicators", dist.Bernoulli(spike_prob))
        local_scale = pyro.sample("local_scale", 
                                dist.HalfNormal(torch.tensor(0.05, device=device)))
        
        scale = torch.clamp(
            indicators * global_scale * local_scale + (1 - indicators) * config["min_scale"],
            min=config["min_scale"],
            max=1.0
        )
        
        beta = pyro.sample("beta", 
                          dist.Laplace(torch.zeros(p, device=device), 
                                     scale * config["lasso_scale"]))
    
    risk_scores = torch.clamp(X @ beta, -config["max_risk_score"], config["max_risk_score"])
    log_pl = CoxPartialLikelihood()(risk_scores, survival_time, event)
    pyro.factor("log_pl", log_pl)

def guide(X, survival_time, event):
    n, p = X.shape
    device = X.device
    
    # Move all parameters to device
    global_scale_loc = pyro.param("global_scale_loc", 
                                torch.tensor(0.01, device=device),
                                constraint=constraints.positive)
    global_scale_scale = pyro.param("global_scale_scale",
                                  torch.tensor(0.001, device=device),
                                  constraint=constraints.positive)
    
    global_scale = pyro.sample("global_scale",
                             dist.LogNormal(global_scale_loc, global_scale_scale))
    
    spike_alpha = pyro.param("spike_prob_alpha",
                           torch.tensor(1.05, device=device),
                           constraint=constraints.positive)
    spike_beta = pyro.param("spike_prob_beta",
                          torch.tensor(10.0, device=device),
                          constraint=constraints.positive)
    
    spike_prob = pyro.sample("spike_prob",
                           dist.Beta(spike_alpha, spike_beta))
    
    with pyro.plate("features", p):
        indicator_probs = pyro.param(
            "indicator_probs",
            torch.full((p,), config["init_spike_prob"], device=device),
            constraint=constraints.interval(0.0, 1.0)
        )
        
        local_scale_loc = pyro.param("local_scale_loc",
                                   torch.full((p,), 0.05, device=device),
                                   constraint=constraints.positive)
        local_scale_scale = pyro.param("local_scale_scale",
                                     torch.full((p,), 0.01, device=device),
                                     constraint=constraints.positive)
        
        beta_loc = pyro.param("beta_loc",
                            torch.zeros(p, device=device))
        beta_scale = pyro.param("beta_scale",
                              torch.full((p,), 0.1, device=device),
                              constraint=constraints.positive)
        
        pyro.sample("indicators", dist.Bernoulli(indicator_probs))
        pyro.sample("local_scale", dist.LogNormal(local_scale_loc, local_scale_scale))
        pyro.sample("beta", dist.Normal(beta_loc, beta_scale))

def setup_gpu():
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        print(f"\nGPU: {torch.cuda.get_device_name(0)}")
    else:
        print("\nUsing CPU")

def load_and_preprocess(path):
    print(f"\nLoading data from: {path}")
    data = pd.read_csv(path)
    print(f"Data shape: {data.shape}")
    
    X = data.iloc[:, 1:-2].values.astype(np.float32)
    time_values = data['time'].values.astype(np.float32)
    event_values = data['event'].values.astype(np.float32)
    
    X = (X - np.nanmean(X, axis=0)) / (np.nanstd(X, axis=0) + 1e-8)
    X = np.nan_to_num(X)
    
    # Move all data to GPU upfront
    X_tensor = torch.tensor(X, dtype=config["precision"]).to(config["device"])
    time_tensor = torch.tensor(time_values, dtype=config["precision"]).to(config["device"])
    event_tensor = torch.tensor(event_values, dtype=config["precision"]).to(config["device"])
    
    print(f"Number of features: {X_tensor.shape[1]}")
    print(f"Number of samples: {X_tensor.shape[0]}")
    print(f"Event rate: {event_tensor.float().mean().item():.2%}")
    
    return X_tensor, time_tensor, event_tensor

class BayesianCoxModel:
    def __init__(self):
        self.device = config["device"]
        pyro.clear_param_store()
        self._initialize_params()

    def _initialize_params(self):
        X_dummy = torch.zeros((2, 1000), device=self.device, dtype=config["precision"])
        time_dummy = torch.zeros(2, device=self.device, dtype=config["precision"])
        event_dummy = torch.zeros(2, device=self.device, dtype=config["precision"])
        
        model(X_dummy, time_dummy, event_dummy)
        guide(X_dummy, time_dummy, event_dummy)

    def train(self, train_loader, val_loader=None):
        optimizer = ClippedAdam({
            "lr": config["initial_lr"],
            "clip_norm": config["clip_norm"],
            "weight_decay": 1e-4,
            "betas": (0.9, 0.999),  # Add momentum parameters
        })
        
        svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=config["elbo_particles"]))
        
        best_loss = float('inf')
        losses = []
        val_losses = []
        patience = 0
        
        try:
            for epoch in range(config["max_epochs"]):
                epoch_loss = 0.0
                for batch in train_loader:
                    X_batch, t_batch, e_batch = [b.to(self.device) for b in batch]
                    
                    # Scale the batch data
                    X_batch = (X_batch - X_batch.mean(0)) / (X_batch.std(0) + 1e-8)
                    
                    loss = svi.step(X_batch, t_batch, e_batch)
                    epoch_loss += loss
                
                avg_loss = epoch_loss / len(train_loader)
                losses.append(avg_loss)
                
                if val_loader:
                    val_loss = self._evaluate(val_loader)
                    val_losses.append(val_loss)
                    
                    if val_loss < best_loss:
                        best_loss = val_loss
                        patience = 0
                    else:
                        patience += 1
                    
                    if patience >= config["early_stop_patience"]:
                        print(f"\nEarly stopping at epoch {epoch}")
                        break
                
                if epoch % config["decay_step"] == 0 and epoch > 0:
                    optimizer.set_state({'lr': config["initial_lr"] * (config["lr_decay"] ** (epoch // config["decay_step"]))})
                
                if epoch % 50 == 0:
                    print(f"Epoch {epoch:4d} | Train Loss: {avg_loss:.2f}" + 
                          (f" | Val Loss: {val_loss:.2f}" if val_loader else ""))
        
        except KeyboardInterrupt:
            print("\nTraining stopped by user")
        
        return losses, val_losses

    def _evaluate(self, loader):
        total_loss = 0.0
        with torch.no_grad():
            for batch in loader:
                X_batch, t_batch, e_batch = [b.to(self.device) for b in batch]
                total_loss += pyro.infer.Trace_ELBO(num_particles=5).loss(
                    model, guide, X_batch, t_batch, e_batch
                )
        return total_loss / len(loader)

def cross_validate(X, time, event):
    kf = KFold(n_splits=config["n_folds"], shuffle=True, random_state=42)
    fold_results = []
    device = config["device"]
    
    # Keep original data on GPU
    X = X.to(device)
    time = time.to(device)
    event = event.to(device)
    
    # Generate splits
    for fold, (train_idx, test_idx) in enumerate(kf.split(X.cpu().numpy())):
        print(f"\n=== Fold {fold+1}/{config['n_folds']} ===")
        
        # Convert indices to GPU tensors
        train_idx = torch.tensor(train_idx, device=device)
        test_idx = torch.tensor(test_idx, device=device)
        
        # Create subsets directly on GPU
        X_train = X[train_idx]
        time_train = time[train_idx]
        event_train = event[train_idx]
        
        X_test = X[test_idx]
        time_test = time[test_idx]
        event_test = event[test_idx]
        
        # Create datasets (data already on GPU)
        train_dataset = TensorDataset(X_train, time_train, event_train)
        test_dataset = TensorDataset(X_test, time_test, event_test)
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=config["batch_size"],
            shuffle=True
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=config["batch_size"]
        )
        
        # Initialize and train model
        model = BayesianCoxModel()
        train_losses, val_losses = model.train(train_loader, test_loader)
        
        # Compute metrics
        with torch.no_grad():
            beta = pyro.param("beta_loc")  # Already on GPU
            risk_scores = X_test @ beta
            
            # Move to CPU only for concordance index calculation
            c_index = concordance_index(
                time_test.cpu().numpy(),
                -risk_scores.cpu().numpy(),
                event_test.cpu().numpy()
            )
            
            probs = pyro.param("indicator_probs").cpu().numpy()
        
        fold_results.append({
            'c_index': c_index,
            'selected_features': (probs > 0.5).astype(int),
            'train_loss': train_losses,
            'val_loss': val_losses
        })
    
    return fold_results

def save_results(results_dir, fold_results):
    with open(os.path.join(results_dir, 'results.json'), 'w') as f:
        json.dump({
            'c_indices': [res['c_index'] for res in fold_results],
            'mean_c_index': float(np.mean([res['c_index'] for res in fold_results])),
            'std_c_index': float(np.std([res['c_index'] for res in fold_results])),
            'selected_features': [res['selected_features'].tolist() for res in fold_results]
        }, f, indent=2)
    
    with open(os.path.join(results_dir, 'config.json'), 'w') as f:
        safe_config = {k: str(v) if isinstance(v, torch.dtype) else v for k, v in config.items()}
        json.dump(safe_config, f, indent=2)
    
    plt.figure(figsize=(10, 6))
    for i, res in enumerate(fold_results):
        plt.plot(res['train_loss'], label=f'Fold {i+1} Train')
        plt.plot(res['val_loss'], '--', label=f'Fold {i+1} Val')
    plt.title("Training/Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(os.path.join(results_dir, 'loss_curves.png'))
    plt.close()
    
    plt.figure(figsize=(8, 5))
    plt.boxplot([res['c_index'] for res in fold_results])
    plt.title("Cross-Validation C-Index")
    plt.ylabel("C-Index")
    plt.savefig(os.path.join(results_dir, 'c_indices.png'))
    plt.close()

def main():
    setup_gpu()
    os.makedirs(config["results_dir"], exist_ok=True)
    
    try:
        X, time, event = load_and_preprocess(config["data_path"])
        fold_results = cross_validate(X, time, event)
        save_results(config["results_dir"], fold_results)
        
        print("\nTraining completed successfully!")
        print(f"Results saved to: {config['results_dir']}")
        print(f"Average C-index: {np.mean([res['c_index'] for res in fold_results]):.3f}")
        
    except Exception as e:
        print(f"\nError: {str(e)}")
    finally:
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm



GPU: NVIDIA GeForce GTX 1650

Loading data from: D:\cox-model-imputation\error-in-r-code-for-mcar\datasets\vb-cox\realistic_cox_data.csv
Data shape: (800, 1003)
Number of features: 1000
Number of samples: 800
Event rate: 58.88%

=== Fold 1/5 ===
Epoch    0 | Train Loss: 7879570247.65 | Val Loss: 7898330759.41
Epoch   50 | Train Loss: 7828337776.54 | Val Loss: 7885955143.18
Epoch  100 | Train Loss: 7840899969.42 | Val Loss: 7853007554.40

Early stopping at epoch 113

=== Fold 2/5 ===
Epoch    0 | Train Loss: 7921804795.10 | Val Loss: 7920152238.15
Epoch   50 | Train Loss: 7819492229.12 | Val Loss: 7891488643.13
