In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [6]:
# Kaggle-optimized C²BA Implementation for UCI Heart Disease
# This script is designed to run directly in Kaggle notebooks

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Scikit-learn imports  
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix, classification_report
from sklearn.calibration import calibration_curve

# Other imports
import scipy.stats as stats
from scipy.spatial.distance import pdist, squareform
import warnings
import os
import sys
from pathlib import Path

warnings.filterwarnings('ignore')

# Kaggle-specific configurations
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
def set_random_seeds(seed=42):
    """Set all random seeds for reproducibility"""
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # For deterministic behavior (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_random_seeds(42)

# Kaggle dataset loading function
def load_heart_disease_data():
    """Load UCI Heart Disease dataset - adaptable to Kaggle input paths"""
    
    # Try multiple potential Kaggle input paths
    potential_paths = [
        '/kaggle/input/heart-disease-data/heart_disease_uci.csv',
        '/kaggle/input/heart-disease/heart.csv', 
        '/kaggle/input/uci-heart-disease/heart.csv',
        '/kaggle/input/heart.csv',
        'heart.csv'  # If uploaded directly
    ]
    
    df = None
    for path in potential_paths:
        try:
            if os.path.exists(path):
                df = pd.read_csv(path)
                print(f"✓ Successfully loaded data from: {path}")
                break
        except Exception as e:
            continue
    
    # If no dataset found, create synthetic data for demonstration
    if df is None:
        print("⚠ No heart disease dataset found in standard Kaggle paths.")
        print("Creating synthetic dataset for demonstration...")
        df = create_synthetic_heart_data()
    
    return df

def create_synthetic_heart_data():
    """Create synthetic UCI Heart Disease-like dataset"""
    np.random.seed(42)
    n_samples = 1025  # Realistic size similar to UCI dataset
    
    # Create realistic synthetic data
    data = {
        'age': np.random.normal(54, 9, n_samples).clip(29, 77).astype(int),
        'sex': np.random.binomial(1, 0.68, n_samples),  # Male bias as in real data
        'cp': np.random.choice([0, 1, 2, 3], n_samples, p=[0.47, 0.16, 0.29, 0.08]),
        'trestbps': np.random.normal(131, 17, n_samples).clip(94, 200).astype(int),
        'chol': np.random.normal(246, 51, n_samples).clip(126, 564).astype(int),
        'fbs': np.random.binomial(1, 0.15, n_samples),
        'restecg': np.random.choice([0, 1, 2], n_samples, p=[0.48, 0.48, 0.04]),
        'thalach': np.random.normal(149, 22, n_samples).clip(71, 202).astype(int),
        'exang': np.random.binomial(1, 0.33, n_samples),
        'oldpeak': np.random.exponential(1.04, n_samples).clip(0, 6.2).round(1),
        'slope': np.random.choice([0, 1, 2], n_samples, p=[0.21, 0.14, 0.65]),
        'ca': np.random.choice([0, 1, 2, 3, 4], n_samples, p=[0.59, 0.21, 0.12, 0.06, 0.02]),
        'thal': np.random.choice([0, 1, 2, 3], n_samples, p=[0.02, 0.55, 0.36, 0.07])
    }
    
    # Create target variable with realistic medical correlations
    risk_factors = (
        0.02 * data['age'] +
        0.5 * (data['cp'] == 0) +  # Typical angina increases risk
        0.3 * data['exang'] +
        0.4 * (data['thal'] == 2) +  # Reversible defect
        0.2 * data['oldpeak'] +
        -0.01 * data['thalach'] +  # Higher heart rate = lower risk
        0.3 * (data['ca'] > 0) +  # Major vessels
        0.2 * data['sex'] +  # Male higher risk
        np.random.normal(0, 0.8, n_samples)
    )
    
    # Convert to binary (0=no disease, 1=disease) - more realistic for UCI data
    data['target'] = (risk_factors > np.percentile(risk_factors, 55)).astype(int)
    
    df = pd.DataFrame(data)
    
    # Add some missing values to make it realistic
    missing_cols = ['ca', 'thal']
    for col in missing_cols:
        missing_mask = np.random.random(len(df)) < 0.02  # 2% missing
        df.loc[missing_mask, col] = np.nan
    
    print(f"✓ Created synthetic dataset with {len(df)} samples")
    print(f"Target distribution: {df['target'].value_counts().to_dict()}")
    
    return df

class HeartDiseaseDataset(Dataset):
    """Custom dataset for UCI Heart Disease with feature engineering"""
    
    def __init__(self, features, targets, transform=None):
        self.features = torch.FloatTensor(features)
        self.targets = torch.LongTensor(targets)
        self.transform = transform
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.targets[idx]
        
        if self.transform:
            x = self.transform(x)
        
        return x, y

