# Feature Importance Rankings from Random Forest Classification

In [47]:
# Standard library
import warnings
import logging
from itertools import combinations
from functools import reduce

# Scientific computing
import numpy as np
import pandas as pd
from numpy import array
import scipy
import scipy.stats as ss
from scipy import interp
from scipy.stats import wilcoxon, ttest_rel

# Visualization
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
import matplotlib.gridspec as gridspec
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from matplotlib.colors import LinearSegmentedColormap

# scikit-bio
from skbio.stats.distance import permanova

# BIOM format
import biom
from biom import load_table

# Scikit-learn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, confusion_matrix, classification_report,
    roc_curve, auc, RocCurveDisplay
)
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.preprocessing import label_binarize


In [48]:
# Define file paths
files = {
    "skin_vs_nares": "../Data/RF_Feature_Importances/ASVs/Top_10_Features_with_ASV_IDs_skin_vs_nares_V4.csv",
    "skin_ADL_vs_H": "../Data/RF_Feature_Importances/ASVs/Top_10_Features_with_ASV_IDs_skin-ADL_vs_skin-H_V4.csv",
    "skin_ADNL_vs_H": "../Data/RF_Feature_Importances/ASVs/Top_10_Features_with_ASV_IDs_skin-ADNL_vs_skin-ADL_V4.csv",
    "skin_ADNL_vs_ADL": "../Data/RF_Feature_Importances/ASVs/Top_10_Features_with_ASV_IDs_skin-ADNL_vs_skin-ADL_V4.csv",
    "nares_AD_vs_H": "../Data/RF_Feature_Importances/ASVs/Top_10_Features_with_ASV_IDs_nares-AD_vs_nares-H_V4.csv"
}

# Read and process each file
rank_dfs = {}
for key, path in files.items():
    df = pd.read_csv(path)
    df = df.sort_values("mean_importance", ascending=False).reset_index(drop=True)
    # Remove rows where the index starts with ' g__ASV'
    df = df[~df["Genus"].str.strip().str.startswith('g___ASV-2')]
    df = df[~df["Genus"].str.strip().str.startswith('g___ASV-6')]
    df = df[~df["Genus"].str.strip().str.startswith('g___ASV-7')]

    df["rank"] = range(1, len(df) + 1)
    df = df[["Genus", "rank"]]
    rank_dfs[key] = df

# List of comparisons to include
keys_to_merge = ["skin_vs_nares", "skin_ADL_vs_H", "skin_ADNL_vs_H", "skin_ADNL_vs_ADL", "nares_AD_vs_H"]

# Merge all rank dfs cleanly
merged = reduce(
    lambda left, right: pd.merge(left, right, on="Genus", how="outer"),
    [rank_dfs[key].rename(columns={"rank": key}) for key in keys_to_merge]
)

# Fill missing values with max_rank + 1
max_rank = max([df["rank"].max() for df in rank_dfs.values()])
merged = merged.set_index("Genus").fillna(max_rank + 1)

merged

Unnamed: 0_level_0,skin_vs_nares,skin_ADL_vs_H,skin_ADNL_vs_H,skin_ADNL_vs_ADL,nares_AD_vs_H
Genus,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
g__2011-GWC2-44-17_ASV-1,1706,1165,1169,1169,1250
g__AC-16_ASV-1,1803,1322,1560,1560,2106
g__Abditibacterium_ASV-1,1229,1450,2062,2062,1799
g__Abiotrophia_ASV-1,133,88,105,105,77
g__Abiotrophia_ASV-2,255,185,191,191,199
...,...,...,...,...,...
g___ASV-95,390,602,459,459,430
g___ASV-96,580,788,1847,1847,804
g___ASV-97,230,213,147,147,635
g___ASV-98,528,1312,1420,1420,847


In [49]:
# Get top 10 genera from each comparison separately
top_genera = set()
for comparison in keys_to_merge:
    top10_in_col = merged.sort_values(comparison).head(10).index
    top_genera.update(top10_in_col)

# Subset merged to include only these genera
top10 = merged.loc[list(top_genera)]

# Sort again based on total summed rank (optional, for nicer display)
top10['sum'] = top10.sum(axis=1)
top10 = top10.sort_values('sum').drop(columns='sum')

# Replace specific genus names
top10 = top10.rename(index={
    'g__F0422_ASV-1': 'g__Veillonella_F0422_ASV-1',
})

