In [1]:
from pathlib import Path
import os
import pandas as pd
import torch
import yaml
import matplotlib.pyplot as plt

from clearit.config import MODELS_DIR, DATASETS_DIR, OUTPUTS_DIR
from clearit.metrics.gather_test_results import get_classifier_test_results
from clearit.inference.pipeline import load_encoder_head
from clearit.shap.utils import load_model_from_test, get_dataloader_for_shap
from clearit.shap.explainer import prepare_shap_explainer
from clearit.shap.compute import compute_shap_values_batch, smooth_shap_maps
from clearit.shap.io import build_shap_metadata, save_shap_bundle
from clearit.plotting.shap import plot_shap_heatmaps
from clearit.io.select import select_cells_by_outcome_and_confidence
from clearit.io.classification_report import classification_report
import platform, torch, shap as _shap, numpy as _np


In [3]:
# Test ID to analyze - T0030 contains linear evaluation results for the best encoder/classifier pair obtained in round 2 for TNBC1-MxIF8 + TME-A_ML6
test_id = "T0030"

# Specify dataset and annotations (we explicitly define these here, but they can also be obtained from test_conf.yaml)
dataset_name = "TNBC1-MxIF8"
annotation_name = "TME-A_ML6"

# Load labels and class strings
df_all = pd.read_csv(Path(DATASETS_DIR, dataset_name, annotation_name, "labels.csv"))
class_strings = pd.read_csv(Path(DATASETS_DIR, dataset_name, annotation_name, "class_names.csv"))["name"].tolist()
markers = class_strings  # alias

# Specify channel names and order for plotting
channel_strings_desired = ['DAPI', 'CK', 'CD3', 'CD8', 'CD20', 'CD56', 'CD68', 'AF']
desired_channel_order = [0, 1, 2, 4, 6, 5, 3, 7]

# Define what we will loop over and how many cells to sample for the respective prediction type
confidence_orders = ["highest", "lowest"]
outcomes = ["TP", "TN", "FP", "FN"]
num_per_marker = 100

# Output setup
output_dir = Path(OUTPUTS_DIR,"shap",dataset_name,annotation_name,test_id,f"N{num_per_marker:04d}")
output_dir.mkdir(parents=True, exist_ok=True)
master_report_entries = []

# Build test results
df_test_results = get_classifier_test_results(test_id, df_all, class_strings)

In [4]:
# Use CUDA if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the exact model used in the test
model, test_cfg, head_cfg = load_model_from_test(test_id, device=device)

# Use head_cfg to drive shapes/metadata
num_classes = int(head_cfg["num_classes"])
patch_size  = int(head_cfg.get("img_size", 64))
dataset_name = head_cfg["dataset_name"]           

# If df_test_results lacks the raw 'label' column, merge it from df_all once:
if "label" not in df_test_results.columns:
    df_test_results = df_test_results.merge(
        df_all[["fname", "cell_x", "cell_y", "label"]],
        on=["fname", "cell_x", "cell_y"],
        how="left",
    )

SAVE_SMOOTHED = False     # save the SHAP values unsmoothed

