In [None]:
# External validation using cross-validated CNN ensemble
# Purpose:
# 1. Load the ensemble of 5 CV-trained models
# 2. Predict TRS scores on an independent CRLM validation cohort
# 3. Evaluate performance and produce a detailed report

import pandas as pd
import numpy as np
import tensorflow as tf
import random
import os
import joblib
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score, precision_score, recall_score, confusion_matrix, classification_report
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import json
import warnings
warnings.filterwarnings('ignore')

print("--- External validation of CRLM dataset using cross-validated ensemble models ---")

In [None]:
# --- Step 0: Global settings ---
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

print("--- Random seed set to 42 ---")

# Set fonts (include Chinese font fallback if needed)
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False

In [None]:
# --- Step 1: File paths ---
print("\n--- Step 1: Configure file paths ---")

BASE_DIR = "D:/结直肠癌肝转移Biomarker 诊断/新的策略/Autoencoder"

# Cross-validation model output paths
CV_OUTPUT_DIR = "D:/temp_output_cv"  # cross-validation root output
ENSEMBLE_MODEL_DIR = os.path.join(CV_OUTPUT_DIR, "ensemble_models")  # directory containing 5 models

# Validation data path
VALIDATION_DATA_DIR = os.path.join(BASE_DIR, "validation_datasets")

# Results output path
OUTPUT_DIR = os.path.join(BASE_DIR, "crlm_validation_with_cv_model")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Input file paths
CV_LABEL_ENCODER_FILE = os.path.join(ENSEMBLE_MODEL_DIR, "label_encoder.pkl")
CV_GENES_FILE = os.path.join(CV_OUTPUT_DIR, "used_functional_genes_cv.txt")
CRLM_DATA_FILE = os.path.join(VALIDATION_DATA_DIR, "dat_crlm.csv")

print(f"Ensemble model directory: {ENSEMBLE_MODEL_DIR}")
print(f"CRLM validation data: {CRLM_DATA_FILE}")
print(f"Results output directory: {OUTPUT_DIR}")

# Check existence
required_files = [CV_LABEL_ENCODER_FILE, CV_GENES_FILE, CRLM_DATA_FILE]
if not os.path.exists(ENSEMBLE_MODEL_DIR):
    print(f"ERROR: Ensemble model directory not found: {ENSEMBLE_MODEL_DIR}")
    print("Please run the cross-validation training script to generate the model files.")
else:
    print("Ensemble model directory exists.")

missing_files = [f for f in required_files if not os.path.exists(f)]

if missing_files:
    print("ERROR: The following required files are missing:")
    for f in missing_files:
        print(f"  - {f}")
    print("Please check the paths or re-run the training script.")
else:
    print("All required files are present.")

In [None]:
# --- Step 2: Load ensemble models and preprocessors ---
print("\n--- Step 2: Load ensemble models and associated preprocessors ---")

try:
    models = []
    scalers = []
    print("Loading models:")
    for i in range(1, 6):
        model_path = os.path.join(ENSEMBLE_MODEL_DIR, f"model_fold{i}.keras")
        scaler_path = os.path.join(ENSEMBLE_MODEL_DIR, f"scaler_fold{i}.pkl")
        if not os.path.exists(model_path) or not os.path.exists(scaler_path):
            raise FileNotFoundError(f"Model or scaler for fold {i} not found: {model_path} or {scaler_path}")
        models.append(tf.keras.models.load_model(model_path))
        scalers.append(joblib.load(scaler_path))
        print(f"  - Loaded model and scaler for fold {i}")

    print(f"Loaded {len(models)} cross-validation models successfully.")

    # Load label encoder
    cv_label_encoder = joblib.load(CV_LABEL_ENCODER_FILE)
    print("Label encoder loaded successfully.")

    # Load gene list used for modeling
    with open(CV_GENES_FILE, 'r') as f:
        cv_model_genes = [line.strip() for line in f.readlines() if line.strip()]
    print(f"Functional gene list loaded: {len(cv_model_genes)} genes")

    print(f"Model input shape: {models[0].input_shape}")
    print(f"Label classes: {cv_label_encoder.classes_}")

