In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import RepeatedStratifiedKFold, GridSearchCV, StratifiedKFold, RandomizedSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import SelectFdr, f_classif
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from xgboost import XGBClassifier
from sklearn.metrics import (roc_auc_score, accuracy_score, f1_score, 
                            recall_score, precision_score, confusion_matrix, 
                            roc_curve, auc)
from sklearn.decomposition import PCA
import scipy.stats as stats
from scipy.stats import t
import warnings
warnings.filterwarnings('ignore')
import shap
import os
import copy
from collections import Counter

# 设置随机种子
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

# 交叉验证参数
N_OUTER_REPEATS = 20
N_OUTER_SPLITS = 5
N_INNER_SPLITS = 5
FDR_ALPHA = 0.05
PCA_VARIANCE_THRESHOLD = 0.95
FEATURE_FREQ_THRESHOLD = 0.6
SHAP_NSAMPLES = 100
SHAP_BACKGROUND_SAMPLES = 50

# 模型参数
LOGISTIC_REG_PARAMS = {
    'C': [0.001, 0.01, 0.1, 1, 10],
    'penalty': ['l1', 'l2'],
    'solver': ['liblinear'],
    'class_weight': ['balanced', None]
}

SVM_PARAMS = {
    'C': [0.01, 0.1, 1, 10],
    'kernel': ['linear', 'rbf'],
    'gamma': ['scale', 'auto'],
    'class_weight': ['balanced', None]
}

DECISION_TREE_PARAMS = {
    'max_depth': [2, 3, 4, 5],
    'min_samples_split': [5, 10, 15],
    'min_samples_leaf': [3, 5, 7],
    'criterion': ['gini', 'entropy'],
    'class_weight': ['balanced', None]
}

RANDOM_FOREST_PARAMS = {
    'n_estimators': [50, 100, 150],
    'max_depth': [3, 4, 5],
    'min_samples_split': [5, 10],
    'min_samples_leaf': [3, 5],
    'max_features': ['sqrt', 'log2'],
    'class_weight': ['balanced', 'balanced_subsample', None]
}

KNN_PARAMS = {
    'n_neighbors': [3, 5, 7, 9, 11, 13],
    'weights': ['uniform', 'distance'],
    'metric': ['euclidean', 'manhattan']
}

NAIVE_BAYES_PARAMS = {
    'var_smoothing': [1e-11, 1e-9, 1e-7, 1e-5, 1e-3]
}

XGBOOST_PARAMS = {
    'n_estimators': [50, 100, 150],
    'max_depth': [2, 3, 4],
    'learning_rate': [0.01, 0.05, 0.1],
    'subsample': [0.7, 0.8, 0.9],
    'colsample_bytree': [0.7, 0.8, 0.9],
    'reg_alpha': [0, 0.1, 1],
    'reg_lambda': [0, 0.1, 1],
    'scale_pos_weight': [1, 1.5, 2]
}

RANDOM_SEARCH_MODELS = {
    'DecisionTree': True,
    'RandomForest': True, 
    'XGBoost': True
}

RANDOM_SEARCH_ITER = 50

def save_pca_formulas(pca_model, feature_names, explained_variance_ratios, filename='./results/pca_components_formulas.txt'):
    """Save PCA component formulas to a text file"""
    print("Saving PCA component formulas...")
    
    with open(filename, 'w') as f:
        f.write("PCA Component Formulas\n")
        f.write("=" * 50 + "\n\n")
        
        for i in range(pca_model.n_components_):
            f.write(f"PC{i+1} (Variance Explained: {explained_variance_ratios[i]:.4f}):\n")
            f.write("Formula: PC{} = ".format(i+1))
            
            # Get the component weights and sort by absolute value
            component_weights = pca_model.components_[i]
            feature_weights = list(zip(feature_names, component_weights))
            feature_weights.sort(key=lambda x: abs(x[1]), reverse=True)
            
            # Write the top contributing features
            terms = []
            for feature, weight in feature_weights[:10]:  # Top 10 features per component
                if abs(weight) > 0.01:  # Only include meaningful contributions
                    terms.append(f"{weight:.4f} × {feature}")
            
            f.write(" + ".join(terms))
            f.write("\n\n")
            
            # Also write the detailed feature contributions
            f.write("Top 10 feature contributions:\n")
            for j, (feature, weight) in enumerate(feature_weights[:10]):
                f.write(f"  {j+1:2d}. {feature}: {weight:.4f}\n")
            f.write("\n" + "-" * 50 + "\n\n")
    
    print(f"PCA component formulas saved to {filename}")

