# Random Forest Analyses

In [14]:
# Standard library
import warnings
import logging
from itertools import combinations

# Scientific computing
import numpy as np
import pandas as pd
from numpy import array
import scipy
import scipy.stats as ss
from scipy import interp
from scipy.stats import wilcoxon, ttest_rel

# Visualization
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

# scikit-bio
from skbio.stats.distance import permanova

# BIOM format
import biom
from biom import load_table

# Scikit-learn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, confusion_matrix, classification_report,
    roc_curve, auc, RocCurveDisplay
)
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.preprocessing import label_binarize

# Set overall styling for plots
sns.set_context("paper", font_scale=1.5)
sns.set_style("ticks")

## Random Forest Classification Tasks

In [15]:
# Read in table at ASV level
biom_path = '..//Data/Tables/Absolute_Abundance_Tables/209766_filtered_feature_table.biom'
# biom_path = '../Data/Tables/Absolute_Abundance_Tables/209766_filtered_by_prevalence_1pct_rare_Genus-ASV-non-collapse.biom'

biom_tbl = load_table(biom_path)
df = pd.DataFrame(biom_tbl.to_dataframe().T)

# delete the prefix from the index
df.index = df.index.str.replace('15564.', '')
df

Unnamed: 0,GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCCGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCCGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTAATCGGAATTATTGGGCGTAAAGCGAGTGCAGACGGTTACTTAAGCCAGATGTGAAATCCCC,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAATTATTGGGCGTAAAGCGCGCGCAGGCGGTTTCTTAAGTCTGATGTGAAAGCCCC,GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATTTATTGGGCGTAAAGGGCTCGTAGGTGGTTGATCGCGTCGGAAGTGTAATCTTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTCCAAGCGTTAATCGGAATTACTGGGCGTAAAGCGTGCGCAGGCGGTTGTGCAAGACCGATGTGAAATCCCC,GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGATTTATTGGGCGTAAAGGGAGCGCAGGTGGTTTCTTAAGTCTGATGTGAAAGCCCA,GTGCCAGCCGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATTTATTGGGTTTAAAGGGAGCGTAGGCGGATTATTAAGTCAGTGGTGAAAGACGG,...,GTGCCAGCCGCCGCGGTAATACGTAGGGGGCAAGCGTTATCCGGATTTACTGGGTGTAAAGGGAGCGTAGACGGCGCAGCAAGTCTGATGTGAAAGGCAG,GTGCCAGCAGCCGCGGTAAGACAGAGGGTGCAAACGTTGCTCGGAATCACTGGGCGTAAAGGGCGTGTAGGCGGGAGAGAAAGTCGGGCGTGAAATCCCT,GTGCCAGCCGCGGTAATACGTAGGGGGCTAGCGTTGTCCGGAATCACTGGGCGTAAAGGGTTCGCAGGCGGAAATGCAAGTCAGGTGTAAAAGGCAGTAG,GTGCCAGCAGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCCGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTATCCGGAATCATTGGGTTTAAAGGGTCCGCAGGCGGATTTATAAGTCAGTGGTGAAAGCCTA,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCGAGCGTTGTCCGGAATTACTGGGTGTAAAGGGCGTGTAGGCGGGAAGGTAAGTCAGATGTGAAATACCG,GTGCCAGCCGCCGCGGTAATACGGAGGATGCGAGCGTTATTCGGAATCATTGGGTTTAAAGGGTCTGTAGGCGGGCTATTAAGTCAGAGGTGAAAGGTTT,GTGCCAGCCGCCGCGGTAAGACGAAGGGGGCTAGCGTTGTTCGGAATTACTGGGCGTAAAGCGCGTGCAGGCGGTTATCCAAGTCGGGTGTGAAAGCCTT,GTCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGT
900344,984.0,611.0,114.0,82.0,22.0,15.0,8.0,8.0,6.0,3.0,...,0,0,0,0,0,0,0,0,0,0
900459,118.0,106.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
900221,22.0,0,0,0,0,0,16.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
900570,389.0,0,0,0,8.0,0,11.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
900092,3106.0,1707.0,59.0,32.0,3.0,0,0,0,0,7.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9003972,1168.0,593.0,16.0,0,28.0,0,736.0,0,0,36.0,...,0,0,0,0,0,0,0,0,0,0
900097,24.0,0,0,0,0,0,33.0,0,0,0,...,8.0,5.0,1.0,0,0,0,0,0,0,0
900498,15.0,17.0,0,0,0,0,34.0,0,14.0,0,...,0,0,0,15.0,10.0,8.0,0,0,0,0
900276,0,0,30.0,0,0,0,151.0,0,0,0,...,0,0,0,0,0,0,11.0,3.0,2.0,1.0


