# Random Forest Analyses

In [7]:
# Standard library
import os
import warnings
import logging
from itertools import combinations
import string
import random

# Scientific computing
import numpy as np
import pandas as pd
from numpy import array
import scipy
import scipy.stats as ss
from scipy import interp
from scipy.stats import wilcoxon, ttest_rel

# Visualization
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

# scikit-bio
from skbio.stats.distance import permanova

# BIOM format
import biom
from biom import load_table

# Scikit-learn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, confusion_matrix, classification_report,
    roc_curve, auc, RocCurveDisplay
)
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.preprocessing import label_binarize

# For confusion matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.metrics import classification_report, ConfusionMatrixDisplay

# Set overall styling for plots
sns.set_context("paper", font_scale=1.5)
sns.set_style("ticks")

In [8]:
# Set random seed for reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)

## Random Forest Classification Tasks

In [9]:
# Read in table at ASV level
biom_path = '../Data/Tables/Count_Tables/1_209766_feature_table.biom'

biom_tbl = load_table(biom_path)
df = pd.DataFrame(biom_tbl.to_dataframe().T)

# delete the prefix from the index
df.index = df.index.str.replace('15564.', '')

# Convert to relative abundance by dividing each row by its row sum
df_dense = df.div(df.sum(axis=1), axis=0)
df = pd.DataFrame(df_dense.values, index=df_dense.index, columns=df_dense.columns)  # Force dense

df

Unnamed: 0,GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCCGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCCGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTAATCGGAATTATTGGGCGTAAAGCGAGTGCAGACGGTTACTTAAGCCAGATGTGAAATCCCC,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAATTATTGGGCGTAAAGCGCGCGCAGGCGGTTTCTTAAGTCTGATGTGAAAGCCCC,GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATTTATTGGGCGTAAAGGGCTCGTAGGTGGTTGATCGCGTCGGAAGTGTAATCTTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTCCAAGCGTTAATCGGAATTACTGGGCGTAAAGCGTGCGCAGGCGGTTGTGCAAGACCGATGTGAAATCCCC,GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGATTTATTGGGCGTAAAGGGAGCGCAGGTGGTTTCTTAAGTCTGATGTGAAAGCCCA,GTGCCAGCCGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATTTATTGGGTTTAAAGGGAGCGTAGGCGGATTATTAAGTCAGTGGTGAAAGACGG,...,GTGCCAGCCGCCGCGGTAATACGTAGGGGGCAAGCGTTATCCGGATTTACTGGGTGTAAAGGGAGCGTAGACGGCGCAGCAAGTCTGATGTGAAAGGCAG,GTGCCAGCAGCCGCGGTAAGACAGAGGGTGCAAACGTTGCTCGGAATCACTGGGCGTAAAGGGCGTGTAGGCGGGAGAGAAAGTCGGGCGTGAAATCCCT,GTGCCAGCCGCGGTAATACGTAGGGGGCTAGCGTTGTCCGGAATCACTGGGCGTAAAGGGTTCGCAGGCGGAAATGCAAGTCAGGTGTAAAAGGCAGTAG,GTGCCAGCAGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCCGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTATCCGGAATCATTGGGTTTAAAGGGTCCGCAGGCGGATTTATAAGTCAGTGGTGAAAGCCTA,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCGAGCGTTGTCCGGAATTACTGGGTGTAAAGGGCGTGTAGGCGGGAAGGTAAGTCAGATGTGAAATACCG,GTGCCAGCCGCCGCGGTAATACGGAGGATGCGAGCGTTATTCGGAATCATTGGGTTTAAAGGGTCTGTAGGCGGGCTATTAAGTCAGAGGTGAAAGGTTT,GTGCCAGCCGCCGCGGTAAGACGAAGGGGGCTAGCGTTGTTCGGAATTACTGGGCGTAAAGCGCGTGCAGGCGGTTATCCAAGTCGGGTGTGAAAGCCTT,GTCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGT
900344,0.529602,0.328848,0.061356,0.044133,0.011841,0.008073,0.004306,0.004306,0.003229,0.001615,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900459,0.068445,0.061485,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900221,0.000744,0.000000,0.000000,0.000000,0.000000,0.000000,0.000541,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900570,0.058922,0.000000,0.000000,0.000000,0.001212,0.000000,0.001666,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900092,0.623945,0.342909,0.011852,0.006428,0.000603,0.000000,0.000000,0.000000,0.000000,0.001406,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9003972,0.115804,0.058794,0.001586,0.000000,0.002776,0.000000,0.072972,0.000000,0.000000,0.003569,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900097,0.011594,0.000000,0.000000,0.000000,0.000000,0.000000,0.015942,0.000000,0.000000,0.000000,...,0.003865,0.002415,0.000483,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900498,0.015707,0.017801,0.000000,0.000000,0.000000,0.000000,0.035602,0.000000,0.014660,0.000000,...,0.000000,0.000000,0.000000,0.015707,0.010471,0.008377,0.000000,0.000000,0.000000,0.000000
900276,0.000000,0.000000,0.041322,0.000000,0.000000,0.000000,0.207989,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.015152,0.004132,0.002755,0.001377


In [10]:
# Load the metadata
metadata_path = '../Metadata/16S_AD_South-Africa_metadata_subset.tsv'
metadata = pd.read_csv(metadata_path, sep='\t')

metadata['#sample-id'] = metadata['#sample-id'].str.replace('_', '')
# Set Sample-ID as the index for the metadata dataframe 
metadata = metadata.set_index('#sample-id')


# Create group column based on case_type to simplify group names
metadata['group'] = metadata['case_type'].map({
    'case-lesional_skin': 'skin-ADL',
    'case-nonlesional_skin': 'skin-ADNL', 
    'control-nonlesional_skin': 'skin-H',
    'case-anterior_nares': 'nares-AD',
    'control-anterior_nares': 'nares-H'
})

metadata

Unnamed: 0_level_0,PlateNumber,PlateLocation,i5,i5Sequence,i7,i7Sequence,identifier,Sequence,Plate ID,Well location,...,specimen,age_months,sex,enrolment_date,enrolment_season,hiv_exposure,hiv_status,household_size,o_scorad,group
#sample-id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ca009STL,1,A1,SA501,ATCGTACG,SA701,CGAGAGTT,SA701SA501,CGAGAGTT-ATCGTACG,1.010000e+21,A1,...,skin,24.0,male,4/16/2015,Autumn,Unexposed,negative,4.0,40,skin-ADL
900221,1,B1,SA502,ACTATCTG,SA701,CGAGAGTT,SA701SA502,CGAGAGTT-ACTATCTG,1.010000e+21,B1,...,skin,9.0,female,8/11/2015,Winter,Unexposed,negative,7.0,34,skin-ADL
Ca010EBL,1,C1,SA503,TAGCGAGT,SA701,CGAGAGTT,SA701SA503,CGAGAGTT-TAGCGAGT,1.010000e+21,C1,...,skin,24.0,female,11/20/2014,Spring,Unexposed,negative,7.0,21,skin-ADL
900460,1,D1,SA504,CTGCGTGT,SA701,CGAGAGTT,SA701SA504,CGAGAGTT-CTGCGTGT,1.010000e+21,D1,...,skin,18.0,female,9/23/2015,Spring,Unexposed,,4.0,40,skin-ADL
900051,1,E1,SA505,TCATCGAG,SA701,CGAGAGTT,SA701SA505,CGAGAGTT-TCATCGAG,1.010000e+21,E1,...,skin,31.0,male,4/21/2015,Autumn,Unexposed,negative,7.0,41,skin-ADL
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
900401,5,C12,SB503,AGAGTCAC,SB712,CGTAGCGA,SB712SB503,CGTAGCGA-AGAGTCAC,1.010000e+21,C12,...,skin,21.0,female,9/17/2015,Spring,Exposed,negative,12.0,38,skin-ADNL
900402,6,B4,SA502,ACTATCTG,SB704,TCTCTATG,SB704SA502,TCTCTATG-ACTATCTG,1.010000e+21,B4,...,nasal,21.0,,,,,,,,nares-AD
Ca006ONL,6,F1,SA506,CGTGAGTG,SB701,CTCGACTT,SB701SA506,CTCGACTT-CGTGAGTG,1.010000e+21,F1,...,skin,35.0,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,skin-ADL
Ca006ONNL,6,F2,SA506,CGTGAGTG,SB702,CGAAGTAT,SB702SA506,CGAAGTAT-CGTGAGTG,1.010000e+21,F2,...,skin,35.0,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,skin-ADNL


