In [None]:
import pandas as pd
import yaml
from _data import (
    get_per_study_data,
    get_split_features,
    get_split_samples,
    DEFAULT_PATIENT_ID_COL,
    DEFAULT_STUDY_ID_COL,
    DEFAULT_DICOM_ID_COL,
    DEFAULT_SPLIT_COL,
    DEFAULT_VIEW_COL,
    DEFAULT_LABELS,
    DEFAULT_VIEW_ORDER,
    DEFAULT_FINDINGS_COL,
    DEFAULT_IMPRESSION_COL,
    DEFAULT_IMG_PROJ_KEY,
)
from _prompt import prepare_prompt
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import trange
import _prompt

In [None]:
def get_retrieved_idxs(section_type):
    _prompt.cached = None
    split_csv = "/opt/gpudata/mimic-cxr/mimic-cxr-2.0.0-split.csv"
    metadata_csv = "/opt/gpudata/mimic-cxr/mimic-cxr-2.0.0-metadata.csv"
    true_label_csv = "/opt/gpudata/mimic-cxr/mimic-cxr-2.0.0-chexpert.csv"
    predicted_label_csv = "/opt/gpudata/rrg-data-2/image-labels/pred_pr.csv"
    label_type = "pred"
    report_csv = "/opt/gpudata/mimic-cxr/mimic_cxr_sectioned.csv"
    add_other_label = True
    feature_h5 = "/opt/gpudata/rrg-data-2/biovilt-features.h5"
    prompt_yaml = "prompts.yaml"
    batch_size = 32
    labels = DEFAULT_LABELS.copy()
    prompt_type = "simple"

    # TODO parameterize hardcoded split remapping
    split_remap = {
        "train": "retrieval",
        "validate": "retrieval",
        "test": "inference",
    }

    # Load and merge data relative to true labels
    retrieval_df = get_per_study_data(
        split_csv=split_csv,
        metadata_csv=metadata_csv,
        label_csv=true_label_csv,
        report_csv=report_csv,
        patient_id_col=DEFAULT_PATIENT_ID_COL,
        study_id_col=DEFAULT_STUDY_ID_COL,
        dicom_id_col=DEFAULT_DICOM_ID_COL,
        split_col=DEFAULT_SPLIT_COL,
        view_col=DEFAULT_VIEW_COL,
        labels=labels,
        view_order=DEFAULT_VIEW_ORDER,
        report_cols=[DEFAULT_FINDINGS_COL, DEFAULT_IMPRESSION_COL],
        split_remap=split_remap,
    )

    # Load and merge data relative to predicted labels if provided
    inference_df = get_per_study_data(
        split_csv=split_csv,
        metadata_csv=metadata_csv,
        label_csv=predicted_label_csv or true_label_csv,
        report_csv=report_csv,
        patient_id_col=DEFAULT_PATIENT_ID_COL,
        study_id_col=DEFAULT_STUDY_ID_COL,
        dicom_id_col=DEFAULT_DICOM_ID_COL,
        split_col=DEFAULT_SPLIT_COL,
        view_col=DEFAULT_VIEW_COL,
        labels=labels,
        view_order=DEFAULT_VIEW_ORDER,
        report_cols=[DEFAULT_FINDINGS_COL, DEFAULT_IMPRESSION_COL],
        split_remap=split_remap,
    )

    # Check that true and predicted labels result in same merged dataframes
    cols = [DEFAULT_PATIENT_ID_COL, DEFAULT_STUDY_ID_COL, DEFAULT_DICOM_ID_COL, DEFAULT_SPLIT_COL, DEFAULT_VIEW_COL]
    assert retrieval_df[cols].equals(inference_df[cols])

    # Filter dataset to only those with given section type
    if section_type == "findings":
        report_cols = [DEFAULT_FINDINGS_COL]
    elif section_type == "impression":
        report_cols = [DEFAULT_IMPRESSION_COL]
    elif section_type in ["both", "findings-intersect", "impression-intersect"]:
        report_cols = [DEFAULT_FINDINGS_COL, DEFAULT_IMPRESSION_COL]
    else:
        raise ValueError(f"Unknown section type: {section_type}")

    mask = retrieval_df[report_cols].notna().all(axis=1)
    retrieval_df = retrieval_df[mask].reset_index(drop=True).copy()
    inference_df = inference_df[mask].reset_index(drop=True).copy()

    # Add implicit "other" label
    if add_other_label:
        # TODO does "other" definition depend on prompt type?
        retrieval_df["Other"] = (retrieval_df[labels] != 1).all(axis=1).astype(int)
        inference_df["Other"] = (inference_df[labels] != 1).all(axis=1).astype(int)
        labels += ["Other"]

    # Prepare per-split projected embeddings
    features = get_split_features(
        feature_h5=feature_h5,
        feature_key=DEFAULT_IMG_PROJ_KEY,
        sample_df=retrieval_df,
        patient_id_col=DEFAULT_PATIENT_ID_COL,
        study_id_col=DEFAULT_STUDY_ID_COL,
        dicom_id_col=DEFAULT_DICOM_ID_COL,
        split_col=DEFAULT_SPLIT_COL,
    )
    retrieval_features = features["retrieval"]
    inference_features = features["inference"]

    # Prepare per-split metadata, labels, and reports
    retrieval_samples = get_split_samples(
        sample_df=retrieval_df,
        split_col=DEFAULT_SPLIT_COL,
    )["retrieval"]
    inference_samples = get_split_samples(
        sample_df=inference_df,
        split_col=DEFAULT_SPLIT_COL,
    )["inference"]

    # Prepare prompt templates
    with open(prompt_yaml) as f:
        prompt_templates = yaml.safe_load(f)

    # Compute similarity between inference and retrieval samples
    similarity = cosine_similarity(inference_features, retrieval_features)

    # Generate reports
    N = len(inference_samples)

    exact_filter_retrieved_idxs = []
    for i in trange(N):
        prompt, target_report, retrieved_studies, idxs = prepare_prompt(
            retrieval_samples=retrieval_samples,
            target_sample=inference_samples.iloc[i],
            target_similarity=similarity[i],
            k=5,
            prompt_templates=prompt_templates,
            filter_type="exact",
            prompt_type=prompt_type,
            section_type=section_type,
            labels=labels,
            findings_col=DEFAULT_FINDINGS_COL,
            impression_col=DEFAULT_IMPRESSION_COL,
            study_id_col=DEFAULT_STUDY_ID_COL,
            return_relative_idxs=True,
        )
        exact_filter_retrieved_idxs.append(idxs)
    
    partial_filter_retrieved_idxs = []
    for i in trange(N):
        prompt, target_report, retrieved_studies, idxs = prepare_prompt(
            retrieval_samples=retrieval_samples,
            target_sample=inference_samples.iloc[i],
            target_similarity=similarity[i],
            k=5,
            prompt_templates=prompt_templates,
            filter_type="partial",
            prompt_type=prompt_type,
            section_type=section_type,
            labels=labels,
            findings_col=DEFAULT_FINDINGS_COL,
            impression_col=DEFAULT_IMPRESSION_COL,
            study_id_col=DEFAULT_STUDY_ID_COL,
            return_relative_idxs=True,
        )
        partial_filter_retrieved_idxs.append(idxs)
    
    return exact_filter_retrieved_idxs, partial_filter_retrieved_idxs