In [16]:
# Load the metadata
metadata_path = '../Data/Metadata/updated_clean_ant_skin_metadata.tab'
metadata = pd.read_csv(metadata_path, sep='\t')

metadata['#sample-id'] = metadata['#sample-id'].str.replace('_', '')
# Set Sample-ID as the index for the metadata dataframe 
metadata = metadata.set_index('#sample-id')


# Create group column based on case_type to simplify group names
metadata['group'] = metadata['case_type'].map({
    'case-lesional skin': 'skin-ADL',
    'case-nonlesional skin': 'skin-ADNL', 
    'control-nonlesional skin': 'skin-H',
    'case-anterior nares': 'nares-AD',
    'control-anterior nares': 'nares-H'
})

metadata

Unnamed: 0_level_0,PlateNumber,PlateLocation,i5,i5Sequence,i7,i7Sequence,identifier,Sequence,Plate ID,Well location,...,sex,enrolment_date,enrolment_season,hiv_exposure,hiv_status,household_size,o_scorad,FWD_filepath,REV_filepath,group
#sample-id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ca009STL,1,A1,SA501,ATCGTACG,SA701,CGAGAGTT,SA701SA501,CGAGAGTT-ATCGTACG,1.010000e+21,A1,...,male,4/16/2015,Autumn,Unexposed,negative,4.0,40,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900221,1,B1,SA502,ACTATCTG,SA701,CGAGAGTT,SA701SA502,CGAGAGTT-ACTATCTG,1.010000e+21,B1,...,female,8/11/2015,Winter,Unexposed,negative,7.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
Ca010EBL,1,C1,SA503,TAGCGAGT,SA701,CGAGAGTT,SA701SA503,CGAGAGTT-TAGCGAGT,1.010000e+21,C1,...,female,11/20/2014,Spring,Unexposed,negative,7.0,21,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900460,1,D1,SA504,CTGCGTGT,SA701,CGAGAGTT,SA701SA504,CGAGAGTT-CTGCGTGT,1.010000e+21,D1,...,female,9/23/2015,Spring,Unexposed,,4.0,40,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900051,1,E1,SA505,TCATCGAG,SA701,CGAGAGTT,SA701SA505,CGAGAGTT-TCATCGAG,1.010000e+21,E1,...,male,4/21/2015,Autumn,Unexposed,negative,7.0,41,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Ca006ONL2,6,H1,SA508,GACACCGT,SB701,CTCGACTT,SB701SA508,CTCGACTT-GACACCGT,1.010000e+21,H1,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
Ca006ONNL,6,F2,SA506,CGTGAGTG,SB702,CGAAGTAT,SB702SA506,CGAAGTAT-CGTGAGTG,1.010000e+21,F2,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADNL
Ca006ONNL2,6,H2,SA508,GACACCGT,SB702,CGAAGTAT,SB702SA508,CGAAGTAT-GACACCGT,1.010000e+21,H2,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADNL
Ca006ONPN,6,F3,SA506,CGTGAGTG,SB703,TAGCAGCT,SB703SA506,TAGCAGCT-CGTGAGTG,1.010000e+21,F3,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,nares-AD


In [17]:
# View metadata columns
metadata.columns