In [11]:
# Subset df to only the 462 samples in subsetted metadata
df = df[df.index.isin(metadata.index)]
# df = df[df.index.isin(metadata.index)].dropna()
df

Unnamed: 0,GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCCGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCCGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTAATCGGAATTATTGGGCGTAAAGCGAGTGCAGACGGTTACTTAAGCCAGATGTGAAATCCCC,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAATTATTGGGCGTAAAGCGCGCGCAGGCGGTTTCTTAAGTCTGATGTGAAAGCCCC,GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATTTATTGGGCGTAAAGGGCTCGTAGGTGGTTGATCGCGTCGGAAGTGTAATCTTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTCCAAGCGTTAATCGGAATTACTGGGCGTAAAGCGTGCGCAGGCGGTTGTGCAAGACCGATGTGAAATCCCC,GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGATTTATTGGGCGTAAAGGGAGCGCAGGTGGTTTCTTAAGTCTGATGTGAAAGCCCA,GTGCCAGCCGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATTTATTGGGTTTAAAGGGAGCGTAGGCGGATTATTAAGTCAGTGGTGAAAGACGG,...,GTGCCAGCCGCCGCGGTAATACGTAGGGGGCAAGCGTTATCCGGATTTACTGGGTGTAAAGGGAGCGTAGACGGCGCAGCAAGTCTGATGTGAAAGGCAG,GTGCCAGCAGCCGCGGTAAGACAGAGGGTGCAAACGTTGCTCGGAATCACTGGGCGTAAAGGGCGTGTAGGCGGGAGAGAAAGTCGGGCGTGAAATCCCT,GTGCCAGCCGCGGTAATACGTAGGGGGCTAGCGTTGTCCGGAATCACTGGGCGTAAAGGGTTCGCAGGCGGAAATGCAAGTCAGGTGTAAAAGGCAGTAG,GTGCCAGCAGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCCGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTATCCGGAATCATTGGGTTTAAAGGGTCCGCAGGCGGATTTATAAGTCAGTGGTGAAAGCCTA,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCGAGCGTTGTCCGGAATTACTGGGTGTAAAGGGCGTGTAGGCGGGAAGGTAAGTCAGATGTGAAATACCG,GTGCCAGCCGCCGCGGTAATACGGAGGATGCGAGCGTTATTCGGAATCATTGGGTTTAAAGGGTCTGTAGGCGGGCTATTAAGTCAGAGGTGAAAGGTTT,GTGCCAGCCGCCGCGGTAAGACGAAGGGGGCTAGCGTTGTTCGGAATTACTGGGCGTAAAGCGCGTGCAGGCGGTTATCCAAGTCGGGTGTGAAAGCCTT,GTCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGT
900344,0.529602,0.328848,0.061356,0.044133,0.011841,0.008073,0.004306,0.004306,0.003229,0.001615,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900459,0.068445,0.061485,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900221,0.000744,0.000000,0.000000,0.000000,0.000000,0.000000,0.000541,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900570,0.058922,0.000000,0.000000,0.000000,0.001212,0.000000,0.001666,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900092,0.623945,0.342909,0.011852,0.006428,0.000603,0.000000,0.000000,0.000000,0.000000,0.001406,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
900294,0.017937,0.010463,0.000000,0.000000,0.000000,0.000000,0.144993,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900097,0.011594,0.000000,0.000000,0.000000,0.000000,0.000000,0.015942,0.000000,0.000000,0.000000,...,0.003865,0.002415,0.000483,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900498,0.015707,0.017801,0.000000,0.000000,0.000000,0.000000,0.035602,0.000000,0.014660,0.000000,...,0.000000,0.000000,0.000000,0.015707,0.010471,0.008377,0.000000,0.000000,0.000000,0.000000
900276,0.000000,0.000000,0.041322,0.000000,0.000000,0.000000,0.207989,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.015152,0.004132,0.002755,0.001377


In [12]:
# =============================================================================
# UNIFIED RANDOM FOREST CONFIGURATION
# =============================================================================

# Define the RF configuration used throughout ALL analyses
RF_CONFIG = {
    'n_estimators': 1000,
    'random_state': 42,
    'n_jobs': -1
}

# Define which confounders to include
CONFOUNDER_COLS = ['age_months', 'sex', 'enrolment_season']

print("RANDOM FOREST CONFIGURATION")
print(f"RF Configuration: {RF_CONFIG}")
print(f"Confounders to include: {CONFOUNDER_COLS}")
print("="*80 + "\n")


# =============================================================================
# 1. CUSTOM GROUP-STRATIFIED K-FOLD FUNCTION
# =============================================================================

def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
    """Custom group-stratified k-fold splitting."""
    unique_groups = np.unique(groups)
    np.random.seed(random_state)
    np.random.shuffle(unique_groups)
    
    group_label_dist = {}
    for group in unique_groups:
        group_mask = groups == group
        group_y = y[group_mask]
        group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}
    
    folds = [[] for _ in range(n_splits)]
    fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
    
    sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)
    
    for group in sorted_groups:
        best_fold = 0
        min_imbalance = float('inf')
        
        for fold_idx in range(n_splits):
            temp_fold_dist = fold_label_dist[fold_idx].copy()
            for label, count in group_label_dist[group].items():
                temp_fold_dist[label] += count
            fold_size = sum(temp_fold_dist.values())
            proportions = [count / fold_size for count in temp_fold_dist.values()] if fold_size else [0]*len(temp_fold_dist)
            imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)
            
            if imbalance < min_imbalance:
                min_imbalance = imbalance
                best_fold = fold_idx
        
        folds[best_fold].extend(np.where(groups == group)[0])
        for label, count in group_label_dist[group].items():
            fold_label_dist[best_fold][label] += count
    
    train_test_indices = []
    for i in range(n_splits):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        train_test_indices.append((train_idx, test_idx))
    
    return train_test_indices


# =============================================================================
# 2. PREPARE CONFOUNDERS
# =============================================================================

def prepare_confounders(metadata, confounder_cols):
    """
    Prepare confounder variables for inclusion in models.
    
    Parameters:
    - metadata: DataFrame with metadata
    - confounder_cols: List of confounder column names
    
    Returns:
    - confounders_df: DataFrame with encoded and cleaned confounders
    - available_confounders: List of available confounder names
    """
    # Check which confounders are available
    available_confounders = [col for col in confounder_cols if col in metadata.columns]
    
    if not available_confounders:
        print("WARNING: No confounders available in metadata")
        return None, []
    
    print(f"Available confounders: {available_confounders}")
    
    # Extract confounders
    confounders_df = metadata[available_confounders].copy()
    
    # Encode categorical variables
    categorical_cols = []
    for col in confounders_df.columns:
        if confounders_df[col].dtype == 'object':
            categorical_cols.append(col)
            # Create dummy variables
            dummies = pd.get_dummies(confounders_df[col], prefix=col, drop_first=True)
            confounders_df = pd.concat([confounders_df.drop(columns=[col]), dummies], axis=1)
    
    # Handle missing values (use median for numeric, most common for categorical)
    for col in confounders_df.columns:
        if confounders_df[col].isna().any():
            if confounders_df[col].dtype in ['int64', 'float64']:
                confounders_df[col].fillna(confounders_df[col].median(), inplace=True)
            else:
                confounders_df[col].fillna(confounders_df[col].mode()[0], inplace=True)
    
    print(f"Final confounder features: {list(confounders_df.columns)}")
    print(f"Categorical variables encoded: {categorical_cols}")
    
    return confounders_df, available_confounders

