In [2]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from clearit.config import OUTPUTS_DIR
from clearit.shap.io import load_shap_bundle
from clearit.shap.compute import smooth_shap_maps
from clearit.plotting.shap import plot_shap_heatmaps
from clearit.shap.metrics import (
    importance_matrix_abs_mean,
    infer_pairs_by_name,
    diagonality_index,
    channel_specificity_entropy,
    center_surround_metrics,
    to_dataframe_matrix,
)

In [4]:
test_id = "T0030"
num_per_marker = 100
dataset_name = "TNBC1-MxIF8"
annotation_name = "TME-A_ML6"
input_dir = Path(OUTPUTS_DIR, "shap", dataset_name, annotation_name, test_id, f"N{num_per_marker:04d}")
output_dir = Path("shap_aggregate", dataset_name, annotation_name, test_id, f"N{num_per_marker:04d}")

# Plotting knobs
NORMALIZE = "row"          # "none" | "global" | "row" | "column"
APPLY_SMOOTHING = True
SIGMA = 2.0
OVERLAY_FMT = "%.2g"       # formatting of numbers in tiles
OVERLAY_LOC = "lower-right"
SAVE = True

yaml_files = sorted(input_dir.glob("*.yaml"))
if not yaml_files:
    print(f"No bundles found in {input_dir}")
else:
    for yml in yaml_files:
        base = yml.with_suffix("")  # stem without extension (e.g., SHAP_TP_highest)
        shap_values, df_cells, meta = load_shap_bundle(base)

        labels = meta.get("labels", {})
        chan_names  = labels.get("channel_strings")
        class_names = labels.get("class_strings")
        order       = labels.get("desired_channel_order")

        # Reorder channels if specified in metadata (keeps names in sync)
        if order:
            shap_values = shap_values[:, order, :, :, :]
            if chan_names:
                chan_names = [chan_names[i] for i in order]

        # Optional smoothing at load time
        if APPLY_SMOOTHING and not meta.get("shap", {}).get("smoothed", False):
            shap_values = smooth_shap_maps(shap_values, sigma=SIGMA)

        # ---------- compute metrics ----------
        I = importance_matrix_abs_mean(shap_values)              # (C,K)
        if NORMALIZE == "row":
            I = I / I.sum(axis=0, keepdims=True)
        elif NORMALIZE == "column":
            I = I / I.sum(axis=1, keepdims=True)
        elif NORMALIZE == "global":
            I = I / I.max()

        pairs = infer_pairs_by_name(chan_names or [], class_names or [])
        diag_val = diagonality_index(I, pairs)
        ent = channel_specificity_entropy(I)                     # (K,)
        cs  = center_surround_metrics(shap_values)               # dict of (C,K)

        # Decide save dir for this bundle
        save_dir = (output_dir or base.parent)
        save_dir.mkdir(parents=True, exist_ok=True)

        # Write metrics CSVs
        stem = base.name
        to_dataframe_matrix(I, chan_names, class_names).to_csv(save_dir / f"{stem}_importance_absmean.csv", float_format="%.6g")
        pd.DataFrame({"diagonality_index": [diag_val], "num_pairs": [len(pairs)]}).to_csv(save_dir / f"{stem}_diagonality.csv", index=False)
        pd.DataFrame({
            "class": class_names or [f"Class {k}" for k in range(I.shape[1])],
            "entropy_nats": ent,
            "entropy_bits": ent / np.log(2)
        }).to_csv(save_dir / f"{stem}_class_entropy.csv", index=False, float_format="%.6g")

        for key, M in cs.items():  # center_abs_share, CSI, sign_inversion_frac
            to_dataframe_matrix(M, chan_names, class_names).to_csv(save_dir / f"{stem}_{key}.csv", float_format="%.6g")

        # ---------- plot with overlay (importance) ----------
        fig, _ = plot_shap_heatmaps(
            shap_values,
            channel_strings=chan_names,
            class_strings=class_names,
            title=None,
            average_over_batch=True,
            figsize_multiplier=2.0,
            normalize=NORMALIZE,
            overlay_matrix=I,
            overlay_fmt=OVERLAY_FMT,
            overlay_loc=OVERLAY_LOC,
            font_scale=1.5,
            cbar_width_scale=1.5,
        )

        out_png = save_dir / f"{stem}_{NORMALIZE}.png"
        if SAVE:
            fig.savefig(out_png, dpi=300, bbox_inches="tight")
            plt.close(fig)
            print(f"Saved {out_png.name} and metrics for {stem}")
        else:
            plt.show()


Saved SHAP_FN_highest_row.png and metrics for SHAP_FN_highest
Saved SHAP_FN_lowest_row.png and metrics for SHAP_FN_lowest
Saved SHAP_FP_highest_row.png and metrics for SHAP_FP_highest
Saved SHAP_FP_lowest_row.png and metrics for SHAP_FP_lowest
Saved SHAP_TN_highest_row.png and metrics for SHAP_TN_highest
Saved SHAP_TN_lowest_row.png and metrics for SHAP_TN_lowest
Saved SHAP_TP_highest_row.png and metrics for SHAP_TP_highest
Saved SHAP_TP_lowest_row.png and metrics for SHAP_TP_lowest