Index(['PlateNumber', 'PlateLocation', 'i5', 'i5Sequence', 'i7', 'i7Sequence',
       'identifier', 'Sequence', 'Plate ID', 'Well location', 'Volume (ul)',
       'Lysozyme pretreatment', 'DNA extraction method', 'Purification method',
       'Date of DNA extraction', 'pid', 'case_type', 'participant', 'area',
       'sample_type', 'specimen', 'age_months', 'sex', 'enrolment_date',
       'enrolment_season', 'hiv_exposure', 'hiv_status', 'household_size',
       'o_scorad', 'FWD_filepath', 'REV_filepath', 'group'],
      dtype='object')

In [18]:
# Check o_scorad distribution across different groups
for group in metadata['group'].unique():
    print(f"\nO-SCORAD distribution for {group}:")
    print(metadata[metadata['group'] == group]['o_scorad'].value_counts().sort_index())



O-SCORAD distribution for skin-ADL:
o_scorad
21    4
22    2
23    3
24    2
25    4
26    2
28    3
29    2
30    3
32    4
33    2
34    5
35    1
36    3
37    2
38    2
39    3
40    4
41    4
42    5
43    7
44    9
45    1
47    1
48    1
49    1
51    1
52    1
53    3
54    3
56    2
58    2
61    2
66    2
67    1
75    2
78    4
82    1
NR    1
Name: count, dtype: int64

O-SCORAD distribution for skin-ADNL:
o_scorad
10    1
21    4
22    2
23    3
24    2
25    6
26    2
28    4
29    2
30    3
32    4
33    2
34    5
35    1
36    3
37    2
38    2
39    3
40    4
41    4
42    5
43    7
44    9
45    1
47    1
48    1
49    2
51    1
52    1
53    3
54    3
56    2
58    2
61    2
66    2
67    1
78    2
NR    1
Name: count, dtype: int64

O-SCORAD distribution for nares-AD:
o_scorad
21    4
22    2
23    3
24    2
25    5
26    2
28    4
29    2
30    3
32    4
33    2
34    5
35    1
36    3
37    2
38    1
39    3
40    4
41    4
42    5
43    6
44    9
45    1
47    1
4

In [19]:
# split data into training and testing
# ------------------------------
# 1. Custom Group-Stratified K-Fold Function
# ------------------------------
def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
    unique_groups = np.unique(groups)
    np.random.seed(random_state)
    np.random.shuffle(unique_groups)
    
    group_label_dist = {}
    for group in unique_groups:
        group_mask = groups == group
        group_y = y[group_mask]
        group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}
    
    folds = [[] for _ in range(n_splits)]
    fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
    
    sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)
    
    for group in sorted_groups:
        best_fold = 0
        min_imbalance = float('inf')
        
        for fold_idx in range(n_splits):
            temp_fold_dist = fold_label_dist[fold_idx].copy()
            for label, count in group_label_dist[group].items():
                temp_fold_dist[label] += count
            fold_size = sum(temp_fold_dist.values())
            proportions = [count / fold_size for count in temp_fold_dist.values()] if fold_size else [0]*len(temp_fold_dist)
            imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)
            
            if imbalance < min_imbalance:
                min_imbalance = imbalance
                best_fold = fold_idx
        
        folds[best_fold].extend(np.where(groups == group)[0])
        for label, count in group_label_dist[group].items():
            fold_label_dist[best_fold][label] += count
    
    train_test_indices = []
    for i in range(n_splits):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        train_test_indices.append((train_idx, test_idx))
    
    return train_test_indices

# ------------------------------
# 2. Function to Run CV and Get ROC/Feature Importance
# ------------------------------
def run_group_stratified_cv(X, y, groups, n_splits=5):
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    
    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []
    
    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
        
        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue
            
        clf = RandomForestClassifier(n_estimators=1000, random_state=42)
        clf.fit(X_train, y_train)
        probas = clf.predict_proba(X_test)
        
        feature_importances[f'fold_{i}'] = clf.feature_importances_
        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)
        
        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })
    
    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)
    
    return cv_results, feature_importances, fold_aucs

# ------------------------------
# 3. Regenerate ROC Results
# ------------------------------
# Make sure your df (feature table) and metadata DataFrame are in scope.
# Here, 'df' is assumed to contain your features (ASV relative abundances) 
# and 'metadata' is assumed to contain columns: 'group', 'pid', and 'microbiome_type' if applicable.