# Prepare confounders once
confounders_full, available_confounder_names = prepare_confounders(metadata, CONFOUNDER_COLS)
print()


# =============================================================================
# 3. FUNCTION TO RUN CV WITH CONFOUNDERS AND GET ROC/FEATURE IMPORTANCE
# =============================================================================

def run_group_stratified_cv(X, y, groups, confounders=None, n_splits=5, rf_config=None):
    """
    Run group-stratified CV for Random Forest with optional confounder adjustment.
    
    Parameters:
    - X: DataFrame of microbiome features
    - y: Series of labels
    - groups: Series of group IDs (patient IDs)
    - confounders: DataFrame of confounder variables (must have same index as X)
    - n_splits: Number of CV folds
    - rf_config: Dictionary with RF parameters (uses RF_CONFIG if None)
    
    Returns:
    - cv_results: List of results per fold
    - feature_importances: DataFrame of microbiome feature importances only
    - fold_aucs: List of AUC values per fold
    """
    if rf_config is None:
        rf_config = RF_CONFIG
    
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    
    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []
    
    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
        
        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"  Skipping fold {i+1} due to insufficient class representation")
            continue
        
        # Add confounders if provided
        if confounders is not None:
            conf_train = confounders.iloc[train_idx]
            conf_test = confounders.iloc[test_idx]
            
            # Concatenate microbiome features with confounders
            X_train_combined = pd.concat([X_train, conf_train], axis=1)
            X_test_combined = pd.concat([X_test, conf_test], axis=1)
            
            # Train RF on combined features
            clf = RandomForestClassifier(**rf_config)
            clf.fit(X_train_combined, y_train)
            
            # Get predictions
            probas = clf.predict_proba(X_test_combined)
            
            # Extract only microbiome feature importances (exclude confounder importances)
            n_microbiome_features = X.shape[1]
            feature_importances[f'fold_{i}'] = clf.feature_importances_[:n_microbiome_features]
            
        else:
            # Standard RF without confounders
            clf = RandomForestClassifier(**rf_config)
            clf.fit(X_train, y_train)
            probas = clf.predict_proba(X_test)
            feature_importances[f'fold_{i}'] = clf.feature_importances_
        
        # Calculate ROC metrics
        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)
        
        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })
    
    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)
    
    return cv_results, feature_importances, fold_aucs


# =============================================================================
# 4. REGENERATE ROC RESULTS WITH CONFOUNDER ADJUSTMENT
# =============================================================================

print("RUNNING RANDOM FOREST ANALYSES WITH CONFOUNDER ADJUSTMENT")

# 4a. Skin vs Nares (Binary Classification)
print("\n1. Analyzing: Skin vs Nares...")
meta_skin_nares = metadata[metadata['group'].str.startswith(('skin', 'nares'))].copy()
meta_skin_nares['site_label'] = meta_skin_nares['group'].apply(lambda x: 0 if x.startswith('skin') else 1)
common_samples = df.index.intersection(meta_skin_nares.index)

X_skin_nares = df.loc[common_samples]
X_skin_nares = X_skin_nares.dropna(axis=1)

y_skin_nares = meta_skin_nares.loc[common_samples, 'site_label']
groups_skin_nares = meta_skin_nares.loc[common_samples, 'pid']
y_skin_nares = pd.to_numeric(y_skin_nares, errors='coerce')

# Get confounders for this subset
conf_skin_nares = confounders_full.loc[common_samples] if confounders_full is not None else None

print(f"  Samples: {len(X_skin_nares)}")
print(f"  Using confounders: {conf_skin_nares is not None}")

skin_vs_nares_cv_results, skin_vs_nares_fi, skin_vs_nares_aucs = run_group_stratified_cv(
    X_skin_nares, y_skin_nares, groups_skin_nares, 
    confounders=conf_skin_nares, n_splits=3, rf_config=RF_CONFIG
)
print(f"  Mean AUC: {np.mean(skin_vs_nares_aucs):.3f} ± {np.std(skin_vs_nares_aucs):.3f}")


# 4b. Skin Comparisons (Pairwise Among Skin Groups)
print("\n2. Analyzing: Skin pairwise comparisons...")
skin_comparisons = [
    ('skin-ADL', 'skin-H'),
    ('skin-ADNL', 'skin-ADL'),
    ('skin-ADNL', 'skin-H')
]
skin_cv_results_dict = {}
skin_feature_importance_dict = {}

for label1, label2 in skin_comparisons:
    print(f"\n   {label1} vs {label2}...")
    meta_subset = metadata[metadata['group'].isin([label1, label2])]
    common_samples = df.index.intersection(meta_subset.index)
    
    X_skin = df.loc[common_samples]
    y_skin = meta_subset.loc[common_samples, 'group'].map({label1: 0, label2: 1})
    groups_skin = meta_subset.loc[common_samples, 'pid']
    
    # Get confounders for this subset
    conf_skin = confounders_full.loc[common_samples] if confounders_full is not None else None
    
    print(f"     Samples: {len(X_skin)}")
    print(f"     Using confounders: {conf_skin is not None}")
    
    cv_results, fi, aucs = run_group_stratified_cv(
        X_skin, y_skin, groups_skin, 
        confounders=conf_skin, n_splits=3, rf_config=RF_CONFIG
    )
    
    key = f"{label1}_vs_{label2}"
    skin_cv_results_dict[key] = cv_results
    skin_feature_importance_dict[key] = fi
    print(f"     Mean AUC: {np.mean(aucs):.3f} ± {np.std(aucs):.3f}")


# 4c. Nares Comparison (nares-AD vs nares-H)
print("\n3. Analyzing: Nares AD vs H...")
meta_nares = metadata[metadata['group'].isin(['nares-AD', 'nares-H'])]
common_samples = df.index.intersection(meta_nares.index)

X_nares = df.loc[common_samples]
y_nares = meta_nares.loc[common_samples, 'group'].map({'nares-AD': 0, 'nares-H': 1})
groups_nares = meta_nares.loc[common_samples, 'pid']

# Get confounders for this subset
conf_nares = confounders_full.loc[common_samples] if confounders_full is not None else None

print(f"  Samples: {len(X_nares)}")
print(f"  Using confounders: {conf_nares is not None}")

nares_ad_vs_h_cv_results, nares_fi, nares_aucs = run_group_stratified_cv(
    X_nares, y_nares, groups_nares, 
    confounders=conf_nares, n_splits=3, rf_config=RF_CONFIG
)
print(f"  Mean AUC: {np.mean(nares_aucs):.3f} ± {np.std(nares_aucs):.3f}")


# =============================================================================
# 5. ASSEMBLE ROC RESULTS DICTIONARY
# =============================================================================

roc_results_dict = {
    'Skin vs. Nares Samples': {
        'skin_vs_nares': skin_vs_nares_cv_results
    },
    'AD Status From Skin Samples': skin_cv_results_dict,
    'AD Status From Nares Samples': {
        'nares-AD_vs_nares-H': nares_ad_vs_h_cv_results
    }
}


# =============================================================================
# 6. COMBINED ROC PLOTTING FUNCTION (3-PANEL)
# =============================================================================