def load_and_preprocess_data():
    """Load clinical, pathological, recurrence data and extract radiomics features"""
    print("Loading data...")
    
    # Load clinical data
    clinical_df = pd.read_csv('./data/clinical.csv', index_col=0)
    print(f"Clinical data shape: {clinical_df.shape}")
    
    # Load pathological data
    patho_df = pd.read_csv('./data/pathological.csv', index_col=0)
    print(f"Pathological data shape: {patho_df.shape}")
    
    # Load recurrence data
    recurrence_df = pd.read_csv('./data/recurrence.csv', index_col=0)
    print(f"Recurrence data shape: {recurrence_df.shape}")
    
    # Extract radiomics features
    print("\nExtracting radiomics features...")
    # Load distinctive radiomics features
    distinctive_df = pd.read_csv('./results/radiomics_distinctive.csv')
    # Get top 50 features (already sorted by importance)
    top_50_features = distinctive_df['Feature'].head(50).tolist()
    print(f"Selected {len(top_50_features)} distinctive radiomics features")
    
    # Load all radiomics features
    radiomics_df = pd.read_csv('./results/radiomics.csv', index_col=0)
    # Extract only the top 50 features
    radiomics_features = radiomics_df[top_50_features]
    print(f"Radiomics features shape: {radiomics_features.shape}")
    
    # Merge all data
    # First, get common IDs
    common_ids = list(set(clinical_df.index) & set(patho_df.index) & 
                      set(recurrence_df.index) & set(radiomics_features.index))
    print(f"Common IDs across all datasets: {len(common_ids)}")
    
    # Align all dataframes
    clinical_aligned = clinical_df.loc[common_ids]
    patho_aligned = patho_df.loc[common_ids]
    recurrence_aligned = recurrence_df.loc[common_ids]
    radiomics_aligned = radiomics_features.loc[common_ids]
    
    # Combine all features
    all_features = pd.concat([clinical_aligned, patho_aligned, radiomics_aligned], axis=1)
    print(f"All features combined shape: {all_features.shape}")
    
    # Get target (assume first column is recurrence)
    y = recurrence_aligned.iloc[:, 0].values
    print(f"Target distribution: {pd.Series(y).value_counts()}")
    
    # Define feature groups (for reference)
    clinical_end = clinical_aligned.shape[1]
    patho_end = clinical_end + patho_aligned.shape[1]
    radio_start = patho_end
    radio_end = all_features.shape[1]
    
    print(f"\nFeature group indices:")
    print(f"Clinical features: 0 to {clinical_end-1} ({clinical_end} features)")
    print(f"Pathological features: {clinical_end} to {patho_end-1} ({patho_end-clinical_end} features)")
    print(f"Radiomics features: {patho_end} to {radio_end-1} ({radio_end-patho_end} features)")
    
    return all_features, y, clinical_end, patho_end, radio_end

def identify_feature_types(all_features):
    """Identify categorical and numerical features"""
    
    categorical_features = []
    numerical_features = []
    
    for col in all_features.columns:
        unique_vals = all_features[col].nunique()
        if unique_vals < 6:
            categorical_features.append(col)
        else:
            numerical_features.append(col)
    
    print(f"Categorical features: {len(categorical_features)}")
    print(f"Numerical features: {len(numerical_features)}")
    
    return categorical_features, numerical_features

def handle_missing_values(all_features, categorical_features, numerical_features):
    """Handle missing values in the dataset"""
    print("Handling missing values...")
    
    missing_values = all_features.isnull().sum()
    features_with_missing = missing_values[missing_values > 0]
    
    if len(features_with_missing) > 0:
        print(f"Features with missing values: {features_with_missing}")
        
        for col in all_features.columns:
            if col in categorical_features:
                mode_val = all_features[col].mode()
                if len(mode_val) > 0:
                    all_features[col].fillna(mode_val[0], inplace=True)
                else:
                    all_features[col].fillna(0, inplace=True)
            else:
                all_features[col].fillna(all_features[col].median(), inplace=True)
        
        print("Missing values filled.")
    else:
        print("No missing values found.")
    
    return all_features

def adaptive_pca(radio_features, variance_threshold=PCA_VARIANCE_THRESHOLD):
    """Apply adaptive PCA to radiomics features"""
    print("Applying adaptive PCA to radiomics features...")
    print(f"Original radiomics features shape: {radio_features.shape}")
    
    scaler = StandardScaler()
    radio_scaled = scaler.fit_transform(radio_features)
    
    pca = PCA(n_components=variance_threshold, random_state=RANDOM_STATE)
    radio_pca = pca.fit_transform(radio_scaled)
    
    n_components = radio_pca.shape[1]
    explained_variance = np.sum(pca.explained_variance_ratio_)
    
    pca_feature_names = [f'PC{i+1}' for i in range(n_components)]
    radio_pca_df = pd.DataFrame(radio_pca, columns=pca_feature_names, index=radio_features.index)
    
    print(f"PCA explained variance ratio: {explained_variance:.3f}")
    print(f"Number of PCA components: {n_components}")
    print(f"PCA features shape: {radio_pca_df.shape}")
    
    return radio_pca_df, pca, scaler

