In [None]:
from sklearn.model_selection import StratifiedKFold



# Parameters
k_folds = 5
num_epochs = 100
batch_size = 32

# Save metrics
fold_train_histories = []
fold_val_histories = []
fold_test_accuracies = []
fold_test_aucs = []


skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)

for fold, (train_idx, test_idx) in enumerate(skf.split(X, y)):
    print(f"\n--- Fold {fold+1}/{k_folds} ---")
    
   
    X_train_fold, X_test_fold = X[train_idx], X[test_idx]
    y_train_fold, y_test_fold = y[train_idx], y[test_idx]
    
    
    X_train_fold, X_val_fold, y_train_fold, y_val_fold = train_test_split(
        X_train_fold, y_train_fold, test_size=0.2, stratify=y_train_fold, random_state=fold)
    
    
    train_clean = tf.data.Dataset.from_tensor_slices((X_train_fold, y_train_fold))
    train_aug = train_clean.map(augment_fn)
    train_dataset = train_clean.concatenate(train_aug).shuffle(1024).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    val_dataset = tf.data.Dataset.from_tensor_slices((X_val_fold, y_val_fold)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    test_dataset = tf.data.Dataset.from_tensor_slices((X_test_fold, y_test_fold)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
   
    model = build_model()  #add your model here

    early_stop = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True) #Update if early stopping is implemented, otherwise comment out
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6) #Update if using adaptive learning rate,  otherwise comment out

    # Train
    history = model.fit(train_dataset,
                        validation_data=val_dataset,
                        epochs=num_epochs,
                        callbacks=[early_stop, reduce_lr],
                        verbose=1)

    # Save train and val history
    fold_train_histories.append(history.history['accuracy'])
    fold_val_histories.append(history.history['val_accuracy'])

    # Evaluate
    results = model.evaluate(test_dataset, verbose=0)
    fold_test_accuracies.append(results[1])
    fold_test_aucs.append(results[2])

    print(f"Fold {fold+1} Test Accuracy: {results[1]:.4f}, Test AUC: {results[2]:.4f}")


print("\n=== K-Fold Cross Validation Summary ===")
print(f"Average Test Accuracy: {np.mean(fold_test_accuracies):.4f} ± {np.std(fold_test_accuracies):.4f}")
print(f"Average Test AUC: {np.mean(fold_test_aucs):.4f} ± {np.std(fold_test_aucs):.4f}")