except Exception as e:
    print(f"ERROR: Failed to load models or resources: {e}")
    raise

In [None]:
# --- Step 3: Load and preprocess CRLM validation data ---
print("\n--- Step 3: Load and preprocess CRLM validation data ---")

crlm_data = pd.read_csv(CRLM_DATA_FILE, index_col=0)
print(f"CRLM data shape: {crlm_data.shape}")

# Check label column
label_col = 'status'
if label_col not in crlm_data.columns:
    print(f"ERROR: Label column '{label_col}' not found in CRLM data.")
    print(f"Available columns: {list(crlm_data.columns)}")
    raise SystemExit(1)

# Prepare labels (map to binary, consistent with training)
y_crlm_raw = crlm_data[label_col]
print("Original label distribution:")
print(y_crlm_raw.value_counts())

# Map labels: metastasis -> 1, else 0
y_crlm = y_crlm_raw.apply(lambda x: 1 if 'metastasis' in str(x).lower() else 0)
print("\nEncoded label distribution:")
print(y_crlm.value_counts())

In [None]:
# --- Step 4: Prepare feature matrix ---
print("\n--- Step 4: Prepare feature matrix ---")

available_genes = [gene for gene in cv_model_genes if gene in crlm_data.columns]
missing_genes = [gene for gene in cv_model_genes if gene not in crlm_data.columns]

print(f"Of {len(cv_model_genes)} required model genes:")
print(f"  - Found in CRLM data: {len(available_genes)} ({len(available_genes)/len(cv_model_genes)*100:.1f}%)")
print(f"  - Missing: {len(missing_genes)} ({len(missing_genes)/len(cv_model_genes)*100:.1f}%)")

if len(missing_genes) > 0:
    print("\nWARNING: Missing genes will be filled with zeros.")
    print(f"First 10 missing genes: {missing_genes[:10]}")

# Build feature matrix in same column order as used in training
X_crlm = pd.DataFrame(index=crlm_data.index, columns=cv_model_genes)

# Fill available gene expression
X_crlm[available_genes] = crlm_data[available_genes]

# Fill missing genes with zeros
if missing_genes:
    X_crlm[missing_genes] = 0

print(f"Feature matrix built, shape: {X_crlm.shape}")
print(f"Contains any NA: {X_crlm.isnull().any().any()}")


In [None]:
# --- Step 5: Data formatting for CNN ---
print("\n--- Step 5: Data formatting for CNN ---")

# Note: scaling will be applied per-model using each model's scaler
X_crlm_cnn_base = X_crlm.to_numpy()
X_crlm_cnn = np.expand_dims(X_crlm_cnn_base, axis=-1)
print(f"CNN input shape prepared: {X_crlm_cnn.shape}")


In [None]:
# --- Step 6: Ensemble predictions using the 5 models ---
print("\n--- Step 6: Ensemble prediction with 5 CV models ---")

all_predictions = []
for i, (model, scaler) in enumerate(zip(models, scalers)):
    # apply scaler corresponding to this model
    X_crlm_scaled = scaler.transform(X_crlm)
    X_crlm_cnn_scaled = np.expand_dims(X_crlm_scaled, axis=-1)
    pred_proba = model.predict(X_crlm_cnn_scaled, verbose=0).flatten()
    all_predictions.append(pred_proba)
    print(f"  - Model fold {i+1} prediction completed")

# Average predicted probabilities
trs_scores_raw = np.mean(all_predictions, axis=0)
print(f"\nRaw ensemble score range: [{trs_scores_raw.min():.4f}, {trs_scores_raw.max():.4f}]")

# Apply semantic correction so that higher TRS corresponds to higher metastasis risk
trs_scores_final = 1 - trs_scores_raw
print(f"Semantic-corrected TRS range: [{trs_scores_final.min():.4f}, {trs_scores_final.max():.4f}]")

print("Ensemble TRS scoring completed.")

In [None]:
# --- Step 7: Performance evaluation ---
print("\n--- Step 7: Evaluate predictive performance ---")

# ROC analysis
auc_score = roc_auc_score(y_crlm, trs_scores_final)
fpr, tpr, thresholds = roc_curve(y_crlm, trs_scores_final)