def prepare_feature_combinations_with_pca(all_features, clinical_end, patho_end, radio_end):
    """Prepare different feature combinations with PCA applied to radiomics"""
    
    # Extract feature groups
    clinical_features = all_features.iloc[:, :clinical_end]
    patho_features = all_features.iloc[:, clinical_end:patho_end]
    radio_features = all_features.iloc[:, patho_end:radio_end]
    
    # Apply PCA to radiomics features
    radio_pca_df, pca_model, pca_scaler = adaptive_pca(radio_features)
    
    # Save PCA component formulas
    save_pca_formulas(pca_model, radio_features.columns.tolist(), 
                      pca_model.explained_variance_ratio_)
    
    # Create feature combinations
    feature_combinations = {
        'Clinical': clinical_features.copy(),
        'Clinical+Pathological': pd.concat([clinical_features, patho_features], axis=1).copy(),
        'Clinical+Radiomics': pd.concat([clinical_features, radio_pca_df], axis=1).copy(),
        'Clinical+Pathological+Radiomics': pd.concat([clinical_features, patho_features, radio_pca_df], axis=1).copy()
    }
    
    print("\nFeature combinations with PCA:")
    for name, features in feature_combinations.items():
        print(f"{name}: {features.shape[1]} features")
    
    return feature_combinations, pca_model, pca_scaler, clinical_features, patho_features, radio_features

def scale_features(feature_combinations, numerical_features):
    """Scale numerical features"""
    print("Scaling numerical features...")
    
    scaled_combinations = {}
    for name, features in feature_combinations.items():
        scaled_features = features.copy()
        if len(numerical_features) > 0:
            numerical_cols = [col for col in numerical_features if col in scaled_features.columns]
            if len(numerical_cols) > 0:
                scaler = StandardScaler()
                scaled_features[numerical_cols] = scaler.fit_transform(scaled_features[numerical_cols])
        scaled_combinations[name] = scaled_features
    
    return scaled_combinations

def get_model_configurations():
    """Get model configurations for training"""
    models = {
        'LogisticRegression': {
            'model': LogisticRegression(random_state=RANDOM_STATE, max_iter=1000),
            'params': LOGISTIC_REG_PARAMS,
            'use_random_search': False
        },
        'SVM': {
            'model': SVC(random_state=RANDOM_STATE, probability=True),
            'params': SVM_PARAMS,
            'use_random_search': False
        },
        'DecisionTree': {
            'model': DecisionTreeClassifier(random_state=RANDOM_STATE),
            'params': DECISION_TREE_PARAMS,
            'use_random_search': RANDOM_SEARCH_MODELS['DecisionTree']
        },
        'RandomForest': {
            'model': RandomForestClassifier(random_state=RANDOM_STATE),
            'params': RANDOM_FOREST_PARAMS,
            'use_random_search': RANDOM_SEARCH_MODELS['RandomForest']
        },
        'KNN': {
            'model': KNeighborsClassifier(),
            'params': KNN_PARAMS,
            'use_random_search': False
        },
        'NaiveBayes': {
            'model': GaussianNB(),
            'params': NAIVE_BAYES_PARAMS,
            'use_random_search': False
        },
        'XGBoost': {
            'model': XGBClassifier(random_state=RANDOM_STATE, eval_metric='logloss'),
            'params': XGBOOST_PARAMS,
            'use_random_search': RANDOM_SEARCH_MODELS['XGBoost']
        }
    }
    return models

def run_single_cv_iteration(X_train, X_test, y_train, y_test, feature_names, model_config, apply_feature_selection=True):
    """Run single cross-validation iteration"""
    if apply_feature_selection:
        selector = SelectFdr(score_func=f_classif, alpha=FDR_ALPHA)
        try:
            X_train_processed = selector.fit_transform(X_train, y_train)
            selected_indices = selector.get_support(indices=True)
            selected_features = [feature_names[i] for i in selected_indices]
            X_test_processed = selector.transform(X_test)
        except:
            selected_features = feature_names
            X_train_processed = X_train
            X_test_processed = X_test
    else:
        selected_features = feature_names
        X_train_processed = X_train
        X_test_processed = X_test
    
    inner_cv = StratifiedKFold(n_splits=N_INNER_SPLITS, shuffle=True, random_state=RANDOM_STATE)
    
    if len(model_config['params']) > 0:
        if model_config['use_random_search']:
            search = RandomizedSearchCV(
                model_config['model'], 
                model_config['params'], 
                cv=inner_cv, 
                scoring='roc_auc', 
                n_jobs=-1,
                n_iter=RANDOM_SEARCH_ITER,
                random_state=RANDOM_STATE
            )
        else:
            search = GridSearchCV(
                model_config['model'], 
                model_config['params'], 
                cv=inner_cv, 
                scoring='roc_auc', 
                n_jobs=-1
            )
        
        search.fit(X_train_processed, y_train)
        best_model = search.best_estimator_
    else:
        best_model = copy.deepcopy(model_config['model'])
        best_model.fit(X_train_processed, y_train)
    
    y_pred_proba = best_model.predict_proba(X_test_processed)[:, 1]
    y_pred = best_model.predict(X_test_processed)
    
    metrics = {
        'auc': roc_auc_score(y_test, y_pred_proba),
        'acc': accuracy_score(y_test, y_pred),
        'f1': f1_score(y_test, y_pred),
        'sens': recall_score(y_test, y_pred)
    }
    
    tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
    metrics['spec'] = tn / (tn + fp)
    
    return metrics, selected_features, best_model, y_pred_proba, y_test