# 3a. Skin vs Nares (Binary Classification)
meta_skin_nares = metadata[metadata['group'].str.startswith(('skin', 'nares'))].copy()
meta_skin_nares['site_label'] = meta_skin_nares['group'].apply(lambda x: 0 if x.startswith('skin') else 1)
common_samples = df.index.intersection(meta_skin_nares.index)
X_skin_nares = df.loc[common_samples]
y_skin_nares = meta_skin_nares.loc[common_samples, 'site_label']
groups_skin_nares = meta_skin_nares.loc[common_samples, 'pid']
y_skin_nares = pd.to_numeric(y_skin_nares, errors='coerce')
skin_vs_nares_cv_results, _, _ = run_group_stratified_cv(X_skin_nares, y_skin_nares, groups_skin_nares, n_splits=3)

# 3b. Skin Comparisons (Pairwise Among Skin Groups)
skin_comparisons = [
    ('skin-ADL', 'skin-H'),
    ('skin-ADNL', 'skin-ADL'),
    ('skin-ADNL', 'skin-H')
]
skin_cv_results_dict = {}
for label1, label2 in skin_comparisons:
    meta_subset = metadata[metadata['group'].isin([label1, label2])]
    common_samples = df.index.intersection(meta_subset.index)
    X_skin = df.loc[common_samples]
    # Map to binary: 0 for first label, 1 for second label
    y_skin = meta_subset.loc[common_samples, 'group'].map({label1: 0, label2: 1})
    groups_skin = meta_subset.loc[common_samples, 'pid']
    cv_results, _, _ = run_group_stratified_cv(X_skin, y_skin, groups_skin, n_splits=3)
    key = f"{label1}_vs_{label2}"
    skin_cv_results_dict[key] = cv_results

# 3c. Nares Comparison (nares-AD vs nares-H)
meta_nares = metadata[metadata['group'].isin(['nares-AD', 'nares-H'])]
common_samples = df.index.intersection(meta_nares.index)
X_nares = df.loc[common_samples]
y_nares = meta_nares.loc[common_samples, 'group'].map({'nares-AD': 0, 'nares-H': 1})
groups_nares = meta_nares.loc[common_samples, 'pid']
nares_ad_vs_h_cv_results, _, _ = run_group_stratified_cv(X_nares, y_nares, groups_nares, n_splits=3)

# ------------------------------
# 4. Assemble ROC Results Dictionary
# ------------------------------
roc_results_dict = {
    'Prediction of Skin vs. Nares Samples': {
        'skin_vs_nares': skin_vs_nares_cv_results
    },
    'Prediction of AD Status From Skin Samples': skin_cv_results_dict,
    'Prediction of AD Status From Nares Samples': {
        'nares-AD_vs_nares-H': nares_ad_vs_h_cv_results
    }
}