for outcome in outcomes:
    for order in confidence_orders:
        df_filtered = select_cells_by_outcome_and_confidence(
            df_test_results, markers=markers, outcome=outcome, order=order, num_per_marker=num_per_marker
        )
        if df_filtered.empty:
            continue

        loader, _ = get_dataloader_for_shap(
            df_filtered,
            dataset_name=dataset_name,
            patch_size=patch_size,
            label_mode="multilabel",
            num_classes=num_classes,
            batch_size=64,
            num_workers=4,
        )

        explainer, background, sample_in = prepare_shap_explainer(
            model, loader, device=device, background_strategy="zeros"
        )

        shap_values, _ = compute_shap_values_batch(
            explainer, loader, device=device, check_additivity=False
        )

        # Optionally smooth now; otherwise keep raw and smooth when plotting
        if SAVE_SMOOTHED:
            shap_values_to_save = smooth_shap_maps(shap_values, sigma=SMOOTH_SIGMA)
        else:
            shap_values_to_save = shap_values

        # Build metadata for this condition
        selection = {"outcome": outcome, "order": order, "num_per_marker": int(num_per_marker)}
        background_md = {"strategy": "zeros", "num_batches": 1}
        shap_md = {
            "check_additivity": False,
            "smoothed": bool(SAVE_SMOOTHED),
            "sigma": float(SMOOTH_SIGMA) if SAVE_SMOOTHED else None,
            "array_shape": tuple(shap_values_to_save.shape),
            "array_dtype": str(shap_values_to_save.dtype),
        }
        meta = build_shap_metadata(
            test_id=test_id,
            dataset_name=dataset_name,
            annotation_name=annotation_name,
            head_cfg=head_cfg,
            selection=selection,
            channel_strings=channel_strings_desired,
            class_strings=class_strings,
            desired_channel_order=desired_channel_order,
            background=background_md,
            shap_config=shap_md,
        )

        # Save under a stable, discoverable name
        base = output_dir / f"SHAP_{outcome}_{order}"
        paths = save_shap_bundle(base, shap_values_to_save, df_filtered, meta, dtype="float32", compressed=True)
        print(f"Saved bundle: {paths['npz'].name}, {paths['table'].name}, {paths['yaml'].name}")


[encoder] missing=2 unexpected=4
  eg missing: ['main_backbone.fc.weight', 'main_backbone.fc.bias']
  eg unexpected: ['mlp.0.weight', 'mlp.0.bias', 'mlp.2.weight', 'mlp.2.bias']
[encoder] checkpoint has no fc weights -> using Identity fc to match old training
[head] strict load ok.


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_TP_highest.npz, SHAP_TP_highest.csv.gz, SHAP_TP_highest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_TP_lowest.npz, SHAP_TP_lowest.csv.gz, SHAP_TP_lowest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_TN_highest.npz, SHAP_TN_highest.csv.gz, SHAP_TN_highest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_TN_lowest.npz, SHAP_TN_lowest.csv.gz, SHAP_TN_lowest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_FP_highest.npz, SHAP_FP_highest.csv.gz, SHAP_FP_highest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_FP_lowest.npz, SHAP_FP_lowest.csv.gz, SHAP_FP_lowest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_FN_highest.npz, SHAP_FN_highest.csv.gz, SHAP_FN_highest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_FN_lowest.npz, SHAP_FN_lowest.csv.gz, SHAP_FN_lowest.yaml


In [5]:
# Test ID to analyze - T0030 contains linear evaluation results for the best encoder/classifier pair obtained in round 2 for TNBC2-MIBI8 + TME-A_ML6
test_id = "T0129"

# Specify dataset and annotations (we explicitly define these here, but they can also be obtained from test_conf.yaml)
dataset_name = "TNBC2-MIBI/TNBC2-MIBI8"
annotation_name = "TME-A_ML6"

# Load labels and class strings
df_all = pd.read_csv(Path(DATASETS_DIR, dataset_name, annotation_name, "labels.csv"))
class_strings = pd.read_csv(Path(DATASETS_DIR, dataset_name, annotation_name, "class_names.csv"))["name"].tolist()
markers = class_strings  # alias

# Specify channel names and order for plotting
channel_strings_desired = ['dsDNA', 'Pan-Keratin', 'CD3', 'CD8', 'CD20', 'CD56', 'CD68', 'BG']
desired_channel_order = [0, 1, 2, 4, 6, 5, 3, 7]

# Define what we will loop over and how many cells to sample for the respective prediction type
confidence_orders = ["highest", "lowest"]
outcomes = ["TP", "TN", "FP", "FN"]
num_per_marker = 100

# Output setup
output_dir = Path(OUTPUTS_DIR,"shap",dataset_name,annotation_name,test_id,f"N{num_per_marker:04d}")
output_dir.mkdir(parents=True, exist_ok=True)
master_report_entries = []

# Build test results
df_test_results = get_classifier_test_results(test_id, df_all, class_strings)

In [6]:
# Use CUDA if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the exact model used in the test
model, test_cfg, head_cfg = load_model_from_test(test_id, device=device)