def unified_nested_cv(X, y, feature_names, model_config):
    """Perform unified nested cross-validation"""
    outer_cv = RepeatedStratifiedKFold(
        n_splits=N_OUTER_SPLITS, 
        n_repeats=N_OUTER_REPEATS, 
        random_state=RANDOM_STATE
    )
    
    all_metrics = {
        'auc': [], 'acc': [], 'f1': [], 'sens': [], 'spec': []
    }
    all_selected_features = []
    all_y_true = []
    all_y_pred_proba = []
    
    print(f"Running unified nested CV: {N_OUTER_REPEATS} repeats × {N_OUTER_SPLITS} folds")
    
    for i, (train_idx, test_idx) in enumerate(outer_cv.split(X, y)):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        metrics, selected_features, _, y_pred_proba, y_test_vals = run_single_cv_iteration(
            X_train, X_test, y_train, y_test, feature_names, model_config, apply_feature_selection=True
        )
        
        all_selected_features.extend(selected_features)
        all_y_true.extend(y_test_vals)
        all_y_pred_proba.extend(y_pred_proba)
        
        for key in all_metrics:
            all_metrics[key].append(metrics[key])
    
    stable_features, feature_frequencies = find_stable_features(all_selected_features, feature_names)
    
    print(f"Found {len(stable_features)} stable features")
    
    stable_metrics = {
        'auc': [], 'acc': [], 'f1': [], 'sens': [], 'spec': []
    }
    stable_y_true = []
    stable_y_pred_proba = []
    
    stable_indices = [i for i, feature in enumerate(feature_names) if feature in stable_features]
    X_stable = X[:, stable_indices]
    
    for i, (train_idx, test_idx) in enumerate(outer_cv.split(X_stable, y)):
        X_train, X_test = X_stable[train_idx], X_stable[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        metrics, _, _, y_pred_proba, y_test_vals = run_single_cv_iteration(
            X_train, X_test, y_train, y_test, stable_features, model_config, apply_feature_selection=False
        )
        
        stable_y_true.extend(y_test_vals)
        stable_y_pred_proba.extend(y_pred_proba)
        
        for key in stable_metrics:
            stable_metrics[key].append(metrics[key])
    
    return all_metrics, stable_metrics, stable_features, feature_frequencies, stable_y_true, stable_y_pred_proba

def calculate_confidence_interval(data, confidence=0.95):
    """Calculate confidence interval"""
    n = len(data)
    mean = np.mean(data)
    std_err = np.std(data, ddof=1) / np.sqrt(n)
    h = std_err * t.ppf((1 + confidence) / 2, n - 1)
    return mean, mean - h, mean + h

def find_stable_features(selected_features_list, feature_names, threshold=FEATURE_FREQ_THRESHOLD):
    """Find stable features based on selection frequency"""
    feature_counts = Counter(selected_features_list)
    total_iterations = N_OUTER_REPEATS * N_OUTER_SPLITS
    
    stable_features = []
    feature_frequencies = {}
    
    for feature in feature_names:
        freq = feature_counts.get(feature, 0) / total_iterations
        feature_frequencies[feature] = freq
        if freq >= threshold:
            stable_features.append(feature)
    
    return stable_features, feature_frequencies

def save_stable_features_to_csv(feature_stability_results):
    """Process stable features data without saving CSV"""
    print(f"Processing stable features...")
    
    all_stable_features_data = {}
    
    for feature_name, stability_info in feature_stability_results.items():
        stable_features = stability_info['stable_features']
        feature_frequencies = stability_info['feature_frequencies']
        
        freq_df = pd.DataFrame({
            'Feature': list(feature_frequencies.keys()),
            'Frequency': list(feature_frequencies.values())
        }).sort_values('Frequency', ascending=False)
        
        freq_df['Stable'] = freq_df['Frequency'] >= FEATURE_FREQ_THRESHOLD
        
        all_stable_features_data[feature_name] = freq_df
        print(f"Processed {len(freq_df)} features for {feature_name}")
        
        stable_only_df = freq_df[freq_df['Stable']]
        print(f"Found {len(stable_only_df)} stable features for {feature_name}")
    
    return all_stable_features_data

def perform_statistical_tests(results):
    """Perform statistical tests to compare best models"""
    print("\nPERFORMING STATISTICAL TESTS")
    print("="*60)
    
    # Extract AUC scores for statistical comparison
    auc_data = {}
    for feature_name, result in results.items():
        auc_data[feature_name] = result['metrics_fixed']['auc']
    
    # Perform paired t-tests between all combinations
    feature_names = list(auc_data.keys())
    p_values = np.zeros((len(feature_names), len(feature_names)))
    
    print("\nPairwise AUC Comparisons (p-values):")
    print(f"{'':<30}", end="")
    for name in feature_names:
        print(f"{name:<15}", end="")
    print()
    
    for i, name1 in enumerate(feature_names):
        print(f"{name1:<30}", end="")
        for j, name2 in enumerate(feature_names):
            if i == j:
                p_values[i, j] = 1.0
                print(f"{'--':<15}", end="")
            else:
                t_stat, p_val = stats.ttest_rel(auc_data[name1], auc_data[name2])
                p_values[i, j] = p_val
                print(f"{p_val:.4f}{'*' if p_val < 0.05 else '':<14}", end="")
        print()
    
    # Find the best performing feature combination
    best_feature = max(results.keys(), key=lambda x: results[x]['auc_mean'])
    print(f"\nBest performing feature combination: {best_feature}")
    print(f"Best AUC: {results[best_feature]['auc_mean']:.3f}")
    
    # Compare best with others
    print("\nComparison with best model:")
    for feature_name in feature_names:
        if feature_name != best_feature:
            t_stat, p_val = stats.ttest_rel(auc_data[best_feature], auc_data[feature_name])
            significance = "SIGNIFICANT" if p_val < 0.05 else "not significant"
            print(f"{best_feature} vs {feature_name}: p = {p_val:.4f} ({significance})")
    
    return p_values, best_feature

def train_and_evaluate_with_nested_cv(feature_combinations, y):
    """Train and evaluate models with nested cross-validation"""
    print("NESTED CROSS-VALIDATION WITH FEATURE STABILITY ANALYSIS")
    
    results = {}
    feature_stability_results = {}
    prediction_data = {}
    
    for feature_name, features in feature_combinations.items():
        print(f"\nAnalyzing feature combination: {feature_name}")
        
        X = features.values
        feature_names = features.columns.tolist()
        
        models_config = get_model_configurations()
        feature_results = {}
        
        for model_name, config in models_config.items():
            print(f"Training {model_name}...")
            
            metrics_with_fs, metrics_fixed, stable_features, feature_frequencies, y_true, y_pred_proba = unified_nested_cv(
                X, y, feature_names, config
            )
            
            auc_mean, auc_ci_lower, auc_ci_upper = calculate_confidence_interval(metrics_fixed['auc'])
            sens_mean, sens_ci_lower, sens_ci_upper = calculate_confidence_interval(metrics_fixed['sens'])
            spec_mean, spec_ci_lower, spec_ci_upper = calculate_confidence_interval(metrics_fixed['spec'])
            
            feature_results[model_name] = {
                'metrics_with_fs': metrics_with_fs,
                'metrics_fixed': metrics_fixed,
                'stable_features': stable_features,
                'feature_frequencies': feature_frequencies,
                'auc_mean': auc_mean,
                'auc_ci': (auc_ci_lower, auc_ci_upper),
                'sens_mean': sens_mean,
                'sens_ci': (sens_ci_lower, sens_ci_upper),
                'spec_mean': spec_mean,
                'spec_ci': (spec_ci_lower, spec_ci_upper),
                'y_true': y_true,
                'y_pred_proba': y_pred_proba
            }
            
            print(f"AUC: {auc_mean:.3f} ({auc_ci_lower:.3f}-{auc_ci_upper:.3f})")
        
        if feature_results:
            best_model_name = max(feature_results.keys(), 
                                key=lambda x: feature_results[x]['auc_mean'])
            results[feature_name] = feature_results[best_model_name]
            results[feature_name]['best_model_name'] = best_model_name
            
            prediction_data[feature_name] = {
                'y_true': feature_results[best_model_name]['y_true'],
                'y_pred_proba': feature_results[best_model_name]['y_pred_proba']
            }
            
            feature_stability_results[feature_name] = {
                'feature_frequencies': feature_results[best_model_name]['feature_frequencies'],
                'stable_features': feature_results[best_model_name]['stable_features']
            }
            
            print(f"Best model: {best_model_name} (AUC: {feature_results[best_model_name]['auc_mean']:.3f})")
    
    # Perform statistical tests
    p_values, best_feature = perform_statistical_tests(results)
    
    stable_features_data = save_stable_features_to_csv(feature_stability_results)
    
    return results, feature_stability_results, stable_features_data, p_values, best_feature

def plot_feature_stability(feature_stability_results, stable_features_data):
    """Plot feature stability analysis"""
    print("Plotting feature stability...")
    
    figs = []
    
    for feature_name, stability_info in feature_stability_results.items():
        feature_frequencies = stability_info['feature_frequencies']
        stable_features = stability_info['stable_features']
        
        freq_df = stable_features_data[feature_name]
        
        plt.figure(figsize=(12, 8))
        bars = plt.bar(range(len(freq_df)), freq_df['Frequency'])
        
        for i, (_, row) in enumerate(freq_df.iterrows()):
            if row['Stable']:
                bars[i].set_color('red')
            else:
                bars[i].set_color('lightblue')
        
        plt.axhline(y=FEATURE_FREQ_THRESHOLD, color='r', linestyle='--', 
                   label=f'Stability threshold ({FEATURE_FREQ_THRESHOLD*100}%)')
        plt.xlabel('Features')
        plt.ylabel('Selection Frequency')
        plt.title(f'Feature Stability Analysis - {feature_name}\n({len(stable_features)} stable features)')
        plt.xticks(range(len(freq_df)), freq_df['Feature'], rotation=90)
        plt.legend()
        plt.tight_layout()
        
        figs.append((f"Feature_Stability_{feature_name}", plt.gcf()))
    
    return figs

def plot_performance_distribution(results):
    """Plot performance distribution across feature combinations"""
    print("Plotting performance distribution...")
    
    plt.figure(figsize=(12, 8))
    
    feature_names = list(results.keys())
    auc_data = [results[name]['metrics_fixed']['auc'] for name in feature_names]
    
    box_plot = plt.boxplot(auc_data, labels=feature_names, patch_artist=True)
    
    colors = ['lightblue', 'lightgreen', 'lightcoral', 'lightyellow']
    for patch, color in zip(box_plot['boxes'], colors):
        patch.set_facecolor(color)
    
    plt.ylabel('AUC')
    plt.title('Performance Distribution Across Feature Combinations')
    plt.grid(True, alpha=0.3)
    
    return plt.gcf()

def plot_statistical_comparison(p_values, feature_names):
    """Plot heatmap of statistical comparisons"""
    plt.figure(figsize=(10, 8))
    
    mask = np.eye(len(feature_names), dtype=bool)
    p_values_masked = np.ma.array(p_values, mask=mask)
    
    sns.heatmap(p_values, 
                xticklabels=feature_names, 
                yticklabels=feature_names,
                annot=True, 
                fmt=".4f",
                cmap="RdYlBu_r",
                center=0.05,
                cbar_kws={'label': 'p-value'})
    
    plt.title('Statistical Comparison of Feature Combinations\n(Paired t-tests on AUC scores)')
    plt.tight_layout()
    
    return plt.gcf()

def create_shap_plots(results, feature_combinations, y):
    """Create SHAP plots for model interpretation"""
    print("Creating SHAP plots...")
    
    shap_figures = []
    
    for feature_name, result in results.items():
        model_name = result['best_model_name']
        stable_features = result['stable_features']
        
        features_df = feature_combinations[feature_name]
        
        if stable_features:
            # Filter features to only stable ones
            available_stable = [f for f in stable_features if f in features_df.columns]
            features_df = features_df[available_stable]
        
        X = features_df.values
        feature_names = features_df.columns.tolist()
        
        print(f"Training final model for SHAP explanation: {feature_name}")
        
        models_config = get_model_configurations()
        model_config = models_config[model_name]
        
        X_final, y_final = X, y
        
        if len(model_config['params']) > 0:
            if model_config['use_random_search']:
                search = RandomizedSearchCV(
                    model_config['model'], 
                    model_config['params'], 
                    cv=StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE), 
                    scoring='roc_auc', 
                    n_jobs=-1,
                    n_iter=RANDOM_SEARCH_ITER,
                    random_state=RANDOM_STATE
                )
            else:
                search = GridSearchCV(
                    model_config['model'], 
                    model_config['params'], 
                    cv=StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE), 
                    scoring='roc_auc', 
                    n_jobs=-1
                )
            
            search.fit(X_final, y_final)
            final_model = search.best_estimator_
        else:
            final_model = copy.deepcopy(model_config['model'])
            final_model.fit(X_final, y_final)
        
        try:
            model_type = type(final_model).__name__
            
            if len(X_final) > SHAP_BACKGROUND_SAMPLES:
                background_data = shap.sample(X_final, SHAP_BACKGROUND_SAMPLES)
            else:
                background_data = X_final
            
            if model_type in ['LogisticRegression']:
                explainer = shap.LinearExplainer(final_model, X_final)
                shap_values = explainer.shap_values(X_final)
                
            elif model_type in ['DecisionTreeClassifier', 'RandomForestClassifier', 
                              'XGBClassifier']:
                explainer = shap.TreeExplainer(final_model)
                shap_values = explainer.shap_values(X_final)
                if isinstance(shap_values, list) and len(shap_values) == 2:
                    shap_values = shap_values[1]
                
            else:
                def model_predict(x):
                    return final_model.predict_proba(x)
                
                explainer = shap.KernelExplainer(model_predict, background_data)
                shap_values = explainer.shap_values(X_final, nsamples=SHAP_NSAMPLES)
                
                if isinstance(shap_values, list):
                    shap_values = np.array(shap_values)
                
                if len(shap_values.shape) == 3:
                    shap_values = shap_values[:, :, 1]
            
            if len(shap_values.shape) > 2:
                shap_values = shap_values.reshape(shap_values.shape[0], -1)
            
            if shap_values.shape != X_final.shape:
                min_features = min(shap_values.shape[1], X_final.shape[1])
                shap_values = shap_values[:, :min_features]
                X_final_display = X_final[:, :min_features]
                feature_names_display = feature_names[:min_features]
            else:
                X_final_display = X_final
                feature_names_display = feature_names
            
            plot_title = f'{model_name} - {feature_name}\n({len(stable_features)} stable features)'
            
            plt.figure(figsize=(10, 8))
            shap.summary_plot(shap_values, X_final_display, feature_names=feature_names_display, 
                            plot_type="bar", show=False, max_display=min(15, len(stable_features)))
            plt.title(f'SHAP Feature Importance - {plot_title}', fontsize=14)
            plt.tight_layout()
            shap_figures.append((f"SHAP_Bar_{feature_name}", plt.gcf()))
            
            plt.figure(figsize=(10, 8))
            shap.summary_plot(shap_values, X_final_display, feature_names=feature_names_display, 
                            show=False, max_display=min(15, len(stable_features)))
            plt.title(f'SHAP Summary - {plot_title}', fontsize=14)
            plt.tight_layout()
            shap_figures.append((f"SHAP_Beeswarm_{feature_name}", plt.gcf()))
            
            print(f"SHAP plots created for {feature_name}")
            
        except Exception as e:
            print(f"Error creating SHAP plots for {feature_name}: {e}")
            continue
    
    return shap_figures