def plot_combined_roc_3panel(roc_results_dict, comparison_order, color_map, output_path=None):
    """Plots a 3-panel ROC curve figure from grouped ROC results."""
    fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

    pretty_labels = {
        'skin_vs_nares': 'All Skin vs All Nares',
        'skin-ADL_vs_skin-H': 'Skin ADL vs Skin H',
        'skin-ADNL_vs_skin-ADL': 'Skin ADNL vs Skin ADL',
        'skin-ADNL_vs_skin-H': 'Skin ADNL vs Skin H',
        'nares-AD_vs_nares-H': 'Nares AD vs Nares H'
    }
    
    for i, panel_title in enumerate(comparison_order):
        ax = axs[i]
        sub_dict = roc_results_dict[panel_title]
        for sublabel, curves in sub_dict.items():
            if not isinstance(curves, list):
                curves = [curves]
            mean_fpr = np.linspace(0, 1, 100)
            tprs = []
            aucs = []
            for result in curves:
                if 'fpr' not in result or 'tpr' not in result:
                    continue
                tpr_interp = np.interp(mean_fpr, result['fpr'], result['tpr'])
                tpr_interp[0] = 0.0
                tprs.append(tpr_interp)
                aucs.append(result['auc'])
            if len(tprs) == 0:
                continue
            mean_tpr = np.mean(tprs, axis=0)
            mean_tpr[-1] = 1.0
            std_tpr = np.std(tprs, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
            ax.plot(mean_fpr, mean_tpr, lw=3,
                    label=f'{pretty_labels.get(sublabel, sublabel)} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                    color=color_map.get(sublabel, 'gray'))
            ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.3,
                            color=color_map.get(sublabel, 'gray'))
    
        ax.plot([0, 1], [0, 1], 'k--', lw=1)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('False Positive Rate', fontsize=18)
        ax.set_title(panel_title, fontsize=20)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.tick_params(axis='both', labelsize=14)
        ax.legend(loc='lower right', fontsize=12)
    
    axs[0].set_ylabel('True Positive Rate', fontsize=19)
    plt.suptitle(f'Random Forest Classifications by 16S ASVs', 
                 fontsize=24, x=0.5, y=1.02)
    plt.tight_layout()
    
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"\nROC plot saved to: {output_path}")


# =============================================================================
# 7. DEFINE PANEL ORDER AND COLOR MAP, THEN PLOT
# =============================================================================

comparison_order = [
    'Skin vs. Nares Samples',
    'AD Status From Skin Samples',
    'AD Status From Nares Samples'
]
color_map = {
    'skin_vs_nares': 'black',
    'skin-ADL_vs_skin-H': 'blue',
    'skin-ADNL_vs_skin-ADL': 'green',
    'skin-ADNL_vs_skin-H': 'purple',
    'nares-AD_vs_nares-H': 'orange'
}

plot_combined_roc_3panel(
    roc_results_dict,
    comparison_order=comparison_order,
    color_map=color_map,
    output_path='../Figures/Main/Fig_2A.jpg'
)


# =============================================================================
# 8. SAVE FEATURE IMPORTANCE RESULTS
# =============================================================================

print("\n" + "="*80)
print("SAVING FEATURE IMPORTANCE TABLES")


# Create directory if it doesn't exist
os.makedirs('../Data/RF_Feature_Importances', exist_ok=True)

# Save feature importance for each comparison
skin_vs_nares_fi.to_csv('../Data/RF_Feature_Importances/feature_importance_skin_vs_nares.csv')
print("  Saved: feature_importance_skin_vs_nares.csv")

nares_fi.to_csv('../Data/RF_Feature_Importances/feature_importance_nares_AD_vs_H.csv')
print("  Saved: feature_importance_nares_AD_vs_H.csv")

for comparison_key, fi in skin_feature_importance_dict.items():
    fi.to_csv(f'../Data/RF_Feature_Importances/feature_importance_{comparison_key}.csv')
    print(f"  Saved: feature_importance_{comparison_key}.csv")

print("ANALYSIS COMPLETE")
if confounders_full is not None:
    print(f"✓ Confounders included: {available_confounder_names}")
else:
    print("✓ No confounders available/used")

print(f"✓ RF Configuration: {RF_CONFIG}")
print("✓ ROC curves generated and saved")
print("✓ Feature importance tables saved")


RANDOM FOREST CONFIGURATION
RF Configuration: {'n_estimators': 1000, 'random_state': 42, 'n_jobs': -1}
Confounders to include: ['age_months', 'sex', 'enrolment_season']

Available confounders: ['age_months', 'sex', 'enrolment_season']
Final confounder features: ['age_months', 'sex_male', 'enrolment_season_Autumn ', 'enrolment_season_Spring', 'enrolment_season_Spring ', 'enrolment_season_Summer', 'enrolment_season_Winter']
Categorical variables encoded: ['sex', 'enrolment_season']

RUNNING RANDOM FOREST ANALYSES WITH CONFOUNDER ADJUSTMENT

1. Analyzing: Skin vs Nares...
  Samples: 462
  Using confounders: True


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  confounders_df[col].fillna(confounders_df[col].median(), inplace=True)


  Mean AUC: 0.952 ± 0.022

2. Analyzing: Skin pairwise comparisons...

   skin-ADL vs skin-H...
     Samples: 183
     Using confounders: True
     Mean AUC: 0.820 ± 0.014

   skin-ADNL vs skin-ADL...
     Samples: 198
     Using confounders: True
     Mean AUC: 0.753 ± 0.017

   skin-ADNL vs skin-H...
     Samples: 183
     Using confounders: True
     Mean AUC: 0.764 ± 0.078

3. Analyzing: Nares AD vs H...
  Samples: 180
  Using confounders: True
  Mean AUC: 0.717 ± 0.028

ROC plot saved to: ../Figures/Main/Fig_2A.jpg

SAVING FEATURE IMPORTANCE TABLES
  Saved: feature_importance_skin_vs_nares.csv
  Saved: feature_importance_nares_AD_vs_H.csv
  Saved: feature_importance_skin-ADL_vs_skin-H.csv
  Saved: feature_importance_skin-ADNL_vs_skin-ADL.csv
  Saved: feature_importance_skin-ADNL_vs_skin-H.csv
ANALYSIS COMPLETE
✓ Confounders included: ['age_months', 'sex', 'enrolment_season']
✓ RF Configuration: {'n_estimators': 1000, 'random_state': 42, 'n_jobs': -1}
✓ ROC curves generated and sav

### Comparison of different ML models

In [13]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier

In [14]:
# =============================================================================
# UNIFIED RANDOM FOREST CONFIGURATION (this takes about 2 mins)
# =============================================================================

# Define the RF configuration used throughout ALL analyses
RF_CONFIG = {
    'n_estimators': 1000,
    'random_state': 42,
    'n_jobs': -1
}


print("RANDOM FOREST CONFIGURATION")

print(f"Configuration: {RF_CONFIG}")
print("This configuration will be used for:")
print("  1. Main RF analysis (ROC curves, feature importance)")
print("  2. Model comparison benchmark")
print("="*80 + "\n")


# =============================================================================
# PART 1: GROUP-STRATIFIED CV FUNCTIONS
# =============================================================================

def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
    """Custom group-stratified k-fold splitting."""
    unique_groups = np.unique(groups)
    np.random.seed(random_state)
    np.random.shuffle(unique_groups)
    
    group_label_dist = {}
    for group in unique_groups:
        group_mask = groups == group
        group_y = y[group_mask]
        group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}
    
    folds = [[] for _ in range(n_splits)]
    fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
    
    sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)
    
    for group in sorted_groups:
        best_fold = 0
        min_imbalance = float('inf')
        
        for fold_idx in range(n_splits):
            temp_fold_dist = fold_label_dist[fold_idx].copy()
            for label, count in group_label_dist[group].items():
                temp_fold_dist[label] += count
            fold_size = sum(temp_fold_dist.values())
            proportions = [count / fold_size for count in temp_fold_dist.values()] if fold_size else [0]*len(temp_fold_dist)
            imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)
            
            if imbalance < min_imbalance:
                min_imbalance = imbalance
                best_fold = fold_idx
        
        folds[best_fold].extend(np.where(groups == group)[0])
        for label, count in group_label_dist[group].items():
            fold_label_dist[best_fold][label] += count
    
    train_test_indices = []
    for i in range(n_splits):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        train_test_indices.append((train_idx, test_idx))
    
    return train_test_indices