# Youden index to find optimal threshold
youden_scores = tpr - fpr
best_threshold_idx = np.argmax(youden_scores)
best_threshold = thresholds[best_threshold_idx]

# Binary predictions at optimal threshold
y_crlm_pred = (trs_scores_final >= best_threshold).astype(int)

# Metrics
accuracy = accuracy_score(y_crlm, y_crlm_pred)
precision = precision_score(y_crlm, y_crlm_pred)
recall = recall_score(y_crlm, y_crlm_pred)
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

# Confusion matrix
cm = confusion_matrix(y_crlm, y_crlm_pred)
tn, fp, fn, tp = cm.ravel()

sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

print("="*60)
print("Performance of cross-validated ensemble on CRLM dataset")
print("="*60)
print(f"AUC: {auc_score:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 score: {f1_score:.4f}")
print(f"Sensitivity: {sensitivity:.4f}")
print(f"Specificity: {specificity:.4f}")
print(f"Optimal threshold: {best_threshold:.4f}")
print(f"Youden index: {youden_scores[best_threshold_idx]:.4f}")

print("\nConfusion matrix (rows=true, cols=predicted):")
print("        Predicted")
print("       Primary  Metastasis")
print(f"True Primary   {tn:3d}      {fp:3d}")
print(f"True Metastasis{fn:3d}      {tp:3d}")

print("\nDetailed classification report:")
print(classification_report(y_crlm, y_crlm_pred, target_names=['Primary', 'Metastasis']))


In [None]:
# --- Step 8: Save predictions and metrics ---
print("\n--- Step 8: Save prediction results and performance summary ---")

results_df = pd.DataFrame({
    'Sample_ID': X_crlm.index,
    'True_Label_str': y_crlm_raw.values,
    'True_Label_int': y_crlm.values,
    'TRS_Score': trs_scores_final,
    'Predicted_Probability': trs_scores_final,
    'Predicted_Label_int': y_crlm_pred,
    'Predicted_Label_str': ['Metastasis' if pred == 1 else 'Primary' for pred in y_crlm_pred],
    'Classification_Correct': (y_crlm == y_crlm_pred).values,
})

performance_summary = {
    'Model_Type': 'Cross_Validation_Ensemble_CNN',
    'Validation_Dataset': 'CRLM',
    'Total_Samples': len(y_crlm),
    'Primary_Samples': int((y_crlm == 0).sum()),
    'Metastasis_Samples': int((y_crlm == 1).sum()),
    'Features_Used': len(cv_model_genes),
    'Available_Features': len(available_genes),
    'Missing_Features': len(missing_genes),
    'Missing_Feature_Ratio': len(missing_genes) / len(cv_model_genes),
    'Performance_Metrics': {
        'AUC': float(auc_score),
        'Accuracy': float(accuracy),
        'Precision': float(precision),
        'Recall': float(recall),
        'F1_Score': float(f1_score),
        'Sensitivity': float(sensitivity),
        'Specificity': float(specificity),
        'Best_Threshold': float(best_threshold),
        'Youden_Index': float(youden_scores[best_threshold_idx])
    },
    'Confusion_Matrix': {'TN': int(tn), 'FP': int(fp), 'FN': int(fn), 'TP': int(tp)}
}

summary_file = os.path.join(OUTPUT_DIR, "crlm_validation_performance_summary.json")
with open(summary_file, 'w') as f:
    json.dump(performance_summary, f, indent=4)
print(f"Performance summary saved: {summary_file}")

print("\nPrediction preview:")
print(results_df[['Sample_ID', 'True_Label_str', 'TRS_Score', 'Predicted_Label_str', 'Classification_Correct']].head(10))

In [None]:
# --- Step 10: Gene correlation with TRS (Pearson) ---
print("\n--- Step 10: Compute gene correlations with TRS (Pearson r and p-value) ---")

correlations = []
for gene in X_crlm.columns:
    try:
        corr, p_value = pearsonr(X_crlm[gene].astype(float), trs_scores_final)
    except Exception:
        corr, p_value = (np.nan, np.nan)
    correlations.append((gene, corr, p_value))