def print_detailed_results(results, p_values, best_feature):
    """Print detailed results summary"""
    print("\nDETAILED RESULTS SUMMARY")
    print("=" * 100)
    
    header = f"{'Feature Combination':<25} {'Algorithm':<20} {'Stable Features':<15} {'AUC (95% CI)':<20} {'Sensitivity (95% CI)':<25} {'Specificity (95% CI)':<25}"
    print(header)
    print("-" * 120)
    
    for feature_name, result in results.items():
        model_name = result['best_model_name']
        n_stable_features = len(result['stable_features'])
        
        auc_str = f"{result['auc_mean']:.3f} ({result['auc_ci'][0]:.3f}-{result['auc_ci'][1]:.3f})"
        sens_str = f"{result['sens_mean']:.3f} ({result['sens_ci'][0]:.3f}-{result['sens_ci'][1]:.3f})"
        spec_str = f"{result['spec_mean']:.3f} ({result['spec_ci'][0]:.3f}-{result['spec_ci'][1]:.3f})"
        
        marker = " *" if feature_name == best_feature else ""
        row = f"{feature_name:<25} {model_name:<20} {n_stable_features:<15} {auc_str:<20} {sens_str:<25} {spec_str:<25}{marker}"
        print(row)
    
    print(f"\n* Best performing feature combination")