def run_group_stratified_cv(X, y, groups, n_splits=5, rf_config=None):
    """
    Run group-stratified CV for Random Forest with feature importance tracking.
    
    Parameters:
    - X: DataFrame of features
    - y: Series of labels
    - groups: Series of group IDs (patient IDs)
    - n_splits: Number of CV folds
    - rf_config: Dictionary with RF parameters (uses RF_CONFIG if None)
    
    Returns:
    - cv_results: List of results per fold
    - feature_importances: DataFrame of feature importances
    - fold_aucs: List of AUC values per fold
    """
    if rf_config is None:
        rf_config = RF_CONFIG
    
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    
    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []
    
    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
        
        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue
        
        # Use the standardized RF configuration
        clf = RandomForestClassifier(**rf_config)
        clf.fit(X_train, y_train)
        probas = clf.predict_proba(X_test)
        
        feature_importances[f'fold_{i}'] = clf.feature_importances_
        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)
        
        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })
    
    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)
    
    return cv_results, feature_importances, fold_aucs


def run_group_stratified_cv_multimodel(X, y, groups, models_dict, n_splits=5):
    """
    Run group-stratified CV for multiple models and return comparative results.
    
    Parameters:
    - X: DataFrame of features
    - y: Series of labels
    - groups: Series of group IDs (e.g., patient IDs)
    - models_dict: Dictionary of model name -> model object
    - n_splits: Number of CV folds
    
    Returns:
    - results_summary: DataFrame with performance metrics for each model
    - detailed_results: Dictionary with detailed results for each model
    """
    
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    
    results_summary = []
    detailed_results = {}
    
    for model_name, model in models_dict.items():
        print(f"  Training {model_name}...")
        
        fold_aucs = []
        fold_accuracies = []
        cv_results = []
        
        for i, (train_idx, test_idx) in enumerate(folds):
            X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
            y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
            
            # Skip fold if insufficient class representation
            if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
                print(f"    Skipping fold {i+1} for {model_name} due to insufficient class representation")
                continue
            
            # Train model
            model.fit(X_train, y_train)
            
            # Get predictions
            probas = model.predict_proba(X_test)
            y_pred = model.predict(X_test)
            
            # Calculate metrics
            fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
            roc_auc = auc(fpr, tpr)
            accuracy = accuracy_score(y_test, y_pred)
            
            fold_aucs.append(roc_auc)
            fold_accuracies.append(accuracy)
            
            cv_results.append({
                'y_true': y_test,
                'y_proba': probas[:, 1],
                'y_pred': y_pred,
                'fpr': fpr,
                'tpr': tpr,
                'auc': roc_auc,
                'accuracy': accuracy,
            })
        
        # Calculate summary statistics
        if len(cv_results) > 0:
            results_summary.append({
                'Model': model_name,
                'AUC-ROC': np.mean(fold_aucs),
                'AUC_std': np.std(fold_aucs),
                'Accuracy': np.mean(fold_accuracies),
                'Acc_std': np.std(fold_accuracies),
            })
            
            detailed_results[model_name] = cv_results
    
    results_df = pd.DataFrame(results_summary)
    results_df = results_df.sort_values('AUC-ROC', ascending=False)
    
    return results_df, detailed_results


# =============================================================================
# PART 2: MAIN RF ANALYSIS (ROC CURVES + FEATURE IMPORTANCE)
# =============================================================================

print("\n" + "="*80)
print("PART 1: MAIN RANDOM FOREST ANALYSIS")


# 3a. Skin vs Nares (Binary Classification)
print("\nAnalyzing: Skin vs Nares...")
meta_skin_nares = metadata[metadata['group'].str.startswith(('skin', 'nares'))].copy()
meta_skin_nares['site_label'] = meta_skin_nares['group'].apply(lambda x: 0 if x.startswith('skin') else 1)
common_samples = df.index.intersection(meta_skin_nares.index)
X_skin_nares = df.loc[common_samples]
y_skin_nares = meta_skin_nares.loc[common_samples, 'site_label']
groups_skin_nares = meta_skin_nares.loc[common_samples, 'pid']
y_skin_nares = pd.to_numeric(y_skin_nares, errors='coerce')

skin_vs_nares_cv_results, skin_vs_nares_fi, skin_vs_nares_aucs = run_group_stratified_cv(
    X_skin_nares, y_skin_nares, groups_skin_nares, n_splits=3, rf_config=RF_CONFIG
)
print(f"  Mean AUC: {np.mean(skin_vs_nares_aucs):.3f} ± {np.std(skin_vs_nares_aucs):.3f}")

# 3b. Skin Comparisons (Pairwise Among Skin Groups)
print("\nAnalyzing: Skin pairwise comparisons...")
skin_comparisons = [
    ('skin-ADL', 'skin-H'),
    ('skin-ADNL', 'skin-ADL'),
    ('skin-ADNL', 'skin-H')
]
skin_cv_results_dict = {}
skin_feature_importance_dict = {}

for label1, label2 in skin_comparisons:
    print(f"  {label1} vs {label2}...")
    meta_subset = metadata[metadata['group'].isin([label1, label2])]
    common_samples = df.index.intersection(meta_subset.index)
    X_skin = df.loc[common_samples]
    y_skin = meta_subset.loc[common_samples, 'group'].map({label1: 0, label2: 1})
    groups_skin = meta_subset.loc[common_samples, 'pid']
    
    cv_results, fi, aucs = run_group_stratified_cv(
        X_skin, y_skin, groups_skin, n_splits=3, rf_config=RF_CONFIG
    )
    key = f"{label1}_vs_{label2}"
    skin_cv_results_dict[key] = cv_results
    skin_feature_importance_dict[key] = fi
    print(f"    Mean AUC: {np.mean(aucs):.3f} ± {np.std(aucs):.3f}")

# 3c. Nares Comparison (nares-AD vs nares-H)
print("\nAnalyzing: Nares AD vs H...")
meta_nares = metadata[metadata['group'].isin(['nares-AD', 'nares-H'])]
common_samples = df.index.intersection(meta_nares.index)
X_nares = df.loc[common_samples]
y_nares = meta_nares.loc[common_samples, 'group'].map({'nares-AD': 0, 'nares-H': 1})
groups_nares = meta_nares.loc[common_samples, 'pid']

nares_ad_vs_h_cv_results, nares_fi, nares_aucs = run_group_stratified_cv(
    X_nares, y_nares, groups_nares, n_splits=3, rf_config=RF_CONFIG
)
print(f"  Mean AUC: {np.mean(nares_aucs):.3f} ± {np.std(nares_aucs):.3f}")

# Assemble ROC Results Dictionary
roc_results_dict = {
    'Skin vs. Nares Samples': {
        'skin_vs_nares': skin_vs_nares_cv_results
    },
    'AD Status From Skin Samples': skin_cv_results_dict,
    'AD Status From Nares Samples': {
        'nares-AD_vs_nares-H': nares_ad_vs_h_cv_results
    }
}

# Save feature importance results
print("Saving feature importance tables...")

os.makedirs('../Data/RF_Feature_Importances', exist_ok=True)
skin_vs_nares_fi.to_csv('../Data/RF_Feature_Importances/feature_importance_skin_vs_nares.csv')
nares_fi.to_csv('../Data/RF_Feature_Importances/feature_importance_nares_AD_vs_H.csv')

for comparison_key, fi in skin_feature_importance_dict.items():
    fi.to_csv(f'../Data/RF_Feature_Importances/feature_importance_{comparison_key}.csv')

print("Feature importance tables saved!")


# =============================================================================
# PART 3: ROC CURVE PLOTTING
# =============================================================================