# Use head_cfg to drive shapes/metadata
num_classes = int(head_cfg["num_classes"])
patch_size  = int(head_cfg.get("img_size", 64))
dataset_name = head_cfg["dataset_name"]           

# If df_test_results lacks the raw 'label' column, merge it from df_all once:
if "label" not in df_test_results.columns:
    df_test_results = df_test_results.merge(
        df_all[["fname", "cell_x", "cell_y", "label"]],
        on=["fname", "cell_x", "cell_y"],
        how="left",
    )

SAVE_SMOOTHED = False     # save the SHAP values unsmoothed

for outcome in outcomes:
    for order in confidence_orders:
        df_filtered = select_cells_by_outcome_and_confidence(
            df_test_results, markers=markers, outcome=outcome, order=order, num_per_marker=num_per_marker
        )
        if df_filtered.empty:
            continue

        loader, _ = get_dataloader_for_shap(
            df_filtered,
            dataset_name=dataset_name,
            patch_size=patch_size,
            label_mode="multilabel",
            num_classes=num_classes,
            batch_size=64,
            num_workers=4,
        )

        explainer, background, sample_in = prepare_shap_explainer(
            model, loader, device=device, background_strategy="zeros"
        )

        shap_values, _ = compute_shap_values_batch(
            explainer, loader, device=device, check_additivity=False
        )

        # Optionally smooth now; otherwise keep raw and smooth when plotting
        if SAVE_SMOOTHED:
            shap_values_to_save = smooth_shap_maps(shap_values, sigma=SMOOTH_SIGMA)
        else:
            shap_values_to_save = shap_values

        # Build metadata for this condition
        selection = {"outcome": outcome, "order": order, "num_per_marker": int(num_per_marker)}
        background_md = {"strategy": "zeros", "num_batches": 1}
        shap_md = {
            "check_additivity": False,
            "smoothed": bool(SAVE_SMOOTHED),
            "sigma": float(SMOOTH_SIGMA) if SAVE_SMOOTHED else None,
            "array_shape": tuple(shap_values_to_save.shape),
            "array_dtype": str(shap_values_to_save.dtype),
        }
        meta = build_shap_metadata(
            test_id=test_id,
            dataset_name=dataset_name,
            annotation_name=annotation_name,
            head_cfg=head_cfg,
            selection=selection,
            channel_strings=channel_strings_desired,
            class_strings=class_strings,
            desired_channel_order=desired_channel_order,
            background=background_md,
            shap_config=shap_md,
        )

        # Save under a stable, discoverable name
        base = output_dir / f"SHAP_{outcome}_{order}"
        paths = save_shap_bundle(base, shap_values_to_save, df_filtered, meta, dtype="float32", compressed=True)
        print(f"Saved bundle: {paths['npz'].name}, {paths['table'].name}, {paths['yaml'].name}")


[encoder] missing=2 unexpected=4
  eg missing: ['main_backbone.fc.weight', 'main_backbone.fc.bias']
  eg unexpected: ['mlp.0.weight', 'mlp.0.bias', 'mlp.2.weight', 'mlp.2.bias']
[encoder] checkpoint has no fc weights -> using Identity fc to match old training
[head] strict load ok.


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_TP_highest.npz, SHAP_TP_highest.csv.gz, SHAP_TP_highest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_TP_lowest.npz, SHAP_TP_lowest.csv.gz, SHAP_TP_lowest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_TN_highest.npz, SHAP_TN_highest.csv.gz, SHAP_TN_highest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_TN_lowest.npz, SHAP_TN_lowest.csv.gz, SHAP_TN_lowest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_FP_highest.npz, SHAP_FP_highest.csv.gz, SHAP_FP_highest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_FP_lowest.npz, SHAP_FP_lowest.csv.gz, SHAP_FP_lowest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_FN_highest.npz, SHAP_FN_highest.csv.gz, SHAP_FN_highest.yaml


unrecognized nn.Module: Identity
Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.


Saved bundle: SHAP_FN_lowest.npz, SHAP_FN_lowest.csv.gz, SHAP_FN_lowest.yaml
