In [None]:
# train_model_with_cross_validation.py
# Purpose: Train a 1D-CNN model with 5-fold stratified cross-validation to improve robustness.
# Notes: StratifiedKFold is used to preserve class balance across folds.

import pandas as pd
import numpy as np
import tensorflow as tf
import random
import os
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Input, Conv1D, MaxPooling1D, Flatten
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, roc_auc_score
import matplotlib.pyplot as plt
import json
from collections import defaultdict
import joblib

In [None]:
# --- Step 0: Global settings ---
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
# --- File path settings ---
BASE_DIR = "D:/结直肠癌肝转移Biomarker 诊断/新的策略/Autoencoder"
EXPRESSION_FILE = os.path.join(BASE_DIR, "expression_data_combat_corrected.csv")
METADATA_FILE = os.path.join(BASE_DIR, "metadata_combined.csv")
FUNC_GENES_FILE = os.path.join(BASE_DIR, "functional_genes_620.txt")
OUTPUT_DIR = "D:/temp_output_cv"  # cross-validation output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("--- Cross-validated CNN training pipeline ---")

In [None]:
# --- Step 1: Load data ---
print('--- Step 1: Load batch-corrected cleaned data ---')

expression_data = pd.read_csv(EXPRESSION_FILE, index_col=0)
metadata = pd.read_csv(METADATA_FILE, index_col=0)
metadata = metadata.reindex(expression_data.index)

print(f"Loaded expression data: {expression_data.shape}")
print(f"Loaded metadata: {metadata.shape}")

In [None]:
# --- Step 2: Select functional genes ---
print('\n--- Step 2: Select functional genes ---')

with open(FUNC_GENES_FILE, 'r') as f:
    functional_genes = [line.strip() for line in f.readlines() if line.strip()]

available_functional_genes = [gene for gene in functional_genes if gene in expression_data.columns]
print(f"Functional genes found in expression data: {len(available_functional_genes)}")

X = expression_data[available_functional_genes]
y_raw = metadata['group']

In [None]:
# Encode labels
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(y_raw)
print(f"Label encoding mapping: {dict(zip(label_encoder.classes_, range(len(label_encoder.classes_))))}")


In [None]:
# --- Step 3: Define model builder ---
def create_cnn_model(input_shape):
    """Create a 1D-CNN model"""
    model = Sequential([
        Input(shape=input_shape),
        Conv1D(filters=32, kernel_size=5, activation='relu'),
        MaxPooling1D(pool_size=2),
        Dropout(0.3),
        Conv1D(filters=64, kernel_size=5, activation='relu'),
        MaxPooling1D(pool_size=2),
        Dropout(0.4),
        Flatten(),
        Dense(64, activation='relu'),
        Dropout(0.5),
        Dense(1, activation='sigmoid')
    ])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    return model

In [None]:
# --- Step 4: Cross-validation setup ---
print('\n--- Step 4: Start 5-fold cross-validation training ---')

# First perform an overall train-test split, reserving 20% as a final test set
X_temp, X_final_test, y_temp, y_final_test = train_test_split(
    X, y, test_size=0.2, random_state=SEED, stratify=y
)

print(f"CV pool: {X_temp.shape[0]} samples")
print(f"Final test set: {X_final_test.shape[0]} samples")

# Configure 5-fold stratified CV
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)

# Store per-fold results
cv_results = {
    'fold': [],
    'train_acc': [],
    'val_acc': [],
    'train_loss': [],
    'val_loss': [],
    'test_acc': [],
    'test_precision': [],
    'test_recall': [],
    'test_auc': []
}

# Store all models and scalers
models = []
scalers = []
fold_histories = []