def plot_combined_roc_3panel(roc_results_dict, comparison_order, color_map, output_path=None):
    """Plots a 3-panel ROC curve figure from grouped ROC results."""
    fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

    pretty_labels = {
        'skin_vs_nares': 'All Skin vs All Nares',
        'skin-ADL_vs_skin-H': 'Skin ADL vs Skin H',
        'skin-ADNL_vs_skin-ADL': 'Skin ADNL vs Skin ADL',
        'skin-ADNL_vs_skin-H': 'Skin ADNL vs Skin H',
        'nares-AD_vs_nares-H': 'Nares AD vs Nares H'
    }
    
    for i, panel_title in enumerate(comparison_order):
        ax = axs[i]
        sub_dict = roc_results_dict[panel_title]
        for sublabel, curves in sub_dict.items():
            if not isinstance(curves, list):
                curves = [curves]
            mean_fpr = np.linspace(0, 1, 100)
            tprs = []
            aucs = []
            for result in curves:
                if 'fpr' not in result or 'tpr' not in result:
                    continue
                tpr_interp = np.interp(mean_fpr, result['fpr'], result['tpr'])
                tpr_interp[0] = 0.0
                tprs.append(tpr_interp)
                aucs.append(result['auc'])
            if len(tprs) == 0:
                continue
            mean_tpr = np.mean(tprs, axis=0)
            mean_tpr[-1] = 1.0
            std_tpr = np.std(tprs, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
            ax.plot(mean_fpr, mean_tpr, lw=3,
                    label=f'{pretty_labels.get(sublabel, sublabel)} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                    color=color_map.get(sublabel, 'gray'))
            ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.3,
                            color=color_map.get(sublabel, 'gray'))
    
        ax.plot([0, 1], [0, 1], 'k--', lw=1)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('False Positive Rate', fontsize=18)
        ax.set_title(panel_title, fontsize=20)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.tick_params(axis='both', labelsize=14)
        ax.legend(loc='lower right', fontsize=12)
    
    axs[0].set_ylabel('True Positive Rate', fontsize=19)
    plt.suptitle(f'Random Forest Classifications by 16S ASVs', 
                 fontsize=24, x=0.5, y=1.02)
    plt.tight_layout()
    
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path, dpi=600, bbox_inches='tight')
        print(f"\nROC plot saved to: {output_path}")

comparison_order = [
    'Skin vs. Nares Samples',
    'AD Status From Skin Samples',
    'AD Status From Nares Samples'
]
color_map = {
    'skin_vs_nares': 'black',
    'skin-ADL_vs_skin-H': 'blue',
    'skin-ADNL_vs_skin-ADL': 'green',
    'skin-ADNL_vs_skin-H': 'purple',
    'nares-AD_vs_nares-H': 'orange'
}

plot_combined_roc_3panel(
    roc_results_dict,
    comparison_order=comparison_order,
    color_map=color_map,
    output_path='../Figures/Main/Fig_2A.jpg'
)


# =============================================================================
# PART 4: MODEL COMPARISON (Using SAME RF configuration)
# =============================================================================

print("PART 2: MODEL COMPARISON")
print("Comparing RF (with identical config) against alternative ML methods...")

# Define models to compare - NOTE: RF uses the SAME config as main analysis
models_to_compare = {
    'Random Forest': RandomForestClassifier(**RF_CONFIG),  # SAME AS MAIN ANALYSIS
    
    'XGBoost': XGBClassifier(
        n_estimators=1000,
        random_state=42,
        learning_rate=0.1,
        eval_metric='logloss',
        n_jobs=-1
    ),
    
    'LightGBM': LGBMClassifier(
        n_estimators=1000,
        random_state=42,
        learning_rate=0.1,
        verbose=-1,
        n_jobs=-1
    ),
    
    
    'SVM': SVC(
        probability=True,
        random_state=42,
        kernel='rbf'
    )
}

# Dictionary to store all comparison results
all_model_comparisons = {}

# 1. Skin vs Nares Comparison
print("MODEL COMPARISON: Skin vs. Nares")
skin_nares_comparison, _ = run_group_stratified_cv_multimodel(
    X_skin_nares, y_skin_nares, groups_skin_nares, 
    models_to_compare, n_splits=3
)
all_model_comparisons['Skin vs Nares'] = skin_nares_comparison
print("\nResults:")
print(skin_nares_comparison.round(3).to_string(index=False))

# 2. Skin Pairwise Comparisons
for label1, label2 in skin_comparisons:
    print(f"MODEL COMPARISON: {label1} vs {label2}")
    
    meta_subset = metadata[metadata['group'].isin([label1, label2])]
    common_samples = df.index.intersection(meta_subset.index)
    X_skin = df.loc[common_samples]
    y_skin = meta_subset.loc[common_samples, 'group'].map({label1: 0, label2: 1})
    groups_skin = meta_subset.loc[common_samples, 'pid']
    
    comparison_results, _ = run_group_stratified_cv_multimodel(
        X_skin, y_skin, groups_skin,
        models_to_compare, n_splits=3
    )
    
    key = f"{label1} vs {label2}"
    all_model_comparisons[key] = comparison_results
    print("\nResults:")
    print(comparison_results.round(3).to_string(index=False))

# 3. Nares AD vs H Comparison
print("MODEL COMPARISON: Nares AD vs H")
nares_comparison, _ = run_group_stratified_cv_multimodel(
    X_nares, y_nares, groups_nares,
    models_to_compare, n_splits=3
)
all_model_comparisons['Nares AD vs H'] = nares_comparison
print("\nResults:")
print(nares_comparison.round(3).to_string(index=False))


# =============================================================================
# PART 5: SAVE AND VISUALIZE MODEL COMPARISON
# =============================================================================

def create_comparison_summary_table(all_comparisons):
    """Create a comprehensive summary table across all comparisons."""
    summary_data = []
    
    for comparison_name, results_df in all_comparisons.items():
        for _, row in results_df.iterrows():
            summary_data.append({
                'Comparison': comparison_name,
                'Model': row['Model'],
                'AUC-ROC': f"{row['AUC-ROC']:.3f} ± {row['AUC_std']:.3f}",
                'Accuracy': f"{row['Accuracy']:.3f} ± {row['Acc_std']:.3f}",
            })
    
    summary_df = pd.DataFrame(summary_data)
    return summary_df

summary_table = create_comparison_summary_table(all_model_comparisons)
print("COMPREHENSIVE MODEL COMPARISON SUMMARY")
print(summary_table.to_string(index=False))

# Save to CSV
os.makedirs('../Data/ML_comparison', exist_ok=True)
summary_table.to_csv('../Data/ML_comparison/model_comparison_summary.csv', index=False)
print("\nSummary saved to: ../Data/ML_comparison/model_comparison_summary.csv")


# Heatmap visualization
def plot_model_comparison(all_comparisons, output_path=None):
    """Create a heatmap showing AUC-ROC for all models across comparisons."""
    
    models = list(models_to_compare.keys())
    comparisons = list(all_comparisons.keys())
    
    data_matrix = np.zeros((len(models), len(comparisons)))
    
    for j, comp_name in enumerate(comparisons):
        comp_df = all_comparisons[comp_name]
        for i, model_name in enumerate(models):
            model_row = comp_df[comp_df['Model'] == model_name]
            if not model_row.empty:
                data_matrix[i, j] = model_row['AUC-ROC'].values[0]
            else:
                data_matrix[i, j] = np.nan
    
    fig, ax = plt.subplots(figsize=(14, 6))
    sns.heatmap(data_matrix, 
                annot=True, 
                fmt='.2f',
                cmap='Blues',
                xticklabels=comparisons,
                yticklabels=models,
                cbar_kws={'label': 'AUC-ROC'},
                vmin=0.5, vmax=1.0,
                linewidths=0.5,
                ax=ax)
    
    plt.title('Model Performance Comparison Across All Binary Classifications', 
              fontsize=20, pad=20)
    plt.xlabel('Classification Task', rotation=0, fontsize=16, labelpad=20)
    plt.ylabel('Model', fontsize=16)
    plt.xticks(rotation=0, ha='center')
    plt.yticks(rotation=0)
    plt.tight_layout()
    
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"\nHeatmap saved to: {output_path}")
    plt.show()

