In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# =====================================================
# Configuration
# =====================================================
TITLE_FONTSIZE = 18
LABEL_FONTSIZE = 15
TICK_FONTSIZE = 13
LEGEND_FONTSIZE = 13

BAR_WIDTH = 0.10        # thickness of each bar
GROUP_OFFSET = 0.05     # distance between linear / nonlinear bars
GROUP_SPACING = 0.25    # distance between feature groups (Iron–Albumin–...)

FIGSIZE = (11, 5)

# =====================================================
# Output directory
# =====================================================
OUT_DIR = Path("Plots MeanAbsSHAPs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# =====================================================
# Input data (Top 6 Mean |SHAP| values for HGB)
# =====================================================
linear = {
    "Iron (µg/dL)": 0.4061119,
    "Albumin (g/dL)": 0.35980982,
    "Folate (ng/mL)": 0.27461123,
    "Vitamin D (ng/mL)": 0.26317444,
    "Ferritin (ng/mL)": 0.2306221,
    "BMI": 0.20391363,
}

nonlinear = {
    "Iron (µg/dL)": 0.42699871632360736,
    "Albumin (g/dL)": 0.383009782225512,
    "Folate (ng/mL)": 0.28539724598756583,
    "Vitamin D (ng/mL)": 0.28336777511317635,
    "Ferritin (ng/mL)": 0.2687011797522646,
    "BMI": 0.2111680725002225,
}

# =====================================================
# Ordering: by nonlinear importance (descending)
# =====================================================
ordered_features = [
    k for k, _ in sorted(nonlinear.items(), key=lambda kv: kv[1], reverse=True)
][:6]

y_linear = [linear[f] for f in ordered_features]
y_nonlinear = [nonlinear[f] for f in ordered_features]

# =====================================================
# Plot
# =====================================================
x = np.arange(len(ordered_features)) * GROUP_SPACING

fig, ax = plt.subplots(figsize=FIGSIZE)

ax.bar(
    x - GROUP_OFFSET,
    y_linear,
    BAR_WIDTH,
    label="Linear model",
    color="tab:blue"
)

ax.bar(
    x + GROUP_OFFSET,
    y_nonlinear,
    BAR_WIDTH,
    label="Nonlinear model",
    color="tab:orange"
)

# -------------------------
# Titles and labels
# -------------------------
ax.set_title("HGB", fontsize=TITLE_FONTSIZE, pad=12)
ax.set_ylabel("Mean |SHAP value|", fontsize=LABEL_FONTSIZE)

ax.set_xticks(x)
ax.set_xticklabels(
    ordered_features,
    rotation=90,
    fontsize=TICK_FONTSIZE
)

ax.tick_params(axis="y", labelsize=TICK_FONTSIZE)

# -------------------------
# Grid and legend
# -------------------------
ax.grid(axis="y", linestyle="--", alpha=0.3)
ax.legend(fontsize=LEGEND_FONTSIZE, frameon=False)

# -------------------------
# Layout & save
# -------------------------
plt.tight_layout()

out_base = OUT_DIR / "top6_mean_abs_shap_linear_vs_nonlinear_HGB"
plt.savefig(out_base.with_suffix(".png"), dpi=300, bbox_inches="tight")
plt.savefig(out_base.with_suffix(".pdf"), bbox_inches="tight")

plt.show()



In [None]:
# =====================================================
# Extra cell: repeat the same plot for HCT and RBC
# (assumes you already defined: TITLE_FONTSIZE, LABEL_FONTSIZE, TICK_FONTSIZE,
#  LEGEND_FONTSIZE, BAR_WIDTH, GROUP_OFFSET, GROUP_SPACING, FIGSIZE, OUT_DIR)
# =====================================================

import numpy as np
import matplotlib.pyplot as plt

def plot_top6_mean_abs_shap(label_name: str, linear_dict: dict, nonlinear_dict: dict):
    # Order features by nonlinear importance (descending), take Top 6
    ordered_features = [
        k for k, _ in sorted(nonlinear_dict.items(), key=lambda kv: kv[1], reverse=True)
    ][:6]

    y_linear = [linear_dict.get(f, 0.0) for f in ordered_features]
    y_nonlinear = [nonlinear_dict.get(f, 0.0) for f in ordered_features]

    # Compress spacing between feature groups
    x = np.arange(len(ordered_features)) * GROUP_SPACING

    fig, ax = plt.subplots(figsize=FIGSIZE)

    ax.bar(
        x - GROUP_OFFSET,
        y_linear,
        BAR_WIDTH,
        label="Linear model",
        color="tab:blue"
    )
    ax.bar(
        x + GROUP_OFFSET,
        y_nonlinear,
        BAR_WIDTH,
        label="Nonlinear model",
        color="tab:orange"
    )

    ax.set_title(label_name, fontsize=TITLE_FONTSIZE, pad=12)
    ax.set_ylabel("Mean |SHAP value|", fontsize=LABEL_FONTSIZE)

    ax.set_xticks(x)
    ax.set_xticklabels(ordered_features, rotation=90, fontsize=TICK_FONTSIZE)
    ax.tick_params(axis="y", labelsize=TICK_FONTSIZE)

    ax.grid(axis="y", linestyle="--", alpha=0.3)
    ax.legend(fontsize=LEGEND_FONTSIZE, frameon=False)

    plt.tight_layout()

    safe_name = label_name.replace(" ", "_").replace("(", "").replace(")", "").replace("/", "_")
    out_base = OUT_DIR / f"top6_mean_abs_shap_linear_vs_nonlinear_{safe_name}"
    plt.savefig(out_base.with_suffix(".png"), dpi=300, bbox_inches="tight")
    plt.savefig(out_base.with_suffix(".pdf"), bbox_inches="tight")

    plt.show()


# -----------------------------
# HCT data
# -----------------------------
linear_hct = {
    "Iron (µg/dL)": 0.42298865,
    "Albumin (g/dL)": 0.35477132,
    "Folate (ng/mL)": 0.28520676,
    "Vitamin D (ng/mL)": 0.27396882,
    "Ferritin (ng/mL)": 0.23225398,
    "BMI": 0.080196686,
}

nonlinear_hct = {
    "Iron (µg/dL)": 0.44640653483211645,
    "Albumin (g/dL)": 0.37933094788651217,
    "Vitamin D (ng/mL)": 0.29622803418812693,
    "Folate (ng/mL)": 0.2940664429066876,
    "Ferritin (ng/mL)": 0.2720817886075711,
    "BMI": 0.09116738995907327,
}

# -----------------------------
# RBC data
# -----------------------------
linear_rbc = {
    "Iron (µg/dL)": 0.42047942,
    "Albumin (g/dL)": 0.35793138,
    "Ferritin (ng/mL)": 0.2919717,
    "Folate (ng/mL)": 0.28232604,
    "Vitamin D (ng/mL)": 0.27763432,
    "BMI": 0.20987643,
}

nonlinear_rbc = {
    "Iron (µg/dL)": 0.43299445251358704,
    "Albumin (g/dL)": 0.3721666073518536,
    "Ferritin (ng/mL)": 0.3015694642137678,
    "Folate (ng/mL)": 0.28912278492386395,
    "Vitamin D (ng/mL)": 0.2885201269167228,
    "BMI": 0.21413452070932545,
}

# -----------------------------
# Generate plots
# -----------------------------
plot_top6_mean_abs_shap("HCT ", linear_hct, nonlinear_hct)
plot_top6_mean_abs_shap("RBC", linear_rbc, nonlinear_rbc)


In [None]:
# =====================================================
# Single figure with 3 subplots: HGB, HCT, RBC
# =====================================================

import numpy as np
import matplotlib.pyplot as plt

def plot_on_axis(ax, label_name, linear_dict, nonlinear_dict, show_legend=False):
    # Order by nonlinear importance
    ordered_features = [
        k for k, _ in sorted(nonlinear_dict.items(), key=lambda kv: kv[1], reverse=True)
    ][:6]

    y_linear = [linear_dict[f] for f in ordered_features]
    y_nonlinear = [nonlinear_dict[f] for f in ordered_features]

    x = np.arange(len(ordered_features)) * GROUP_SPACING

    ax.bar(
        x - GROUP_OFFSET,
        y_linear,
        BAR_WIDTH,
        color="tab:blue",
        label="Linear model"
    )
    ax.bar(
        x + GROUP_OFFSET,
        y_nonlinear,
        BAR_WIDTH,
        color="tab:orange",
        label="Nonlinear model"
    )

    ax.set_title(label_name, fontsize=TITLE_FONTSIZE, pad=10)
    ax.set_ylabel("Mean |SHAP value|", fontsize=LABEL_FONTSIZE)

    ax.set_xticks(x)
    ax.set_xticklabels(
        ordered_features,
        rotation=90,
        fontsize=TICK_FONTSIZE
    )
    ax.tick_params(axis="y", labelsize=TICK_FONTSIZE)

    ax.grid(axis="y", linestyle="--", alpha=0.3)

    if show_legend:
        ax.legend(fontsize=LEGEND_FONTSIZE, frameon=False)


# =====================================================
# Create figure
# =====================================================
fig, axes = plt.subplots(
    nrows=3,
    ncols=1,
    figsize=(FIGSIZE[0], FIGSIZE[1] * 3),
    sharex=False
)

# -----------------------------
# HGB
# -----------------------------
plot_on_axis(
    axes[0],
    "HGB",
    linear,
    nonlinear,
    show_legend=True
)

# -----------------------------
# HCT
# -----------------------------
plot_on_axis(
    axes[1],
    "HCT",
    linear_hct,
    nonlinear_hct
)

# -----------------------------
# RBC
# -----------------------------
plot_on_axis(
    axes[2],
    "RBC",
    linear_rbc,
    nonlinear_rbc
)

plt.tight_layout()

# =====================================================
# Save combined figure
# =====================================================
out_base = OUT_DIR / "top6_mean_abs_shap_linear_vs_nonlinear_HGB_HCT_RBC"
plt.savefig(out_base.with_suffix(".png"), dpi=300, bbox_inches="tight")
plt.savefig(out_base.with_suffix(".pdf"), bbox_inches="tight")

plt.show()


In [None]:
# =====================================================
# Single figure with 3 subplots arranged horizontally
# =====================================================

import numpy as np
import matplotlib.pyplot as plt

TITLE_FONTSIZE = 40
LABEL_FONTSIZE = 30
TICK_FONTSIZE = 30
LEGEND_FONTSIZE = 30

BAR_WIDTH = 0.10        # thickness of each bar
GROUP_OFFSET = 0.05     # distance between linear / nonlinear bars
GROUP_SPACING = 0.25    # distance between feature groups (Iron–Albumin–...)


def plot_on_axis(ax, label_name, linear_dict, nonlinear_dict, show_legend=False):
    ordered_features = [
        k for k, _ in sorted(nonlinear_dict.items(), key=lambda kv: kv[1], reverse=True)
    ][:6]

    y_linear = [linear_dict[f] for f in ordered_features]
    y_nonlinear = [nonlinear_dict[f] for f in ordered_features]

    x = np.arange(len(ordered_features)) * GROUP_SPACING

    ax.bar(
        x - GROUP_OFFSET,
        y_linear,
        BAR_WIDTH,
        color="tab:blue",
        label="Linear model"
    )
    ax.bar(
        x + GROUP_OFFSET,
        y_nonlinear,
        BAR_WIDTH,
        color="tab:orange",
        label="Nonlinear model"
    )

    ax.set_title(label_name, fontsize=TITLE_FONTSIZE, pad=10)
    ax.set_ylabel("Mean |SHAP value|", fontsize=LABEL_FONTSIZE)

    ax.set_xticks(x)
    ax.set_xticklabels(
        ordered_features,
        rotation=90,
        fontsize=TICK_FONTSIZE
    )
    ax.tick_params(axis="y", labelsize=TICK_FONTSIZE)

    ax.grid(axis="y", linestyle="--", alpha=0.3)

    if show_legend:
        ax.legend(fontsize=LEGEND_FONTSIZE, frameon=False)


# =====================================================
# Create horizontal figure
# =====================================================
VERTICAL_SCALE=2.5

fig, axes = plt.subplots(
    nrows=1,
    ncols=3,
    #figsize=(FIGSIZE[0] * 3, FIGSIZE[1]),
    figsize=(FIGSIZE[0] * 3, FIGSIZE[1] * VERTICAL_SCALE),
    sharey=True
)

# -----------------------------
# HCT
# -----------------------------
plot_on_axis(
    axes[0],
    "HCT",
    linear_hct,
    nonlinear_hct,
    show_legend=True
)


# -----------------------------
# HGB
# -----------------------------
plot_on_axis(
    axes[1],
    "HGB",
    linear,
    nonlinear,
)


# -----------------------------
# RBC
# -----------------------------
plot_on_axis(
    axes[2],
    "RBC",
    linear_rbc,
    nonlinear_rbc
)

plt.tight_layout()

# =====================================================
# Save combined figure
# =====================================================
out_base = OUT_DIR / "top6_mean_abs_shap_linear_vs_nonlinear_HGB_HCT_RBC_horizontal"
plt.savefig(out_base.with_suffix(".png"), dpi=300, bbox_inches="tight")
plt.savefig(out_base.with_suffix(".pdf"), bbox_inches="tight")

plt.show()