In [None]:
# --- Step 5: Run cross-validation ---
for fold, (train_idx, val_idx) in enumerate(kfold.split(X_temp, y_temp), 1):
    print(f"\n=== Fold {fold} training ===")
    
    # Data split
    X_train_fold = X_temp.iloc[train_idx]
    X_val_fold = X_temp.iloc[val_idx]
    y_train_fold = y_temp[train_idx]
    y_val_fold = y_temp[val_idx]
    
    print(f"Train: {len(X_train_fold)} samples, Val: {len(X_val_fold)} samples")
    
    # Standardize
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train_fold)
    X_val_scaled = scaler.transform(X_val_fold)
    
    # Reshape for 1D-CNN
    X_train_cnn = np.expand_dims(X_train_scaled, axis=-1)
    X_val_cnn = np.expand_dims(X_val_scaled, axis=-1)
    
    # Build model
    model = create_cnn_model((X_train_cnn.shape[1], 1))
    # Callbacks
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=0
    )
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        verbose=0,
        min_lr=1e-6
    )
    
    # Train model (epochs can be adjusted to save time)
    history = model.fit(
        X_train_cnn, y_train_fold,
        epochs=50,  # reduce epochs if needed
        batch_size=32,
        validation_data=(X_val_cnn, y_val_fold),
        callbacks=[early_stopping, reduce_lr],
        verbose=0
    )

    # Record training results
    final_train_acc = history.history['accuracy'][-1]
    final_val_acc = history.history['val_accuracy'][-1]
    final_train_loss = history.history['loss'][-1]
    final_val_loss = history.history['val_loss'][-1]
    
    # Evaluate on validation set
    y_val_pred_proba = model.predict(X_val_cnn, verbose=0).flatten()
    y_val_pred = (y_val_pred_proba > 0.5).astype(int)
    
    val_acc = accuracy_score(y_val_fold, y_val_pred)
    val_precision = precision_score(y_val_fold, y_val_pred)
    val_recall = recall_score(y_val_fold, y_val_pred)
    val_auc = roc_auc_score(y_val_fold, y_val_pred_proba)
    
    # Store results
    cv_results['fold'].append(fold)
    cv_results['train_acc'].append(final_train_acc)
    cv_results['val_acc'].append(final_val_acc)
    cv_results['train_loss'].append(final_train_loss)
    cv_results['val_loss'].append(final_val_loss)
    cv_results['test_acc'].append(val_acc)
    cv_results['test_precision'].append(val_precision)
    cv_results['test_recall'].append(val_recall)
    cv_results['test_auc'].append(val_auc)

    # Save model and scaler
    models.append(model)
    scalers.append(scaler)
    fold_histories.append(history)
    
    print(f"Validation - Acc: {val_acc:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, AUC: {val_auc:.4f}")


In [None]:
# --- Step 6: Cross-validation summary ---
print('\n=== 5-fold Cross-validation Summary ===')
cv_df = pd.DataFrame(cv_results)

print('Per-fold results:')
print(cv_df.round(4))

print('\nAverage performance metrics:')
mean_results = {
    'Mean_Accuracy': np.mean(cv_results['test_acc']),
    'Std_Accuracy': np.std(cv_results['test_acc']),
    'Mean_Precision': np.mean(cv_results['test_precision']),
    'Std_Precision': np.std(cv_results['test_precision']),
    'Mean_Recall': np.mean(cv_results['test_recall']),
    'Std_Recall': np.std(cv_results['test_recall']),
    'Mean_AUC': np.mean(cv_results['test_auc']),
    'Std_AUC': np.std(cv_results['test_auc'])
}

for metric, value in mean_results.items():
    print(f"{metric}: {value:.4f}")

In [None]:
# Optional: plot loss curves for each fold
plt.figure(figsize=(12, 8))
for i, history in enumerate(fold_histories, 1):
    plt.plot(history.history['loss'], label=f'Fold {i} Train Loss')
    plt.plot(history.history['val_loss'], linestyle='--', label=f'Fold {i} Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss vs Epochs for Each Fold')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# --- Step 7: Final test set evaluation (ensemble) ---
print('\n=== Final test set evaluation (ensemble prediction) ===')

final_predictions = []
for i, (model, scaler) in enumerate(zip(models, scalers)):
    X_final_scaled = scaler.transform(X_final_test)
    X_final_cnn = np.expand_dims(X_final_scaled, axis=-1)
    pred_proba = model.predict(X_final_cnn, verbose=0).flatten()
    final_predictions.append(pred_proba)

# Mean predicted probabilities
ensemble_pred_proba = np.mean(final_predictions, axis=0)
ensemble_pred = (ensemble_pred_proba > 0.5).astype(int)

# Final test set performance
final_acc = accuracy_score(y_final_test, ensemble_pred)
final_precision = precision_score(y_final_test, ensemble_pred)
final_recall = recall_score(y_final_test, ensemble_pred)
final_auc = roc_auc_score(y_final_test, ensemble_pred_proba)

print(f"Final test set - Accuracy: {final_acc:.4f}")
print(f"Final test set - Precision: {final_precision:.4f}")
print(f"Final test set - Recall: {final_recall:.4f}")
print(f"Final test set - AUC: {final_auc:.4f}")

print('\nDetailed classification report:')
print(classification_report(y_final_test, ensemble_pred, target_names=label_encoder.classes_))

In [None]:
# --- Step 8: Save results ---
print('\n--- Step 8: Save cross-validation results ---')

# Save per-fold CV results
cv_df.to_csv(os.path.join(OUTPUT_DIR, "cross_validation_results.csv"), index=False)

# Save summary results
summary_results = {
    **mean_results,
    'Final_Test_Accuracy': final_acc,
    'Final_Test_Precision': final_precision,
    'Final_Test_Recall': final_recall,
    'Final_Test_AUC': final_auc,
    'Total_Samples': len(X),
    'Features_Used': len(available_functional_genes),
    'CV_Folds': 5
}

with open(os.path.join(OUTPUT_DIR, "cv_summary_results.json"), 'w') as f:
    json.dump(summary_results, f, indent=4)