plot_model_comparison(all_model_comparisons, 
                      output_path='../Figures/Supplementary/Suppl_Fig_2.jpg')

print("\n" + "="*80)
print("ANALYSIS COMPLETE")

print(f"RF Configuration used throughout: {RF_CONFIG}")
print("✓ Main RF analysis complete (ROC curves + feature importance)")
print("✓ Model comparison complete (RF vs 6 alternatives)")
print("✓ All results saved")

RANDOM FOREST CONFIGURATION
Configuration: {'n_estimators': 1000, 'random_state': 42, 'n_jobs': -1}
This configuration will be used for:
  1. Main RF analysis (ROC curves, feature importance)
  2. Model comparison benchmark


PART 1: MAIN RANDOM FOREST ANALYSIS

Analyzing: Skin vs Nares...
  Mean AUC: 0.949 ± 0.021

Analyzing: Skin pairwise comparisons...
  skin-ADL vs skin-H...
    Mean AUC: 0.790 ± 0.018
  skin-ADNL vs skin-ADL...
    Mean AUC: 0.751 ± 0.011
  skin-ADNL vs skin-H...
    Mean AUC: 0.689 ± 0.076

Analyzing: Nares AD vs H...
  Mean AUC: 0.646 ± 0.064
Saving feature importance tables...
Feature importance tables saved!

ROC plot saved to: ../Figures/Main/Fig_2A.jpg
PART 2: MODEL COMPARISON
Comparing RF (with identical config) against alternative ML methods...
MODEL COMPARISON: Skin vs. Nares
  Training Random Forest...
  Training XGBoost...
  Training LightGBM...
  Training SVM...

Results:
        Model  AUC-ROC  AUC_std  Accuracy  Acc_std
      XGBoost    0.955    0.01

  plt.show()


## Classification Split by Urban (Cape Town) vs Rural (Umtata) (Supplementary Plots)

In [15]:
# Split data by area (Cape Town vs Umtata)
cape_town_samples = metadata[metadata['area'] == 'Cape Town'].index
umtata_samples = metadata[metadata['area'] == 'Umtata'].index

# Filter the ASV table to get samples from each area
df_cape_town = df.loc[df.index.isin(cape_town_samples)]
df_umtata = df.loc[df.index.isin(umtata_samples)]

In [16]:
# Define area-specific DataFrames

area_df_dict = {
    'Capetown': df_cape_town,
    'Umtata': df_umtata
}


# Loop through both study areas