# ------------------------------
# 5. Combined ROC Plotting Function (3-Panel)
# ------------------------------
def plot_combined_roc_3panel(roc_results_dict, comparison_order, color_map, output_path=None):
    """
    Plots a 3-panel ROC curve figure from grouped ROC results.
    
    Parameters:
    - roc_results_dict: dict
        A dictionary mapping panel titles to sub-dicts. Each sub-dict maps sublabels 
        to a list of cv_results dictionaries.
    - comparison_order: list
        List of panel titles in order.
    - color_map: dict
        Dictionary mapping each sublabel to a color.
    - output_path: str or None
        If provided, path to save the figure.
    """
    fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

    pretty_labels = {
    'skin_vs_nares': 'All Skin vs All Nares',
    'skin-ADL_vs_skin-H': 'Skin ADL vs Skin H',
    'skin-ADNL_vs_skin-ADL': 'Skin ADNL vs Skin H',
    'skin-ADNL_vs_skin-H': 'Skin ADNL vs Skin ADL',
    'nares-AD_vs_nares-H': 'Nares AD vs Nares H'
}
    
    for i, panel_title in enumerate(comparison_order):
        ax = axs[i]
        sub_dict = roc_results_dict[panel_title]
        # Loop over each sub-comparison within this panel
        for sublabel, curves in sub_dict.items():
            # Ensure curves is a list
            if not isinstance(curves, list):
                curves = [curves]
            mean_fpr = np.linspace(0, 1, 100)
            tprs = []
            aucs = []
            for result in curves:
                # Ensure the result dict has fpr and tpr keys
                if 'fpr' not in result or 'tpr' not in result:
                    continue
                tpr_interp = np.interp(mean_fpr, result['fpr'], result['tpr'])
                tpr_interp[0] = 0.0
                tprs.append(tpr_interp)
                aucs.append(result['auc'])
            if len(tprs) == 0:
                continue
            mean_tpr = np.mean(tprs, axis=0)
            mean_tpr[-1] = 1.0
            std_tpr = np.std(tprs, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
            ax.plot(mean_fpr, mean_tpr, lw=3,
                    label=f'{pretty_labels.get(sublabel, sublabel)} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                    color=color_map.get(sublabel, 'gray'))
            ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.3,
                            color=color_map.get(sublabel, 'gray'))
    
        ax.plot([0, 1], [0, 1], 'k--', lw=1)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('False Positive Rate', fontsize=16)
        ax.set_title(panel_title, fontsize=18)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.tick_params(axis='both', labelsize=14)
        ax.legend(loc='lower right', fontsize=12)
    
    axs[0].set_ylabel('True Positive Rate', fontsize=16)
    plt.suptitle('Random Forest Classifications by 16S ASVs', fontsize=26, x=0.5, y=1)
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=600, bbox_inches='tight')

# ------------------------------
# 6. Define Panel Order and Color Map, Then Plot
# ------------------------------
comparison_order = [
    'Prediction of Skin vs. Nares Samples',
    'Prediction of AD Status From Skin Samples',
    'Prediction of AD Status From Nares Samples'
]
color_map = {
    'skin_vs_nares': 'black',
    'skin-ADL_vs_skin-H': 'blue',
    'skin-ADNL_vs_skin-ADL': 'green',
    'skin-ADNL_vs_skin-H': 'purple',
    'nares-AD_vs_nares-H': 'orange'
}

plot_combined_roc_3panel(
    roc_results_dict,
    comparison_order=comparison_order,
    color_map=color_map,
    # output_path='../Plots/Analysis_figures/Random_Forest/rf_classifications_ROC_curves_filtered_by_prevalence_1pct_rare_Genus-ASV-non-collapse.png'
    output_path = '../Plots/Analysis_figures/Random_Forest/rf_classifications_final.png'
)


## Random Forest Regression Tasks

In [20]:
# Additional Python packages for regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

In [21]:
def clr_transform(df, pseudocount=1e-6, standardize=False):
    """
    Apply Centered Log-Ratio (CLR) transformation to a dense or sparse feature table.

    Parameters:
    -----------
    df : pd.DataFrame
        Microbial abundance table (samples x features)
    pseudocount : float
        Small value to add to zero counts before log transformation
    standardize : bool
        Whether to z-score standardize each feature after CLR

    Returns:
    --------
    pd.DataFrame
        CLR-transformed (and optionally standardized) DataFrame
    """
    # Convert to dense if sparse
    if hasattr(df, 'sparse'):
        df = df.sparse.to_dense()

    # Replace zeros with pseudocount
    df_pseudo = df.replace(0, pseudocount)

    # Log transform
    log_df = np.log(df_pseudo)

    # Subtract row-wise mean (CLR)
    clr_array = log_df.sub(log_df.mean(axis=1), axis=0)

    clr_df = pd.DataFrame(clr_array, index=df.index, columns=df.columns)

    # Optionally standardize
    if standardize:
        scaler = StandardScaler()
        clr_df = pd.DataFrame(
            scaler.fit_transform(clr_df),
            index=clr_df.index,
            columns=clr_df.columns
        )

    return clr_df
    

df = clr_transform(df)  # Just CLR
df