class DataProcessor:
    """Comprehensive data preprocessing pipeline"""
    
    def __init__(self):
        self.scaler = StandardScaler()
        self.label_encoders = {}
        self.feature_stats = {}
        self.embedding_dims = {
            'sex': 4, 'dataset': 6, 'cp': 6, 'restecg': 4, 'slope': 4, 'thal': 4
        }
        
    def create_embeddings(self, df, categorical_cols):
        """Create embedding matrices for categorical features"""
        embeddings = {}
        for col in categorical_cols:
            if col in df.columns:
                unique_vals = df[col].nunique()
                embed_dim = min(50, int(unique_vals**0.25) * 4)
                embeddings[col] = nn.Embedding(unique_vals, embed_dim)
        return embeddings
    
    def engineer_features(self, df):
        """Advanced feature engineering"""
        df_eng = df.copy()
        
        # Polynomial interactions
        interactions = [
            ('age', 'trestbps'),
            ('age', 'thalch'), 
            ('chol', 'trestbps'),
            ('oldpeak', 'slope'),
            ('cp', 'exang')
        ]
        
        for feat1, feat2 in interactions:
            if feat1 in df_eng.columns and feat2 in df_eng.columns:
                df_eng[f'{feat1}_{feat2}_interaction'] = df_eng[feat1] * df_eng[feat2]
        
        # Missing value indicators
        for col in df_eng.columns:
            if df_eng[col].isna().sum() > len(df_eng) * 0.05:  # >5% missing
                df_eng[f'{col}_missing'] = df_eng[col].isna().astype(int)
        
        return df_eng
    
    def preprocess_data(self, df, is_training=True):
        """Complete preprocessing pipeline - adapted for real UCI heart data structure"""
        df_processed = df.copy()
        
        # Handle common UCI heart dataset column names
        target_col = 'target' if 'target' in df.columns else 'num'
        
        # Standardize column names if needed
        column_mapping = {
            'num': 'target',  # UCI dataset sometimes uses 'num' for target
        }
        df_processed = df_processed.rename(columns=column_mapping)
        
        # Ensure target column exists
        if target_col not in df_processed.columns and 'target' not in df_processed.columns:
            if 'num' in df_processed.columns:
                df_processed['target'] = df_processed['num']
            else:
                raise ValueError("No target column found. Expected 'target' or 'num'")
        
        # For multi-class, convert to binary for simplicity in this demo
        if 'target' in df_processed.columns:
            max_target_value = df_processed['target'].max()
            if pd.isna(max_target_value):
                max_target_value = 0
            if max_target_value > 1:  # Now comparing scalar values
                df_processed['target'] = (df_processed['target'] > 0).astype(int)
                # Now max_target_value is a scalar, safe to compare
                if max_target_value > 1:
                    df_processed['target'] = (df_processed['target'] > 0).astype(int)
            else:
                # All values are NaN, set default
                df_processed['target'] = 0
            
        # Feature engineering
        df_processed = self.engineer_features(df_processed)
        
        # Separate feature types
        feature_cols = [col for col in df_processed.columns if col != 'target']
        continuous_cols = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
        binary_cols = ['sex', 'fbs', 'exang']
        categorical_cols = ['cp', 'restecg', 'slope', 'ca', 'thal']
        
        # Filter to existing columns
        continuous_cols = [col for col in continuous_cols if col in df_processed.columns]
        binary_cols = [col for col in binary_cols if col in df_processed.columns]
        categorical_cols = [col for col in categorical_cols if col in df_processed.columns]
        
        # Handle missing values
        for col in continuous_cols:
            if col in df_processed.columns:
                if is_training:
                    self.feature_stats[f'{col}_mean'] = df_processed[col].mean()
                fill_value = self.feature_stats.get(f'{col}_mean', df_processed[col].mean())
                df_processed[col].fillna(fill_value, inplace=True)
        
        for col in categorical_cols + binary_cols:
            if col in df_processed.columns:
                if is_training:
                    mode_val = df_processed[col].mode()
                    self.feature_stats[f'{col}_mode'] = mode_val[0] if len(mode_val) > 0 else 0
                fill_value = self.feature_stats.get(f'{col}_mode', 0)
                df_processed[col].fillna(fill_value, inplace=True)
        
        # Encode categorical variables
        for col in categorical_cols:
            if col in df_processed.columns:
                if is_training:
                    self.label_encoders[col] = LabelEncoder()
                    df_processed[col] = self.label_encoders[col].fit_transform(df_processed[col].astype(str))
                else:
                    try:
                        df_processed[col] = self.label_encoders[col].transform(df_processed[col].astype(str))
                    except ValueError:
                        # Handle unseen categories
                        df_processed[col] = 0
        
        # Scale features
        feature_cols = [col for col in df_processed.columns if col != 'target']
        X = df_processed[feature_cols].values
        
        if is_training:
            X_scaled = self.scaler.fit_transform(X)
        else:
            X_scaled = self.scaler.transform(X)
        
        y = df_processed['target'].values if 'target' in df_processed.columns else None
        
        return X_scaled, y