for area_name, df in area_df_dict.items():   
    
    # 1. Custom Group-Stratified K-Fold Function (UNCHANGED)
    
    def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
        unique_groups = np.unique(groups)
        np.random.seed(random_state)
        np.random.shuffle(unique_groups)
        
        group_label_dist = {}
        for group in unique_groups:
            group_mask = groups == group
            group_y = y[group_mask]
            group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}
        
        folds = [[] for _ in range(n_splits)]
        fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
        
        sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)
        
        for group in sorted_groups:
            best_fold = 0
            min_imbalance = float('inf')
            
            for fold_idx in range(n_splits):
                temp_fold_dist = fold_label_dist[fold_idx].copy()
                for label, count in group_label_dist[group].items():
                    temp_fold_dist[label] += count
                fold_size = sum(temp_fold_dist.values())
                proportions = [count / fold_size for count in temp_fold_dist.values()] if fold_size else [0]*len(temp_fold_dist)
                imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)
                
                if imbalance < min_imbalance:
                    min_imbalance = imbalance
                    best_fold = fold_idx
            
            folds[best_fold].extend(np.where(groups == group)[0])
            for label, count in group_label_dist[group].items():
                fold_label_dist[best_fold][label] += count
        
        train_test_indices = []
        for i in range(n_splits):
            test_idx = np.array(folds[i])
            train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
            train_test_indices.append((train_idx, test_idx))
        
        return train_test_indices

    
    # 2. MODIFIED Function to Run CV with Confounders
    
    def run_group_stratified_cv(X, y, groups, confounders=None, n_splits=5):
        """
        Run group-stratified CV with optional confounder adjustment.
        
        Parameters:
        - X: DataFrame of microbiome features
        - y: Series of labels
        - groups: Series of group IDs (patient IDs)
        - confounders: DataFrame of confounder variable
                       Must have same index as X
        - n_splits: Number of CV folds
        
        Returns:
        - cv_results: List of results per fold
        - feature_importances: DataFrame of feature importances
        - fold_aucs: List of AUC values per fold
        """
        folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)

        cv_results = []
        feature_importances = pd.DataFrame(index=X.columns)
        fold_aucs = []

        for i, (train_idx, test_idx) in enumerate(folds):
            X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
            y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

            if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
                print(f"Skipping fold {i+1} due to insufficient class representation")
                continue

            # ADD CONFOUNDERS IF PROVIDED
            if confounders is not None:
                # Ensure confounders are numeric
                conf_train = confounders.iloc[train_idx]
                conf_test = confounders.iloc[test_idx]
                
                # Concatenate microbiome features with confounders
                X_train_combined = pd.concat([X_train, conf_train], axis=1)
                X_test_combined = pd.concat([X_test, conf_test], axis=1)
                
                # Train RF on combined features
                clf = RandomForestClassifier(n_estimators=1000, random_state=42, n_jobs=-1)
                clf.fit(X_train_combined, y_train)
                
                # Get predictions
                probas = clf.predict_proba(X_test_combined)
                y_pred = clf.predict(X_test_combined)
                
                # Extract only microbiome feature importances (exclude confounder importances)
                n_microbiome_features = X.shape[1]
                feature_importances[f'fold_{i}'] = clf.feature_importances_[:n_microbiome_features]
                
            else:
                # Standard RF without confounders
                clf = RandomForestClassifier(n_estimators=1000, random_state=42, n_jobs=-1)
                clf.fit(X_train, y_train)
                probas = clf.predict_proba(X_test)
                y_pred = clf.predict(X_test)
                feature_importances[f'fold_{i}'] = clf.feature_importances_

            fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
            roc_auc = auc(fpr, tpr)
            fold_aucs.append(roc_auc)

            cv_results.append({
                'y_true': y_test,
                'y_pred': y_pred,
                'y_proba': probas[:, 1],
                'fpr': fpr,
                'tpr': tpr,
                'auc': roc_auc,
                'fold': i
            })

        feature_importances['mean_importance'] = feature_importances.mean(axis=1)
        feature_importances['std_importance'] = feature_importances.std(axis=1)
        feature_importances = feature_importances.sort_values('mean_importance', ascending=False)

        return cv_results, feature_importances, fold_aucs

    
    # 3.Prepare Confounders
    
    # Define which confounders to include
    confounder_cols = ['age_months', 'sex', 'enrolment_season']

    # Check which confounders are available
    available_confounders = [col for col in confounder_cols if col in metadata.columns]
    print(f"\nArea: {area_name}")
    print(f"Available confounders: {available_confounders}")
    
    # Prepare confounder dataframe
    if available_confounders:
        confounders_full = metadata[available_confounders].copy()
        
        # Encode categorical variables (e.g., age_months)
        for col in confounders_full.columns:
            if confounders_full[col].dtype == 'object':
                # Create dummy variables for categorical
                confounders_full = pd.get_dummies(confounders_full, columns=[col], drop_first=True)
        
        # Handle missing values
        confounders_full = confounders_full.fillna(confounders_full.median())
        
        print(f"Final confounder variables: {list(confounders_full.columns)}")
    else:
        confounders_full = None
        print("No confounders available - running without confounder adjustment")
    
    
    # 4. MODIFIED: Regenerate ROC Results with Confounders
    
    # 4a. Skin vs Nares
    meta_skin_nares = metadata[metadata['group'].str.startswith(('skin', 'nares'))].copy()
    meta_skin_nares['site_label'] = meta_skin_nares['group'].apply(lambda x: 0 if x.startswith('skin') else 1)
    common_samples = df.index.intersection(meta_skin_nares.index)
    X_skin_nares = df.loc[common_samples]
    y_skin_nares = meta_skin_nares.loc[common_samples, 'site_label']
    groups_skin_nares = meta_skin_nares.loc[common_samples, 'pid']
    y_skin_nares = pd.to_numeric(y_skin_nares, errors='coerce')
    
    # Get confounders for this subset
    conf_skin_nares = confounders_full.loc[common_samples] if confounders_full is not None else None
    
    skin_vs_nares_cv_results, _, _ = run_group_stratified_cv(
        X_skin_nares, y_skin_nares, groups_skin_nares, 
        confounders=conf_skin_nares, n_splits=3
    )

    # 4b. Skin Comparisons
    skin_comparisons = [
        ('skin-ADL', 'skin-H'),
        ('skin-ADNL', 'skin-ADL'),
        ('skin-ADNL', 'skin-H')
    ]
    skin_cv_results_dict = {}
    for label1, label2 in skin_comparisons:
        meta_subset = metadata[metadata['group'].isin([label1, label2])]
        common_samples = df.index.intersection(meta_subset.index)
        X_skin = df.loc[common_samples]
        y_skin = meta_subset.loc[common_samples, 'group'].map({label1: 0, label2: 1})
        groups_skin = meta_subset.loc[common_samples, 'pid']
        
        # Get confounders for this subset
        conf_skin = confounders_full.loc[common_samples] if confounders_full is not None else None
        
        cv_results, _, _ = run_group_stratified_cv(
            X_skin, y_skin, groups_skin, 
            confounders=conf_skin, n_splits=3
        )
        key = f"{label1}_vs_{label2}"
        skin_cv_results_dict[key] = cv_results

    # 4c. Nares Comparison
    meta_nares = metadata[metadata['group'].isin(['nares-AD', 'nares-H'])]
    common_samples = df.index.intersection(meta_nares.index)
    X_nares = df.loc[common_samples]
    y_nares = meta_nares.loc[common_samples, 'group'].map({'nares-AD': 0, 'nares-H': 1})
    groups_nares = meta_nares.loc[common_samples, 'pid']
    
    # Get confounders for this subset
    conf_nares = confounders_full.loc[common_samples] if confounders_full is not None else None
    
    nares_ad_vs_h_cv_results, _, _ = run_group_stratified_cv(
        X_nares, y_nares, groups_nares, 
        confounders=conf_nares, n_splits=3
    )

    roc_results_dict = {
        'Prediction of Skin vs. Nares Samples': {
            'skin_vs_nares': skin_vs_nares_cv_results
        },
        'Prediction of AD Status From Skin Samples': skin_cv_results_dict,
        'Prediction of AD Status From Nares Samples': {
            'nares-AD_vs_nares-H': nares_ad_vs_h_cv_results
        }
    }

    
    # 5. Combined ROC Plotting Function (UNCHANGED)
    
    def plot_combined_roc_3panel(roc_results_dict, comparison_order, color_map, area_name, output_path=None):
        fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

        pretty_labels = {
            'skin_vs_nares': 'All Skin vs All Nares',
            'skin-ADL_vs_skin-H': 'Skin ADL vs Skin H',
            'skin-ADNL_vs_skin-ADL': 'Skin ADNL vs Skin ADL',
            'skin-ADNL_vs_skin-H': 'Skin ADNL vs Skin H',
            'nares-AD_vs_nares-H': 'Nares AD vs Nares H'
        }

        for i, panel_title in enumerate(comparison_order):
            ax = axs[i]
            sub_dict = roc_results_dict[panel_title]
            for sublabel, curves in sub_dict.items():
                if not isinstance(curves, list):
                    curves = [curves]
                mean_fpr = np.linspace(0, 1, 100)
                tprs, aucs = [], []
                for result in curves:
                    if 'fpr' not in result or 'tpr' not in result:
                        continue
                    tpr_interp = np.interp(mean_fpr, result['fpr'], result['tpr'])
                    tpr_interp[0] = 0.0
                    tprs.append(tpr_interp)
                    aucs.append(result['auc'])
                if not tprs:
                    continue
                mean_tpr = np.mean(tprs, axis=0)
                mean_tpr[-1] = 1.0
                std_tpr = np.std(tprs, axis=0)
                ax.plot(mean_fpr, mean_tpr, lw=3,
                        label=f'{pretty_labels.get(sublabel, sublabel)} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                        color=color_map.get(sublabel, 'gray'))
                ax.fill_between(mean_fpr,
                                np.maximum(mean_tpr - std_tpr, 0),
                                np.minimum(mean_tpr + std_tpr, 1),
                                color=color_map.get(sublabel, 'gray'),
                                alpha=0.3)
            ax.plot([0, 1], [0, 1], 'k--', lw=1)
            ax.set_xlim([0.0, 1.0])
            ax.set_ylim([0.0, 1.05])
            ax.set_xlabel('False Positive Rate', fontsize=16)
            ax.set_title(panel_title, fontsize=18)
            ax.grid(True, linestyle='--', alpha=0.7)
            ax.tick_params(axis='both', labelsize=14)
            ax.legend(loc='lower right', fontsize=12)

        axs[0].set_ylabel('True Positive Rate', fontsize=16)
        if area_name == 'Capetown':
            plt.suptitle('Random Forest Classifications by 16S ASVs of Cape Town Samples', fontsize=26)
        elif area_name == 'Umtata':
            plt.suptitle('Random Forest Classifications by 16S ASVs of Umtata Samples', fontsize=26)

        plt.tight_layout(rect=[0, 0, 1, 0.95])

        if output_path:
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            plt.savefig(output_path, dpi=600, bbox_inches='tight')
            print(f"Saved ROC plot for {area_name} to {output_path}")
        else:
            plt.show()

        plt.close(fig)

    
    # 6. Define Panel Order and Color Map, Then Plot
    
    comparison_order = [
        'Prediction of Skin vs. Nares Samples',
        'Prediction of AD Status From Skin Samples',
        'Prediction of AD Status From Nares Samples'
    ]
    color_map = {
        'skin_vs_nares': 'black',
        'skin-ADL_vs_skin-H': 'blue',
        'skin-ADNL_vs_skin-ADL': 'green',
        'skin-ADNL_vs_skin-H': 'purple',
        'nares-AD_vs_nares-H': 'orange'
    }

    if area_name == 'Umtata':
        output_path = '../Figures/Supplementary/Suppl_Fig_1A.jpg'
    elif area_name == 'Capetown':
        output_path = '../Figures/Supplementary/Suppl_Fig_1B.jpg'

    plot_combined_roc_3panel(
        roc_results_dict,
        comparison_order=comparison_order,
        color_map=color_map,
        area_name=area_name,
        output_path=output_path
    )


Area: Capetown
Available confounders: ['age_months', 'sex', 'enrolment_season']
Final confounder variables: ['age_months', 'sex_male', 'enrolment_season_Autumn ', 'enrolment_season_Spring', 'enrolment_season_Spring ', 'enrolment_season_Summer', 'enrolment_season_Winter']
Saved ROC plot for Capetown to ../Figures/Supplementary/Suppl_Fig_1B.jpg

Area: Umtata
Available confounders: ['age_months', 'sex', 'enrolment_season']
Final confounder variables: ['age_months', 'sex_male', 'enrolment_season_Autumn ', 'enrolment_season_Spring', 'enrolment_season_Spring ', 'enrolment_season_Summer', 'enrolment_season_Winter']
Saved ROC plot for Umtata to ../Figures/Supplementary/Suppl_Fig_1A.jpg