In [5]:
test_id = "T0129"
num_per_marker = 100
dataset_name = "TNBC2-MIBI/TNBC2-MIBI8"
annotation_name = "TME-A_ML6"
input_dir = Path(OUTPUTS_DIR, "shap", dataset_name, annotation_name, test_id, f"N{num_per_marker:04d}")
output_dir = Path("shap_aggregate", dataset_name, annotation_name, test_id, f"N{num_per_marker:04d}")

# Plotting knobs
NORMALIZE = "row"          # "none" | "global" | "row" | "column"
APPLY_SMOOTHING = True
SIGMA = 2.0
OVERLAY_FMT = "%.2g"       # formatting of numbers in tiles
OVERLAY_LOC = "lower-right"
SAVE = True

yaml_files = sorted(input_dir.glob("*.yaml"))
if not yaml_files:
    print(f"No bundles found in {input_dir}")
else:
    for yml in yaml_files:
        base = yml.with_suffix("")  # stem without extension (e.g., SHAP_TP_highest)
        shap_values, df_cells, meta = load_shap_bundle(base)

        labels = meta.get("labels", {})
        chan_names  = labels.get("channel_strings")
        class_names = labels.get("class_strings")
        order       = labels.get("desired_channel_order")

        # Reorder channels if specified in metadata (keeps names in sync)
        if order:
            shap_values = shap_values[:, order, :, :, :]
            if chan_names:
                chan_names = [chan_names[i] for i in order]

        # Optional smoothing at load time
        if APPLY_SMOOTHING and not meta.get("shap", {}).get("smoothed", False):
            shap_values = smooth_shap_maps(shap_values, sigma=SIGMA)

        # ---------- compute metrics ----------
        I = importance_matrix_abs_mean(shap_values)              # (C,K)
        if NORMALIZE == "row":
            I = I / I.sum(axis=0, keepdims=True)
        elif NORMALIZE == "column":
            I = I / I.sum(axis=1, keepdims=True)
        elif NORMALIZE == "global":
            I = I / I.max()

        pairs = infer_pairs_by_name(chan_names or [], class_names or [])
        diag_val = diagonality_index(I, pairs)
        ent = channel_specificity_entropy(I)                     # (K,)
        cs  = center_surround_metrics(shap_values)               # dict of (C,K)

        # Decide save dir for this bundle
        save_dir = (output_dir or base.parent)
        save_dir.mkdir(parents=True, exist_ok=True)

        # Write metrics CSVs
        stem = base.name
        to_dataframe_matrix(I, chan_names, class_names).to_csv(save_dir / f"{stem}_importance_absmean.csv", float_format="%.6g")
        pd.DataFrame({"diagonality_index": [diag_val], "num_pairs": [len(pairs)]}).to_csv(save_dir / f"{stem}_diagonality.csv", index=False)
        pd.DataFrame({
            "class": class_names or [f"Class {k}" for k in range(I.shape[1])],
            "entropy_nats": ent,
            "entropy_bits": ent / np.log(2)
        }).to_csv(save_dir / f"{stem}_class_entropy.csv", index=False, float_format="%.6g")

        for key, M in cs.items():  # center_abs_share, CSI, sign_inversion_frac
            to_dataframe_matrix(M, chan_names, class_names).to_csv(save_dir / f"{stem}_{key}.csv", float_format="%.6g")

        # ---------- plot with overlay (importance) ----------
        fig, _ = plot_shap_heatmaps(
            shap_values,
            channel_strings=chan_names,
            class_strings=class_names,
            title=None,
            average_over_batch=True,
            figsize_multiplier=2.0,
            normalize=NORMALIZE,
            overlay_matrix=I,
            overlay_fmt=OVERLAY_FMT,
            overlay_loc=OVERLAY_LOC,
            font_scale=1.5,
            cbar_width_scale=1.5,
        )

        out_png = save_dir / f"{stem}_{NORMALIZE}.png"
        if SAVE:
            fig.savefig(out_png, dpi=300, bbox_inches="tight")
            plt.close(fig)
            print(f"Saved {out_png.name} and metrics for {stem}")
        else:
            plt.show()


Saved SHAP_FN_highest_row.png and metrics for SHAP_FN_highest
Saved SHAP_FN_lowest_row.png and metrics for SHAP_FN_lowest
Saved SHAP_FP_highest_row.png and metrics for SHAP_FP_highest
Saved SHAP_FP_lowest_row.png and metrics for SHAP_FP_lowest
Saved SHAP_TN_highest_row.png and metrics for SHAP_TN_highest
Saved SHAP_TN_lowest_row.png and metrics for SHAP_TN_lowest
Saved SHAP_TP_highest_row.png and metrics for SHAP_TP_highest
Saved SHAP_TP_lowest_row.png and metrics for SHAP_TP_lowest