corr_df = pd.DataFrame(correlations, columns=['Gene', 'Correlation', 'P_Value'])
corr_df.dropna(inplace=True)

print(f"Gene correlation calculation completed: {len(corr_df)} genes analyzed.")
print("Top 5 positively correlated genes:")
print(corr_df.sort_values('Correlation', ascending=False).head())
print("\nTop 5 negatively correlated genes:")
print(corr_df.sort_values('Correlation', ascending=True).head())


In [None]:
# --- Top50 gene correlation visualization (25 pos + 25 neg) ---
print("\nGenerating gene-correlation visualization...")

top_pos_corr = corr_df.sort_values('Correlation', ascending=False).head(25)
top_neg_corr = corr_df.sort_values('Correlation', ascending=True).head(25)
top_genes_vis = pd.concat([top_pos_corr, top_neg_corr]).sort_values('Correlation')

plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(12, 8))

# Red = risk gene, Blue = protective gene
colors = ['#c23616' if c > 0 else '#192a56' for c in top_genes_vis['Correlation']]
ax.barh(top_genes_vis['Gene'], top_genes_vis['Correlation'], color=colors)

ax.set_xlabel('Pearson Correlation with TRS (Metastasis Risk Score)', fontsize=12, fontweight='bold')
ax.set_ylabel('Gene', fontsize=12, fontweight='bold')
ax.set_title('Top 50 Genes Most Correlated with TRS\nRed=Risk gene, Blue=Protective gene\nHigher TRS = Higher mCRC Risk', fontsize=15, fontweight='bold', pad=20)
ax.tick_params(axis='y', labelsize=11)
ax.grid(axis='x', linestyle='--', alpha=0.6)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

In [None]:
# Add numeric labels
for i, (value, name) in enumerate(zip(top_genes_vis['Correlation'], top_genes_vis['Gene'])):
    ax.text(value + 0.01 if value > 0 else value - 0.01, i, f'{value:.3f}',
            ha='left' if value > 0 else 'right',
            va='center',
            fontweight='medium',
            fontsize=10)

plt.tight_layout()
CORRELATION_PLOT_FILE = os.path.join(OUTPUT_DIR, "top50_gene_correlations.png")
plt.savefig(CORRELATION_PLOT_FILE, dpi=300, bbox_inches='tight')
print(f"Correlation plot saved to: {CORRELATION_PLOT_FILE}")
plt.show()

corr_table_file = os.path.join(OUTPUT_DIR, "trs_gene_correlation in CRLM Samples.csv")
corr_df.to_csv(corr_table_file, index=False)
print(f"All gene correlations table saved to: {corr_table_file}")

In [None]:
# --- Step 11: Summary and completion ---
print("\n" + "="*80)
print("🎉 External validation of CV-ensemble model on CRLM dataset completed!")
print("="*80)

print("\nKey results summary:")
print(f"• Validation dataset: CRLM ({len(y_crlm)} samples)")
print(f"• Model type: Cross-validated ensemble CNN")
print("• Main performance metrics:")
print(f"  - AUC: {auc_score:.3f}")
print(f"  - Accuracy: {accuracy:.3f}")
print(f"  - Sensitivity: {sensitivity:.3f}")
print(f"  - Specificity: {specificity:.3f}")

print("\nGenerated files:")
print(f"• Prediction results: {os.path.basename(results_file)}")
print(f"• Performance summary: {os.path.basename(summary_file)}")
print(f"• Gene correlations table: {os.path.basename(corr_table_file)}")
print(f"• Top50 correlation plot: {os.path.basename(CORRELATION_PLOT_FILE)}")

print(f"\nAll outputs are saved to: {OUTPUT_DIR}")

print("\nModel assessment:")
if auc_score > 0.8:
    print("✅ Excellent: AUC > 0.8, strong discriminative ability.")
elif auc_score > 0.7:
    print("✅ Good: AUC > 0.7, reasonable discriminative ability.")
elif auc_score > 0.6:
    print("⚠️ Moderate: AUC > 0.6, limited discriminative ability.")
else:
    print("❌ Poor: AUC ≤ 0.6, limited discriminative ability.")

print("\nScript finished.")