Unnamed: 0,GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCCGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCCGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTAATCGGAATTATTGGGCGTAAAGCGAGTGCAGACGGTTACTTAAGCCAGATGTGAAATCCCC,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAATTATTGGGCGTAAAGCGCGCGCAGGCGGTTTCTTAAGTCTGATGTGAAAGCCCC,GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATTTATTGGGCGTAAAGGGCTCGTAGGTGGTTGATCGCGTCGGAAGTGTAATCTTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTCCAAGCGTTAATCGGAATTACTGGGCGTAAAGCGTGCGCAGGCGGTTGTGCAAGACCGATGTGAAATCCCC,GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGATTTATTGGGCGTAAAGGGAGCGCAGGTGGTTTCTTAAGTCTGATGTGAAAGCCCA,GTGCCAGCCGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATTTATTGGGTTTAAAGGGAGCGTAGGCGGATTATTAAGTCAGTGGTGAAAGACGG,...,GTGCCAGCCGCCGCGGTAATACGTAGGGGGCAAGCGTTATCCGGATTTACTGGGTGTAAAGGGAGCGTAGACGGCGCAGCAAGTCTGATGTGAAAGGCAG,GTGCCAGCAGCCGCGGTAAGACAGAGGGTGCAAACGTTGCTCGGAATCACTGGGCGTAAAGGGCGTGTAGGCGGGAGAGAAAGTCGGGCGTGAAATCCCT,GTGCCAGCCGCGGTAATACGTAGGGGGCTAGCGTTGTCCGGAATCACTGGGCGTAAAGGGTTCGCAGGCGGAAATGCAAGTCAGGTGTAAAAGGCAGTAG,GTGCCAGCAGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCCGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTATCCGGAATCATTGGGTTTAAAGGGTCCGCAGGCGGATTTATAAGTCAGTGGTGAAAGCCTA,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCGAGCGTTGTCCGGAATTACTGGGTGTAAAGGGCGTGTAGGCGGGAAGGTAAGTCAGATGTGAAATACCG,GTGCCAGCCGCCGCGGTAATACGGAGGATGCGAGCGTTATTCGGAATCATTGGGTTTAAAGGGTCTGTAGGCGGGCTATTAAGTCAGAGGTGAAAGGTTT,GTGCCAGCCGCCGCGGTAAGACGAAGGGGGCTAGCGTTGTTCGGAATTACTGGGCGTAAAGCGCGTGCAGGCGGTTATCCAAGTCGGGTGTGAAAGCCTT,GTCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGT
900344,20.612399,20.135870,18.456971,18.127492,16.811815,16.428823,15.800215,15.800215,15.512532,14.819385,...,-0.094738,-0.094738,-0.094738,-0.094738,-0.094738,-0.094738,-0.094738,-0.094738,-0.094738,-0.094738
900459,18.363989,18.256743,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206,...,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206,-0.222206
900221,16.585737,-0.320816,-0.320816,-0.320816,-0.320816,-0.320816,16.267284,-0.320816,-0.320816,-0.320816,...,-0.320816,-0.320816,-0.320816,-0.320816,-0.320816,-0.320816,-0.320816,-0.320816,-0.320816,-0.320816
900570,18.922757,-0.856333,-0.856333,-0.856333,15.038620,-0.856333,15.357073,-0.856333,-0.856333,-0.856333,...,-0.856333,-0.856333,-0.856333,-0.856333,-0.856333,-0.856333,-0.856333,-0.856333,-0.856333,-0.856333
900092,21.760542,21.161944,17.796988,17.185187,14.818063,-0.096060,-0.096060,-0.096060,-0.096060,15.665361,...,-0.096060,-0.096060,-0.096060,-0.096060,-0.096060,-0.096060,-0.096060,-0.096060,-0.096060,-0.096060
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9003972,19.411458,18.733605,15.120999,-1.467100,15.680615,-1.467100,18.949640,-1.467100,-1.467100,15.931929,...,-1.467100,-1.467100,-1.467100,-1.467100,-1.467100,-1.467100,-1.467100,-1.467100,-1.467100,-1.467100
900097,16.382896,-0.610668,-0.610668,-0.610668,-0.610668,-0.610668,16.701350,-0.610668,-0.610668,-0.610668,...,15.284284,14.814280,13.204842,-0.610668,-0.610668,-0.610668,-0.610668,-0.610668,-0.610668,-0.610668
900498,16.105696,16.230859,-0.417864,-0.417864,-0.417864,-0.417864,16.924007,-0.417864,16.036703,-0.417864,...,-0.417864,-0.417864,-0.417864,16.105696,15.700231,15.477088,-0.417864,-0.417864,-0.417864,-0.417864
900276,-0.414507,-0.414507,16.802201,-0.414507,-0.414507,-0.414507,18.418284,-0.414507,-0.414507,-0.414507,...,-0.414507,-0.414507,-0.414507,-0.414507,-0.414507,-0.414507,15.798899,14.499616,14.094151,13.401004