def save_results_to_csv(results, p_values, best_feature):
    """Save results to CSV files"""
    print("Saving results to CSV files...")
    
    # Create results table
    table_data = []
    
    for feature_name, result in results.items():
        model_name = result['best_model_name']
        n_stable_features = len(result['stable_features'])
        
        auc_str = f"{result['auc_mean']:.3f} ({result['auc_ci'][0]:.3f}-{result['auc_ci'][1]:.3f})"
        sens_str = f"{result['sens_mean']:.3f} ({result['sens_ci'][0]:.3f}-{result['sens_ci'][1]:.3f})"
        spec_str = f"{result['spec_mean']:.3f} ({result['spec_ci'][0]:.3f}-{result['spec_ci'][1]:.3f})"
        
        is_best = feature_name == best_feature
        
        table_data.append({
            'Feature_Combination': feature_name,
            'Algorithm': model_name,
            'Stable_Features': n_stable_features,
            'AUC': result['auc_mean'],
            'AUC_CI_Lower': result['auc_ci'][0],
            'AUC_CI_Upper': result['auc_ci'][1],
            'AUC_95_CI': auc_str,
            'Sensitivity': result['sens_mean'],
            'Sensitivity_CI_Lower': result['sens_ci'][0],
            'Sensitivity_CI_Upper': result['sens_ci'][1],
            'Sensitivity_95_CI': sens_str,
            'Specificity': result['spec_mean'],
            'Specificity_CI_Lower': result['spec_ci'][0],
            'Specificity_CI_Upper': result['spec_ci'][1],
            'Specificity_95_CI': spec_str,
            'Is_Best': is_best
        })
    
    # Convert to DataFrame and save
    results_df = pd.DataFrame(table_data)
    results_csv_path = './results/nested_cv_results.csv'
    results_df.to_csv(results_csv_path, index=False)
    print(f"Saved results to {results_csv_path}")
    
    # Save statistical comparison matrix
    feature_names = list(results.keys())
    p_values_df = pd.DataFrame(p_values, index=feature_names, columns=feature_names)
    p_values_csv_path = './results/nested_cv_statistical_comparison.csv'
    p_values_df.to_csv(p_values_csv_path)
    print(f"Saved statistical comparison to {p_values_csv_path}")
    
    return results_df, p_values_df