# Remove 'g__' prefix and '_ASV' from index names
top10.index = top10.index.str.replace('g__', '')
top10.index = top10.index.str.replace('_ASV', ' ASV')

top10

Unnamed: 0_level_0,skin_vs_nares,skin_ADL_vs_H,skin_ADNL_vs_H,skin_ADNL_vs_ADL,nares_AD_vs_H
Genus,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Streptococcus ASV-1,4,1,1,1,2
Staphylococcus ASV-1,3,4,3,3,1
Cutibacterium ASV-1,8,9,10,10,6
Staphylococcus ASV-2,6,5,12,12,9
Haemophilus_D_734546 ASV-1,5,21,9,9,5
Veillonella_A ASV-1,45,2,2,2,17
Streptococcus ASV-2,22,22,11,11,7
Cutibacterium ASV-2,10,12,25,25,10
Micrococcus ASV-1,7,14,23,23,23
Massilia ASV-1,17,11,6,6,53


In [50]:
# --- Truncate function to remove very white parts ---
def truncate_colormap(cmap, minval=0.0, maxval=0.8, n=100):
    new_cmap = LinearSegmentedColormap.from_list(
        f'trunc({cmap.name},{minval:.2f},{maxval:.2f})',
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap

# --- Your custom colormaps (non-reversed!) ---
colormaps = {
    'skin_vs_nares': 'Greys',
    'skin_ADL_vs_H': 'Blues',
    'skin_ADNL_vs_H': 'Greens',
    'skin_ADNL_vs_ADL': 'Purples',
    'nares_AD_vs_H': 'Oranges'
}

# --- Transpose so comparisons are rows, features are columns ---
top10_t = top10.T

# --- Create the figure ---
fig, ax = plt.subplots(figsize=(12, 4))  # wide figure

# Normalize rows
normed_data = pd.DataFrame(index=top10_t.index, columns=top10_t.columns)

for comparison in top10_t.index:
    row = top10_t.loc[comparison]
    normed = (row - row.min()) / (row.max() - row.min())
    normed_data.loc[comparison] = normed

# --- Plot manually ---
for idx, comparison in enumerate(top10_t.index):
    base_cmap = plt.get_cmap(colormaps[comparison])
    cmap = truncate_colormap(base_cmap, 0.0, 0.8)  # Cut off super light top
    row_data = normed_data.loc[comparison].astype(float)
    
    for jdx, value in enumerate(row_data):
        corrected_value = 1 - value  # <-- flip! low values colorful, high values light
        color = cmap(corrected_value)
        
        rect = plt.Rectangle((jdx, idx), 1, 1, facecolor=color, edgecolor='white', linewidth=0.5)
        ax.add_patch(rect)
        
        original_val = top10_t.loc[comparison].iloc[jdx]
        ax.text(jdx + 0.5, idx + 0.5, f"{int(original_val)}",
                ha='center', va='center', color='black', fontsize=8)

# Set axis limits
ax.set_xlim(0, top10_t.shape[1])
ax.set_ylim(0, top10_t.shape[0])

# Set ticks
ax.set_xticks(np.arange(top10_t.shape[1]) + 0.5)
ax.set_yticks(np.arange(top10_t.shape[0]) + 0.5)
ax.set_xticklabels(top10_t.columns, rotation=45, ha='right')
ax.set_yticklabels(["All Skin vs. All Nares", "Skin ADL vs Skin H", "Skin ADNL vs Skin H", "Skin ADNL vs Skin ADL", "Nares AD vs Nares H"])

# Reverse y-axis
ax.invert_yaxis()

# Clean up
ax.set_xticks(np.arange(top10_t.shape[1]), minor=True)
ax.set_yticks(np.arange(top10_t.shape[0]), minor=True)
ax.grid(False)
ax.tick_params(which="minor", bottom=False, left=False)

# Title
plt.title("Top 10 Features from Each Classification Pooled", fontsize=16, pad=35, x=0.45)

# Smaller subtitle
plt.text(
    0.45, 1.25, 
    "(lower rank values correspond to higher feature importance)",
    ha='center', va='center',
    transform=ax.transAxes,
    fontsize=12
)

plt.tight_layout()
plt.savefig('../Figures/Main/Fig_5B.png', dpi=600, bbox_inches='tight')