In [None]:
# --- Save ensemble models (5 models and scalers) ---
ensemble_model_dir = os.path.join(OUTPUT_DIR, "ensemble_models")
os.makedirs(ensemble_model_dir, exist_ok=True)

for i, (model, scaler) in enumerate(zip(models, scalers), 1):
    model.save(os.path.join(ensemble_model_dir, f"model_fold{i}.keras"))
    joblib.dump(scaler, os.path.join(ensemble_model_dir, f"scaler_fold{i}.pkl"))

# Save label encoder
joblib.dump(label_encoder, os.path.join(ensemble_model_dir, "label_encoder.pkl"))

print(f"Saved 5 k-fold models and scalers to: {ensemble_model_dir}")
print("For new sample prediction, run all 5 models and average predicted probabilities for ensemble result.")

# Save list of used genes
with open(os.path.join(OUTPUT_DIR, "used_functional_genes_cv.txt"), 'w') as f:
    for gene in available_functional_genes:
        f.write(f"{gene}\n")

In [None]:
# --- Save final test set ensemble predictions to CSV ---
test_results_df = pd.DataFrame({
    'Sample_ID': X_final_test.index,
    'True_Label': y_final_test,
    'Pred_Probability': ensemble_pred_proba,
    'Pred_Label': ensemble_pred
})
test_results_path = os.path.join(OUTPUT_DIR, "test_predictions_with_labels.csv")
test_results_df.to_csv(test_results_path, index=False)
print(f"Test set ensemble predictions saved: {test_results_path}")

In [None]:
# Polished CV training visualization and save to PDF
fig, axs = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Final Model Training History', fontsize=18, fontweight='bold')

# 1. Loss per fold
axs[0, 0].plot(range(1, len(fold_histories)+1), cv_results['train_loss'], 'o-', color='#d62728', label='Training Loss', linewidth=2, markersize=7)
axs[0, 0].plot(range(1, len(fold_histories)+1), cv_results['val_loss'], 's--', color='#1f77b4', label='Validation Loss', linewidth=2, markersize=7)
axs[0, 0].set_xlabel('Fold', fontsize=12)
axs[0, 0].set_ylabel('Loss', fontsize=12)
axs[0, 0].set_title('Model Loss (per fold)', fontsize=14)
axs[0, 0].legend(fontsize=11)
axs[0, 0].grid(True, alpha=0.3)

# 2. Accuracy per fold
axs[0, 1].plot(range(1, len(fold_histories)+1), cv_results['train_acc'], 'o-', color='#2ca02c', label='Training Accuracy', linewidth=2, markersize=7)
axs[0, 1].plot(range(1, len(fold_histories)+1), cv_results['val_acc'], 's--', color='#ff7f0e', label='Validation Accuracy', linewidth=2, markersize=7)
axs[0, 1].set_xlabel('Fold', fontsize=12)
axs[0, 1].set_ylabel('Accuracy', fontsize=12)
axs[0, 1].set_title('Model Accuracy (per fold)', fontsize=14)
axs[0, 1].legend(fontsize=11)
axs[0, 1].grid(True, alpha=0.3)

# 3. AUC per fold
axs[1, 0].plot(range(1, len(fold_histories)+1), cv_results['test_auc'], 'd-', color='#9467bd', label='Validation AUC', linewidth=2, markersize=7)
axs[1, 0].set_xlabel('Fold', fontsize=12)
axs[1, 0].set_ylabel('AUC', fontsize=12)
axs[1, 0].set_title('Model AUC (per fold)', fontsize=14)
axs[1, 0].legend(fontsize=11)
axs[1, 0].grid(True, alpha=0.3)

# 4. Precision & Recall per fold
axs[1, 1].plot(range(1, len(fold_histories)+1), cv_results['test_precision'], '^-', color='#8c564b', label='Validation Precision', linewidth=2, markersize=7)
axs[1, 1].plot(range(1, len(fold_histories)+1), cv_results['test_recall'], 'v--', color='#e377c2', label='Validation Recall', linewidth=2, markersize=7)
axs[1, 1].set_xlabel('Fold', fontsize=12)
axs[1, 1].set_ylabel('Value', fontsize=12)
axs[1, 1].set_title('Model Precision & Recall (per fold)', fontsize=14)
axs[1, 1].legend(fontsize=11)
axs[1, 1].grid(True, alpha=0.3)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
pdf_path = os.path.join(OUTPUT_DIR, "cross_validation_training_curves.pdf")
plt.savefig(pdf_path, format='pdf', bbox_inches='tight')
plt.show()
print(f"Training curves saved to PDF: {pdf_path}")

print(f"\n🎉 5-fold cross-validation completed! All results saved to: {OUTPUT_DIR}")
print(f"CV mean AUC: {np.mean(cv_results['test_auc']):.4f} ± {np.std(cv_results['test_auc']):.4f}")
print(f"Final test set AUC: {final_auc:.4f}")