def save_figures_to_png(performance_fig, stability_figs, shap_figures, stats_fig):
    """Save all figures as PNG files to figures folder"""
    print("Saving figures as PNG files...")
    
    # Create figures directory if it doesn't exist
    os.makedirs('./figures', exist_ok=True)
    
    # Save performance distribution
    performance_path = './figures/performance_distribution.png'
    performance_fig.savefig(performance_path, dpi=300, bbox_inches='tight')
    plt.close(performance_fig)
    print(f"Saved performance distribution as {performance_path}")
    
    # Save statistical comparison
    stats_path = './figures/statistical_comparison.png'
    stats_fig.savefig(stats_path, dpi=300, bbox_inches='tight')
    plt.close(stats_fig)
    print(f"Saved statistical comparison as {stats_path}")
    
    # Save stability figures
    for name, fig in stability_figs:
        safe_name = name.replace('+', '_').replace(' ', '_')
        stability_path = f'./figures/{safe_name}.png'
        fig.savefig(stability_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
        print(f"Saved {name} as {stability_path}")
    
    # Save SHAP figures
    for name, fig in shap_figures:
        safe_name = name.replace('+', '_').replace(' ', '_')
        shap_path = f'./figures/{safe_name}.png'
        fig.savefig(shap_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
        print(f"Saved {name} as {shap_path}")
    
    print("All figures saved successfully to ./figures/ folder!")

def main():
    """Main function to run the complete analysis"""
    try:
        # Create output directories if they don't exist
        os.makedirs('./results', exist_ok=True)
        os.makedirs('./figures', exist_ok=True)
        
        print("="*60)
        print("MACHINE LEARNING ANALYSIS FOR RECURRENCE PREDICTION")
        print("="*60)
        
        # Step 1: Load and preprocess data
        all_features, y, clinical_end, patho_end, radio_end = load_and_preprocess_data()
        
        # Step 2: Identify feature types
        categorical_features, numerical_features = identify_feature_types(all_features)
        
        # Step 3: Handle missing values
        all_features = handle_missing_values(all_features, categorical_features, numerical_features)
        
        # Step 4: Prepare feature combinations with PCA
        feature_combinations, pca_model, pca_scaler, clinical_features, patho_features, radio_features = prepare_feature_combinations_with_pca(
            all_features, clinical_end, patho_end, radio_end
        )
        
        # Step 5: Scale features
        scaled_feature_combinations = scale_features(feature_combinations, numerical_features)
        
        # Step 6: Train and evaluate with nested cross-validation
        results, feature_stability_results, stable_features_data, p_values, best_feature = train_and_evaluate_with_nested_cv(
            scaled_feature_combinations, y
        )
        
        # Step 7: Create visualizations
        print("\nCreating visualizations...")
        performance_fig = plot_performance_distribution(results)
        
        # Add statistical comparison plot
        feature_names = list(results.keys())
        stats_fig = plot_statistical_comparison(p_values, feature_names)
        
        stability_figs = plot_feature_stability(feature_stability_results, stable_features_data)
        
        shap_figures = create_shap_plots(results, scaled_feature_combinations, y)
        
        # Step 8: Print detailed results
        print_detailed_results(results, p_values, best_feature)
        
        # Step 9: Save results to CSV
        save_results_to_csv(results, p_values, best_feature)
        
        # Step 10: Save figures to PNG
        save_figures_to_png(performance_fig, stability_figs, shap_figures, stats_fig)
        
        print("\n" + "="*60)
        print("ANALYSIS COMPLETED SUCCESSFULLY!")
        print("="*60)
        print(f"Results saved to: ./results/")
        print(f"Figures saved to: ./figures/")
        
    except Exception as e:
        print(f"Error in main analysis: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

MACHINE LEARNING ANALYSIS FOR RECURRENCE PREDICTION
Loading data...
Clinical data shape: (142, 15)
Pathological data shape: (142, 7)
Recurrence data shape: (142, 2)

Extracting radiomics features...
Selected 50 distinctive radiomics features
Radiomics features shape: (205, 50)
Common IDs across all datasets: 142
All features combined shape: (142, 72)
Target distribution: 1    86
0    56
Name: count, dtype: int64

Feature group indices:
Clinical features: 0 to 14 (15 features)
Pathological features: 15 to 21 (7 features)
Radiomics features: 22 to 71 (50 features)
Categorical features: 15
Numerical features: 57
Handling missing values...
No missing values found.
Applying adaptive PCA to radiomics features...
Original radiomics features shape: (142, 50)
PCA explained variance ratio: 0.952
Number of PCA components: 9
PCA features shape: (142, 9)
Saving PCA component formulas...
PCA component formulas saved to ./results/pca_components_formulas.txt

Feature combinations with PCA:
Clinical: 1