In [None]:
findings_exact, findings_partial = get_retrieved_idxs("findings")
impression_exact, impression_partial = get_retrieved_idxs("impression")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Patch

cmap = sns.color_palette(palette="Set3")
cmap

In [None]:
fig = plt.figure(layout="constrained", figsize=(11, 5))

gs = GridSpec(2, 3, figure=fig, width_ratios=[5, 5, 1])
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])
ax_legend = fig.add_subplot(gs[:, 2])

bins = list(range(27))
bins[-1] = 1000000

sns.histplot([x - 1 for xs in findings_exact for x in xs], bins=bins, ax=ax1, color=cmap[4], linewidth=1, zorder=10, alpha=1)
sns.histplot([x - 1 for xs in impression_exact for x in xs], bins=bins, ax=ax2, color=cmap[4], linewidth=1, zorder=10, alpha=1)
sns.histplot([x - 1 for xs in findings_partial for x in xs], bins=bins, ax=ax3, color=cmap[1], linewidth=1, zorder=10, alpha=1)
sns.histplot([x - 1 for xs in impression_partial for x in xs], bins=bins, ax=ax4, color=cmap[1], linewidth=1, zorder=10, alpha=1)

ax1.set_ylim([0, 550])
ax2.set_ylim([0, 550])
ax3.set_ylim([0, 550])
ax4.set_ylim([0, 550])