class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention for tabular data"""
    
    def __init__(self, d_model, n_heads=4):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model)
        
        return self.W_o(attn_output)

class FoundationModel(nn.Module):
    """Deep tabular neural network with attention"""
    
    def __init__(self, input_dim, hidden_dims=[128, 96], output_dim=32):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        # Input processing
        self.input_bn = nn.BatchNorm1d(input_dim)
        self.input_dropout = nn.Dropout(0.1)
        
        # Projection layer for residual connection
        self.input_proj = nn.Linear(input_dim, hidden_dims[0])
        
        # Layer 1: Linear transformation with residual
        self.layer1 = nn.Linear(input_dim, hidden_dims[0])
        self.layer1_bn = nn.BatchNorm1d(hidden_dims[0])
        
        # Layer 2: Multi-head attention
        self.attention = MultiHeadAttention(hidden_dims[0], n_heads=4)
        self.attn_norm = nn.LayerNorm(hidden_dims[0])
        
        # Layer 3: Compression layer
        self.layer3 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.layer3_bn = nn.BatchNorm1d(hidden_dims[1])
        self.layer3_proj = nn.Linear(hidden_dims[0], hidden_dims[1])
        
        # Final representation layer
        self.output_layer = nn.Linear(hidden_dims[1], output_dim)
        
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
    
    def forward(self, x):
        # Input processing
        x_bn = self.input_bn(x)
        x_drop = self.input_dropout(x_bn)
        
        # Layer 1 with residual connection
        h1 = F.relu(self.layer1_bn(self.layer1(x_drop)))
        h1 = h1 + self.input_proj(x_drop)
        
        # Layer 2: Multi-head attention
        h1_unsqueezed = h1.unsqueeze(1)  # Add sequence dimension
        attn_out = self.attention(h1_unsqueezed).squeeze(1)
        h2 = self.attn_norm(attn_out + h1)
        
        # Layer 3: Compression with residual
        h3 = F.relu(self.layer3_bn(self.layer3(h2)))
        h3 = h3 + self.layer3_proj(h2)
        
        # Final representation
        output = torch.tanh(self.output_layer(h3))
        
        return output

class BayesianAdapter(nn.Module):
    """Low-rank Bayesian linear layer with Horseshoe prior"""
    
    def __init__(self, input_dim, output_dim, rank=8):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.rank = rank
        
        # Variational parameters for U matrix
        self.U_mean = nn.Parameter(torch.randn(input_dim, rank) * 0.1)
        self.U_logvar = nn.Parameter(torch.ones(input_dim, rank) * (-2))
        
        # Variational parameters for V matrix  
        self.V_mean = nn.Parameter(torch.randn(output_dim, rank) * 0.1)
        self.V_logvar = nn.Parameter(torch.ones(output_dim, rank) * (-2))
        
        # Global and local shrinkage parameters
        self.tau_mean = nn.Parameter(torch.tensor(0.0))
        self.tau_logvar = nn.Parameter(torch.tensor(-1.0))
        
        self.lambda_mean = nn.Parameter(torch.zeros(rank))
        self.lambda_logvar = nn.Parameter(torch.ones(rank) * (-1))
        
    def sample_weights(self, num_samples=1):
        """Sample weight matrices using reparameterization trick"""
        # Sample shrinkage parameters
        tau_std = torch.exp(0.5 * self.tau_logvar)
        tau_samples = self.tau_mean + tau_std * torch.randn(num_samples, device=self.tau_mean.device)
        
        lambda_std = torch.exp(0.5 * self.lambda_logvar)
        lambda_samples = self.lambda_mean + lambda_std * torch.randn(num_samples, self.rank, device=self.lambda_mean.device)
        
        # Compute effective variances
        tau_expanded = tau_samples.unsqueeze(-1)  # [num_samples, 1]
        lambda_expanded = lambda_samples  # [num_samples, rank]
        effective_var = tau_expanded * lambda_expanded  # [num_samples, rank]
        
        weights = []
        for i in range(num_samples):
            # Sample U and V matrices
            U_std = torch.exp(0.5 * self.U_logvar) * effective_var[i]
            V_std = torch.exp(0.5 * self.V_logvar) * effective_var[i]
            
            U_sample = self.U_mean + U_std * torch.randn_like(self.U_mean)
            V_sample = self.V_mean + V_std * torch.randn_like(self.V_mean)
            
            # Compute weight matrix W = UV^T
            W_sample = torch.mm(U_sample, V_sample.t())
            weights.append(W_sample)
        
        return torch.stack(weights)
    
    def kl_divergence(self):
        """Compute KL divergence with Horseshoe prior"""
        # KL for tau (global shrinkage)
        tau_var = torch.exp(self.tau_logvar)
        kl_tau = 0.5 * (tau_var + self.tau_mean**2 - 1 - self.tau_logvar)
        
        # KL for lambda (local shrinkage)
        lambda_var = torch.exp(self.lambda_logvar)
        kl_lambda = 0.5 * torch.sum(lambda_var + self.lambda_mean**2 - 1 - self.lambda_logvar)
        
        # KL for U and V matrices (approximated with unit variance prior)
        U_var = torch.exp(self.U_logvar)
        V_var = torch.exp(self.V_logvar)
        
        kl_U = 0.5 * torch.sum(U_var + self.U_mean**2 - 1 - self.U_logvar)
        kl_V = 0.5 * torch.sum(V_var + self.V_mean**2 - 1 - self.V_logvar)
        
        return kl_tau + kl_lambda + kl_U + kl_V
    
    def forward(self, x, num_samples=1):
        """Forward pass with Monte Carlo sampling"""
        if self.training:
            weights = self.sample_weights(num_samples)
            outputs = []
            for i in range(num_samples):
                outputs.append(torch.mm(x, weights[i]))
            return torch.stack(outputs).mean(dim=0)
        else:
            # Use mean weights for inference
            W_mean = torch.mm(self.U_mean, self.V_mean.t())
            return torch.mm(x, W_mean)

class DistributionShiftDetector(nn.Module):
    """Detect and quantify distribution shift"""
    
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, x):
        return self.classifier(x)
    
    def compute_mmd(self, X_train, X_test, gamma=1.0):
        """Compute Maximum Mean Discrepancy"""
        n, m = X_train.size(0), X_test.size(0)
        
        # RBF kernel computation
        def rbf_kernel(X, Y, gamma):
            X_norm = (X**2).sum(1).view(-1, 1)
            Y_norm = (Y**2).sum(1).view(1, -1)
            K = torch.exp(-gamma * (X_norm + Y_norm - 2 * torch.mm(X, Y.t())))
            return K
        
        # MMD computation
        Kxx = rbf_kernel(X_train, X_train, gamma)
        Kyy = rbf_kernel(X_test, X_test, gamma)
        Kxy = rbf_kernel(X_train, X_test, gamma)
        
        mmd_squared = (Kxx.sum() / (n * n) + Kyy.sum() / (m * m) - 
                      2 * Kxy.sum() / (n * m))
        
        return torch.sqrt(torch.clamp(mmd_squared, min=1e-8))
    
    def compute_energy_distance(self, X_train, X_test):
        """Compute energy distance between distributions"""
        n, m = X_train.size(0), X_test.size(0)
        
        # Pairwise distances
        def pairwise_distances(X, Y):
            return torch.cdist(X, Y, p=2)
        
        # Energy distance components
        d_xy = pairwise_distances(X_train, X_test).mean()
        d_xx = pairwise_distances(X_train, X_train).mean()
        d_yy = pairwise_distances(X_test, X_test).mean()
        
        return 2 * d_xy - d_xx - d_yy

class CalibrationSystem(nn.Module):
    """Temperature scaling with density ratio correction"""
    
    def __init__(self, num_classes=5):
        super().__init__()
        self.num_classes = num_classes
        self.temperature = nn.Parameter(torch.ones(1))
        self.density_ratio_weight = nn.Parameter(torch.tensor(0.1))
        
    def temperature_scale(self, logits):
        """Apply temperature scaling"""
        return logits / self.temperature
    
    def forward(self, logits, density_ratios=None):
        """Apply calibration with optional density ratio correction"""
        scaled_logits = self.temperature_scale(logits)
        
        if density_ratios is not None and self.training:
            # Apply density ratio correction
            correction = self.density_ratio_weight * torch.log(density_ratios + 1e-8)
            scaled_logits = scaled_logits + correction.unsqueeze(-1)
        
        return F.softmax(scaled_logits, dim=-1)

class C2BAModel(nn.Module):
    """Complete Counterfactually-Calibrated Bayesian Adapter model"""
    
    def __init__(self, input_dim, num_classes=2, foundation_dim=32, adapter_rank=8):
        super().__init__()
        self.num_classes = num_classes
        self.foundation = FoundationModel(input_dim, output_dim=foundation_dim)
        self.bayesian_adapter = BayesianAdapter(foundation_dim, num_classes, rank=adapter_rank)
        self.shift_detector = DistributionShiftDetector(foundation_dim)
        self.calibration = CalibrationSystem(num_classes)
        
        # Store training data statistics for shift detection
        self.register_buffer('train_features_mean', torch.zeros(foundation_dim))
        self.register_buffer('train_features_std', torch.ones(foundation_dim))
        
    def forward(self, x, compute_uncertainty=True, num_mc_samples=5):
        # Extract foundation features
        features = self.foundation(x)
        
        if compute_uncertainty and self.training:
            # Multiple forward passes for uncertainty estimation
            predictions = []
            for _ in range(num_mc_samples):
                logits = self.bayesian_adapter(features, num_samples=1)
                predictions.append(logits)
            
            logits = torch.stack(predictions).mean(dim=0)
            uncertainty = torch.stack(predictions).std(dim=0).mean(dim=-1)
        else:
            logits = self.bayesian_adapter(features)
            uncertainty = None
        
        # Detect distribution shift
        shift_score = self.shift_detector(features).sigmoid()
        
        # Apply calibration
        calibrated_probs = self.calibration(logits, shift_score)
        
        return {
            'logits': logits,
            'probabilities': calibrated_probs,
            'features': features,
            'shift_score': shift_score,
            'uncertainty': uncertainty
        }
    
    def compute_loss(self, outputs, targets, train_features=None):
        """Compute total loss including ELBO and calibration terms"""
        logits = outputs['logits']
        shift_scores = outputs['shift_score']
        
        # Classification loss
        ce_loss = F.cross_entropy(logits, targets)
        
        # Bayesian adapter KL divergence
        kl_loss = self.bayesian_adapter.kl_divergence()
        
        # Shift detection loss (when training data is available)
        shift_loss = torch.tensor(0.0, device=logits.device)
        if train_features is not None:
            # Create labels for shift detection
            batch_size = logits.size(0)
            train_batch_size = train_features.size(0)
            
            shift_labels = torch.cat([
                torch.zeros(train_batch_size),  # training data
                torch.ones(batch_size)  # current batch (potentially shifted)
            ]).to(logits.device)
            
            combined_features = torch.cat([train_features, outputs['features']], dim=0)
            shift_logits = self.shift_detector(combined_features).squeeze()
            shift_loss = F.binary_cross_entropy_with_logits(shift_logits, shift_labels)
        
        # Total loss with weighting
        total_loss = ce_loss + 0.01 * kl_loss + 0.5 * shift_loss
        
        return {
            'total_loss': total_loss,
            'ce_loss': ce_loss,
            'kl_loss': kl_loss,
            'shift_loss': shift_loss
        }

class ModelTrainer:
    """Comprehensive training pipeline for C²BA model"""
    
    def __init__(self, model, device='cpu'):
        self.model = model.to(device)
        self.device = device
        self.best_val_loss = float('inf')
        self.patience_counter = 0
        
    def train_epoch(self, train_loader, optimizer, kl_weight=1.0):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(self.device), targets.to(self.device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = self.model(data, compute_uncertainty=True)
            
            # Compute losses
            loss_dict = self.model.compute_loss(outputs, targets)
            
            # Apply KL annealing
            loss_dict['total_loss'] = (loss_dict['ce_loss'] + 
                                     kl_weight * loss_dict['kl_loss'] + 
                                     loss_dict['shift_loss'])
            
            # Backward pass
            loss_dict['total_loss'].backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            optimizer.step()
            
            # Statistics
            total_loss += loss_dict['total_loss'].item()
            pred = outputs['probabilities'].argmax(dim=1)
            correct += pred.eq(targets).sum().item()
            total += targets.size(0)
            
        return {
            'loss': total_loss / len(train_loader),
            'accuracy': 100. * correct / total
        }
    
    def evaluate(self, val_loader):
        """Evaluate model on validation set"""
        self.model.eval()
        val_loss = 0
        correct = 0
        total = 0
        all_probs = []
        all_targets = []
        
        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                
                outputs = self.model(data, compute_uncertainty=False)
                loss_dict = self.model.compute_loss(outputs, targets)
                
                val_loss += loss_dict['total_loss'].item()
                pred = outputs['probabilities'].argmax(dim=1)
                correct += pred.eq(targets).sum().item()
                total += targets.size(0)
                
                all_probs.extend(outputs['probabilities'].cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        all_probs = np.array(all_probs)
        all_targets = np.array(all_targets)
        
        # Compute additional metrics
        pred_classes = np.argmax(all_probs, axis=1)
        f1 = f1_score(all_targets, pred_classes, average='weighted')
        
        return {
            'loss': val_loss / len(val_loader),
            'accuracy': 100. * correct / total,
            'f1_score': f1,
            'predictions': all_probs,
            'targets': all_targets
        }
    
    def train(self, train_loader, val_loader, num_epochs=100, patience=10):
        """Complete training loop with early stopping"""
        # Optimizer with different learning rates for different components
        foundation_params = list(self.model.foundation.parameters())
        adapter_params = list(self.model.bayesian_adapter.parameters())
        other_params = (list(self.model.shift_detector.parameters()) + 
                       list(self.model.calibration.parameters()))
        
        optimizer = torch.optim.AdamW([
            {'params': foundation_params, 'lr': 1e-3},
            {'params': adapter_params, 'lr': 5e-3},
            {'params': other_params, 'lr': 1e-3}
        ], weight_decay=1e-5)
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=num_epochs, eta_min=1e-6)
        
        history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
        
        for epoch in range(num_epochs):
            # KL annealing
            kl_weight = min(1.0, epoch / 50.0)
            
            # Training
            train_stats = self.train_epoch(train_loader, optimizer, kl_weight)
            
            # Validation
            val_stats = self.evaluate(val_loader)
            
            # Learning rate scheduling
            scheduler.step()
            
            # Record history
            history['train_loss'].append(train_stats['loss'])
            history['train_acc'].append(train_stats['accuracy'])
            history['val_loss'].append(val_stats['loss'])
            history['val_acc'].append(val_stats['accuracy'])
            
            # Early stopping
            if val_stats['loss'] < self.best_val_loss:
                self.best_val_loss = val_stats['loss']
                self.patience_counter = 0
                # Save best model
                torch.save(self.model.state_dict(), 'best_c2ba_model.pt')
            else:
                self.patience_counter += 1
                
            if self.patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
                
            if (epoch + 1) % 10 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}:')
                print(f'  Train Loss: {train_stats["loss"]:.4f}, Train Acc: {train_stats["accuracy"]:.2f}%')
                print(f'  Val Loss: {val_stats["loss"]:.4f}, Val Acc: {val_stats["accuracy"]:.2f}%')
                print(f'  Val F1: {val_stats["f1_score"]:.4f}')
        
        # Load best model
        self.model.load_state_dict(torch.load('best_c2ba_model.pt'))
        
        return history

class MetricsCalculator:
    """Calculate comprehensive evaluation metrics"""
    
    @staticmethod
    def expected_calibration_error(y_prob, y_true, n_bins=10):
        """Calculate Expected Calibration Error (ECE)"""
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        bin_lowers = bin_boundaries[:-1]
        bin_uppers = bin_boundaries[1:]
        
        ece = 0
        for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
            # Find samples in bin
            in_bin = (y_prob > bin_lower) & (y_prob <= bin_upper)
            prop_in_bin = in_bin.mean()
            
            if prop_in_bin > 0:
                accuracy_in_bin = y_true[in_bin].mean()
                avg_confidence_in_bin = y_prob[in_bin].mean()
                ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        
        return ece
    
    @staticmethod
    def brier_score(y_prob, y_true, num_classes):
        """Calculate Brier Score"""
        y_true_onehot = np.eye(num_classes)[y_true]
        return np.mean(np.sum((y_prob - y_true_onehot)**2, axis=1))
    
    @staticmethod
    def compute_comprehensive_metrics(y_prob, y_true, num_classes=5):
        """Compute all evaluation metrics"""
        y_pred = np.argmax(y_prob, axis=1)
        y_prob_max = np.max(y_prob, axis=1)
        y_true_binary = (y_pred == y_true).astype(int)
        
        # Classification metrics
        accuracy = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average='weighted')
        
        # Calibration metrics
        ece = MetricsCalculator.expected_calibration_error(y_prob_max, y_true_binary)
        brier = MetricsCalculator.brier_score(y_prob, y_true, num_classes)
        
        # Confidence metrics
        avg_confidence = np.mean(y_prob_max)
        confidence_std = np.std(y_prob_max)
        
        return {
            'accuracy': accuracy,
            'f1_score': f1,
            'ece': ece,
            'brier_score': brier,
            'avg_confidence': avg_confidence,
            'confidence_std': confidence_std
        }

def create_distribution_shifts(X, y, shift_type='age'):
    """Create different types of distribution shifts - adapted for binary classification"""
    
    if shift_type == 'age':
        # Age-based split (assuming age is first column after scaling)
        # Use median split
        age_median = np.median(X[:, 0])  # Age is typically first feature
        train_mask = X[:, 0] <= age_median  # Younger patients for training
        test_mask = X[:, 0] > age_median    # Older patients for testing
    
    elif shift_type == 'gender':
        # Gender-based split (assuming sex is in features)
        # Find sex column (should be binary 0/1)
        sex_col = None
        for i in range(min(10, X.shape[1])):  # Check first 10 features
            unique_vals = np.unique(X[:, i])
            if len(unique_vals) == 2 and set(unique_vals).issubset({0, 1}):
                sex_col = i
                break
        
        if sex_col is not None:
            # Create gender imbalance
            male_indices = np.where(X[:, sex_col] > 0.5)[0]  # After scaling, male might not be exactly 1
            female_indices = np.where(X[:, sex_col] <= 0.5)[0]
            
            # Train: 70% male, Test: balanced
            n_train = len(X) // 2
            if len(male_indices) > 0 and len(female_indices) > 0:
                train_male_count = min(int(0.7 * n_train), len(male_indices))
                train_female_count = min(n_train - train_male_count, len(female_indices))
                
                train_indices = np.concatenate([
                    np.random.choice(male_indices, train_male_count, replace=False),
                    np.random.choice(female_indices, train_female_count, replace=False)
                ])
                
                train_mask = np.zeros(len(X), dtype=bool)
                train_mask[train_indices] = True
                test_mask = ~train_mask
            else:
                # Fallback to random split
                train_mask = np.random.random(len(X)) < 0.6
                test_mask = ~train_mask
        else:
            # Fallback to random split
            train_mask = np.random.random(len(X)) < 0.6
            test_mask = ~train_mask
    
    elif shift_type == 'severity':
        # Severity-based split for binary classification
        # Train on one class predominantly, test on balanced
        positive_indices = np.where(y == 1)[0]
        negative_indices = np.where(y == 0)[0]
        
        if len(positive_indices) > 0 and len(negative_indices) > 0:
            # Train: 80% negative cases (mild), Test: balanced
            n_train = len(X) // 2
            train_neg_count = min(int(0.8 * n_train), len(negative_indices))
            train_pos_count = min(n_train - train_neg_count, len(positive_indices))
            
            train_indices = np.concatenate([
                np.random.choice(negative_indices, train_neg_count, replace=False),
                np.random.choice(positive_indices, train_pos_count, replace=False)
            ])
            
            train_mask = np.zeros(len(X), dtype=bool)
            train_mask[train_indices] = True
            test_mask = ~train_mask
        else:
            # Fallback to random split
            train_mask = np.random.random(len(X)) < 0.6
            test_mask = ~train_mask
    
    elif shift_type == 'feature':
        # Feature-based shift: split based on a continuous feature
        # Use cholesterol or blood pressure if available
        feature_col = min(2, X.shape[1] - 1)  # Use 3rd feature (likely chol or trestbps)
        feature_median = np.median(X[:, feature_col])
        train_mask = X[:, feature_col] <= feature_median
        test_mask = X[:, feature_col] > feature_median
    
    else:
        # Random split as baseline
        train_mask = np.random.random(len(X)) < 0.6
        test_mask = ~train_mask
    
    # Ensure both sets have both classes for binary classification
    train_classes = np.unique(y[train_mask])
    test_classes = np.unique(y[test_mask])
    
    if len(train_classes) < 2 or len(test_classes) < 2:
        print("⚠ Warning: Unbalanced class distribution detected. Using stratified split.")
        # Fallback to stratified split
        train_indices, test_indices = train_test_split(
            np.arange(len(X)), test_size=0.4, random_state=42, stratify=y
        )
        train_mask = np.zeros(len(X), dtype=bool)
        test_mask = np.zeros(len(X), dtype=bool)
        train_mask[train_indices] = True
        test_mask[test_indices] = True
    
    return train_mask, test_mask

def load_and_preprocess_data():
    """Load and preprocess UCI Heart Disease dataset"""
    # Load the heart disease dataset
    df = load_heart_disease_data()
    
    print(f"Dataset shape: {df.shape}")
    print(f"Columns: {list(df.columns)}")
    
    # Display basic statistics
    target_col = 'target' if 'target' in df.columns else 'num'
    if target_col in df.columns:
        print(f"Target distribution: {df[target_col].value_counts().sort_index().to_dict()}")
    
    # If we have a multi-class target (num column), convert to binary for this implementation
    if 'num' in df.columns and df['num'].max() > 1:
        # Convert multi-class to binary (0 = no disease, 1+ = disease)
        df['target'] = (df['num'] > 0).astype(int)
        print(f"Converted to binary classification. New target distribution: {df['target'].value_counts().to_dict()}")
    elif 'target' not in df.columns:
        # If no target column exists, create one for demonstration
        print("No target column found, creating synthetic target...")
        np.random.seed(42)
        # Create realistic target based on some features
        if 'age' in df.columns and 'cp' in df.columns:
            risk_score = (df['age'] - df['age'].mean()) / df['age'].std()
            if 'cp' in df.columns:
                risk_score += 0.5 * (df['cp'] == df['cp'].mode()[0])
            df['target'] = (risk_score > 0).astype(int)
        else:
            df['target'] = np.random.binomial(1, 0.45, len(df))
        print(f"Created synthetic target distribution: {df['target'].value_counts().to_dict()}")
    
    return df

def main():
    """Main training and evaluation pipeline"""
    print("=== C²BA Heart Disease Classification ===")
    print("Loading and preprocessing data...")
    
    # Load data
    df = load_and_preprocess_data()
    print(f"Dataset shape: {df.shape}")
    print(f"Class distribution: {df['num'].value_counts().sort_index().to_dict()}")
    
    # Initialize data processor
    processor = DataProcessor()
    
    # Preprocess data
    X, y = processor.preprocess_data(df, is_training=True)
    
    print(f"Preprocessed feature shape: {X.shape}")
    print(f"Number of classes: {len(np.unique(y))}")
    
    # Create distribution shift
    train_mask, test_mask = create_distribution_shifts(X, y, shift_type='geographical')
    
    X_train_shift, y_train_shift = X[train_mask], y[train_mask]
    X_test_shift, y_test_shift = X[test_mask], y[test_mask]
    
    # Standard train/validation split on shifted training data
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_shift, y_train_shift, test_size=0.2, random_state=42, stratify=y_train_shift
    )
    
    print(f"Training set: {X_train.shape[0]} samples")
    print(f"Validation set: {X_val.shape[0]} samples")
    print(f"Test set (shifted): {X_test_shift.shape[0]} samples")
    
    # Create datasets and data loaders
    train_dataset = HeartDiseaseDataset(X_train, y_train)
    val_dataset = HeartDiseaseDataset(X_val, y_val)
    test_dataset = HeartDiseaseDataset(X_test_shift, y_test_shift)
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    model = C2BAModel(
        input_dim=X_train.shape[1],
        num_classes=5,
        foundation_dim=32,
        adapter_rank=8
    )
    
    # Initialize trainer
    trainer = ModelTrainer(model, device=device)
    
    print("\n=== Training Phase ===")
    history = trainer.train(train_loader, val_loader, num_epochs=100, patience=15)
    
    print("\n=== Evaluation Phase ===")
    
    # Evaluate on validation set
    print("Validation Set Evaluation:")
    val_results = trainer.evaluate(val_loader)
    val_metrics = MetricsCalculator.compute_comprehensive_metrics(
        val_results['predictions'], val_results['targets']
    )
    
    for metric, value in val_metrics.items():
        print(f"  {metric}: {value:.4f}")
    
    # Evaluate on test set (distribution shifted)
    print("\nTest Set Evaluation (Distribution Shifted):")
    test_results = trainer.evaluate(test_loader)
    test_metrics = MetricsCalculator.compute_comprehensive_metrics(
        test_results['predictions'], test_results['targets']
    )
    
    for metric, value in test_metrics.items():
        print(f"  {metric}: {value:.4f}")
    
    # Distribution shift analysis
    print("\n=== Distribution Shift Analysis ===")
    
    # Compute shift detection metrics
    model.eval()
    with torch.no_grad():
        # Get features from both sets
        train_features_list = []
        test_features_list = []
        
        for data, _ in val_loader:
            data = data.to(device)
            outputs = model(data, compute_uncertainty=False)
            train_features_list.append(outputs['features'].cpu())
        
        for data, _ in test_loader:
            data = data.to(device)
            outputs = model(data, compute_uncertainty=False)
            test_features_list.append(outputs['features'].cpu())
        
        train_features = torch.cat(train_features_list, dim=0)
        test_features = torch.cat(test_features_list, dim=0)
    
    # Compute MMD
    shift_detector = model.shift_detector
    mmd_score = shift_detector.compute_mmd(train_features, test_features).item()
    energy_distance = shift_detector.compute_energy_distance(train_features, test_features).item()
    
    print(f"Maximum Mean Discrepancy: {mmd_score:.6f}")
    print(f"Energy Distance: {energy_distance:.6f}")
    
    # Performance degradation analysis
    accuracy_drop = val_metrics['accuracy'] - test_metrics['accuracy']
    calibration_degradation = test_metrics['ece'] - val_metrics['ece']
    
    print(f"\nPerformance Impact:")
    print(f"  Accuracy drop: {accuracy_drop:.2f}%")
    print(f"  Calibration degradation (ECE): {calibration_degradation:.4f}")
    print(f"  F1-score drop: {val_metrics['f1_score'] - test_metrics['f1_score']:.4f}")
    
    # Uncertainty quantification analysis
    print("\n=== Uncertainty Quantification Analysis ===")
    
    # Get uncertainty estimates
    model.eval()
    uncertainties = []
    confidences = []
    correct_preds = []
    
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data, compute_uncertainty=True, num_mc_samples=10)
            
            probs = outputs['probabilities'].cpu().numpy()
            preds = np.argmax(probs, axis=1)
            confidence = np.max(probs, axis=1)
            correct = (preds == targets.cpu().numpy()).astype(int)
            
            if outputs['uncertainty'] is not None:
                uncertainties.extend(outputs['uncertainty'].cpu().numpy())
            confidences.extend(confidence)
            correct_preds.extend(correct)
    
    confidences = np.array(confidences)
    correct_preds = np.array(correct_preds)
    
    # Analyze uncertainty-accuracy relationship
    if uncertainties:
        uncertainties = np.array(uncertainties)
        # Correlation between uncertainty and correctness
        uncertainty_accuracy_corr = np.corrcoef(uncertainties, 1 - correct_preds)[0, 1]
        print(f"Uncertainty-Error Correlation: {uncertainty_accuracy_corr:.4f}")
        
        # High uncertainty samples statistics
        high_uncertainty_threshold = np.percentile(uncertainties, 80)
        high_uncertainty_mask = uncertainties > high_uncertainty_threshold
        high_uncertainty_accuracy = correct_preds[high_uncertainty_mask].mean()
        
        print(f"High Uncertainty Samples Accuracy: {high_uncertainty_accuracy:.4f}")
    
    # Confidence-accuracy relationship
    confidence_accuracy_corr = np.corrcoef(confidences, correct_preds)[0, 1]
    print(f"Confidence-Accuracy Correlation: {confidence_accuracy_corr:.4f}")
    
    # Calibration reliability diagram
    print("\n=== Calibration Reliability Analysis ===")
    
    # Bin predictions by confidence
    n_bins = 10
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    
    print("Confidence Bin | Accuracy | Count | ECE Contribution")
    print("-" * 55)
    
    total_ece_contribution = 0
    for i in range(n_bins):
        bin_lower = bin_boundaries[i]
        bin_upper = bin_boundaries[i + 1]
        
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        
        if in_bin.sum() > 0:
            bin_accuracy = correct_preds[in_bin].mean()
            bin_confidence = confidences[in_bin].mean()
            bin_count = in_bin.sum()
            bin_weight = bin_count / len(confidences)
            ece_contribution = abs(bin_confidence - bin_accuracy) * bin_weight
            total_ece_contribution += ece_contribution
            
            print(f"{bin_lower:.1f}-{bin_upper:.1f}     | {bin_accuracy:.3f}    | {bin_count:5d} | {ece_contribution:.4f}")
    
    print("-" * 55)
    print(f"Total ECE: {total_ece_contribution:.4f}")
    
    # Save results summary
    results_summary = {
        'validation_metrics': val_metrics,
        'test_metrics': test_metrics,
        'distribution_shift': {
            'mmd_score': mmd_score,
            'energy_distance': energy_distance,
            'accuracy_drop': accuracy_drop,
            'calibration_degradation': calibration_degradation
        },
        'uncertainty_analysis': {
            'confidence_accuracy_correlation': confidence_accuracy_corr,
        }
    }
    
    if uncertainties:
        results_summary['uncertainty_analysis'].update({
            'uncertainty_accuracy_correlation': uncertainty_accuracy_corr,
            'high_uncertainty_accuracy': high_uncertainty_accuracy
        })
    
    print("\n=== Training Complete ===")
    print("Model saved as 'best_c2ba_model.pt'")
    print("\nKey Findings:")
    print(f"- Base model accuracy: {val_metrics['accuracy']:.2f}%")
    print(f"- Shifted data accuracy: {test_metrics['accuracy']:.2f}%")
    print(f"- Calibration quality (ECE): {test_metrics['ece']:.4f}")
    print(f"- Distribution shift detected (MMD): {mmd_score:.6f}")
    
    if test_metrics['accuracy'] > val_metrics['accuracy'] * 0.9:  # Less than 10% drop
        print("✓ Model shows good robustness to distribution shift")
    else:
        print("⚠ Significant performance degradation detected")
    
    if test_metrics['ece'] < 0.1:  # Well-calibrated
        print("✓ Model maintains good calibration under shift")
    else:
        print("⚠ Calibration degraded under distribution shift")
    
    return results_summary, model, trainer

if __name__ == "__main__":
    # Run the complete pipeline
    results, trained_model, trainer = main()

=== C²BA Heart Disease Classification ===
Loading and preprocessing data...
✓ Successfully loaded data from: /kaggle/input/heart-disease-data/heart_disease_uci.csv
Dataset shape: (920, 16)
Columns: ['id', 'age', 'sex', 'dataset', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalch', 'exang', 'oldpeak', 'slope', 'ca', 'thal', 'num']
Target distribution: {0: 411, 1: 265, 2: 109, 3: 107, 4: 28}
Converted to binary classification. New target distribution: {1: 509, 0: 411}
Dataset shape: (920, 17)
Class distribution: {0: 411, 1: 265, 2: 109, 3: 107, 4: 28}


ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().