In [22]:
# Filter metadata for AD lesional skin
adl_meta = metadata[metadata['group'] == 'skin-ADL']
common_samples = df.index.intersection(adl_meta.index)

# Extract SCORAD scores and drop NaNs
y_adl_full = pd.to_numeric(adl_meta.loc[common_samples, 'o_scorad'], errors='coerce')
valid_samples = y_adl_full.dropna().index

# Final filtered data
X_adl = df.loc[valid_samples]
y_adl = y_adl_full.loc[valid_samples]
groups_adl = adl_meta.loc[valid_samples, 'pid']


In [23]:
def run_group_stratified_regression(X, y, groups, n_splits=3):
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    results = []

    for train_idx, test_idx in folds:
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

        model = RandomForestRegressor(n_estimators=1000, random_state=42)
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)

        results.append({
            'y_true': y_test,
            'y_pred': y_pred,
            'r2': r2_score(y_test, y_pred),
            'mse': mean_squared_error(y_test, y_pred),
            'mae': mean_absolute_error(y_test, y_pred),
        })

    return results

In [24]:
adl_results = run_group_stratified_regression(X_adl, y_adl, groups_adl, n_splits=3)
adl_results

[{'y_true': Ca006ONL     34.0
  Ca006ONL2    34.0
  900403       78.0
  9004032      78.0
  9000107      25.0
  900096       44.0
  900569       36.0
  Ca010EBL     21.0
  900554       44.0
  900122       58.0
  900078       24.0
  900066       61.0
  900232       21.0
  900482       42.0
  900575       48.0
  900572       43.0
  900129       44.0
  900119       36.0
  900113       32.0
  900584       61.0
  900057       33.0
  900463       43.0
  900299       42.0
  900229       23.0
  900436       66.0
  900235       49.0
  900423       32.0
  900308       56.0
  900587       28.0
  900261       42.0
  900110       25.0
  900420       43.0
  900102       43.0
  900084       32.0
  Name: o_scorad, dtype: float64,
  'y_pred': array([36.747, 39.523, 39.196, 39.689, 50.404, 40.782, 38.784, 41.686,
         42.081, 38.984, 44.808, 46.704, 40.224, 38.231, 41.451, 39.384,
         39.052, 40.84 , 44.002, 34.359, 40.094, 50.23 , 45.139, 44.843,
         39.184, 40.024, 39.669, 40.331, 41.465

In [25]:
# Combine all folds into one list of true vs predicted values
y_true_all = np.concatenate([r['y_true'] for r in adl_results])
y_pred_all = np.concatenate([r['y_pred'] for r in adl_results])

plt.figure(figsize=(6, 6))
plt.scatter(y_true_all, y_pred_all, alpha=0.7, edgecolor='k')
plt.plot([y_true_all.min(), y_true_all.max()],
         [y_true_all.min(), y_true_all.max()], 'r--', lw=2)
plt.xlabel("True SCORAD", fontsize=14)
plt.ylabel("Predicted SCORAD", fontsize=14)
plt.title("AD Lesional Skin: True vs Predicted SCORAD", fontsize=16)
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig("../Plots/Analysis_figures/Random_Forest/regression_true_vs_predicted_filtered_by_prevalence_1pct_rare_Genus-ASV-non-collapse.png", dpi=600)

# plt.savefig("../Plots/Analysis_figures/Random_Forest/regression_true_vs_predicted_filtered_by_prevalence_1pct_rare_Genus-ASV-non-collapse.png", dpi=600)