ax1.set_xlim([0, 26])
ax2.set_xlim([0, 26])
ax3.set_xlim([0, 26])
ax4.set_xlim([0, 26])

ax1.set_xticks([0, 5, 10, 15, 20, 25])
ax2.set_xticks([0, 5, 10, 15, 20, 25])
ax3.set_xticks([0, 5, 10, 15, 20, 25])
ax4.set_xticks([0, 5, 10, 15, 20, 25])

ax1.set_xticklabels([])
ax2.set_xticklabels([])
ax3.set_xticklabels([0, 5, 10, 15, 20, "25+"])
ax4.set_xticklabels([0, 5, 10, 15, 20, "25+"])

ax1.set_xlabel("")
ax2.set_xlabel("")
ax3.set_xlabel("Image Similarity Rank")
ax4.set_xlabel("Image Similarity Rank")

ax2.set_ylabel("")
ax4.set_ylabel("")

ax1.set_yticks([0, 100, 200, 300, 400, 500])
ax2.set_yticks([0, 100, 200, 300, 400, 500])
ax3.set_yticks([0, 100, 200, 300, 400, 500])
ax4.set_yticks([0, 100, 200, 300, 400, 500])

ax1.set_yticklabels([0, 100, 200, 300, 400, 500])
ax2.set_yticklabels([])
ax3.set_yticklabels([0, 100, 200, 300, 400, 500])
ax4.set_yticklabels([])

ax1.grid(which="major", axis="y", zorder=0)
ax2.grid(which="major", axis="y", zorder=0)
ax3.grid(which="major", axis="y", zorder=0)
ax4.grid(which="major", axis="y", zorder=0)

ax1.set_title(f"Findings, N={len(findings_exact)}")
ax2.set_title(f"Impression, N={len(impression_exact)}")

legend_elements = [
    Patch(facecolor=cmap[4], edgecolor="gray", label="Exact"),
    Patch(facecolor=cmap[1], edgecolor="gray", label="Partial"),
]

ax1_25n = (pd.Series([x - 1 for xs in findings_exact for x in xs]) >= 25).sum()
ax1.text(25.6, 275, f"{ax1_25n}", zorder=15, rotation=90, ha="center", va="center")
ax2_25n = (pd.Series([x - 1 for xs in findings_partial for x in xs]) >= 25).sum()
ax2.text(25.6, 275, f"{ax2_25n}", zorder=15, rotation=90, ha="center", va="center")
ax3_25n = (pd.Series([x - 1 for xs in impression_exact for x in xs]) >= 25).sum()
ax3.text(25.6, 275, f"{ax3_25n}", zorder=15, rotation=90, ha="center", va="center")
ax4_25n = (pd.Series([x - 1 for xs in impression_partial for x in xs]) >= 25).sum()
ax4.text(25.6, 275, f"{ax4_25n}", zorder=15, rotation=90, ha="center", va="center")

ax_legend.legend(handles=legend_elements, loc="center left", title="Filter", title_fontproperties={"weight": "semibold"})
ax_legend.axis("off")

fig.suptitle("Top 5 Filtered Image Similarity", x=0.47)

fig.savefig(f"../figs/pngs/filter-rank-hist.png", dpi=300)
fig.savefig(f"../figs/pdfs/filter-rank-hist.pdf")