## Inspect Prompts

In [None]:
import yaml
import difflib

In [None]:
with open("rrg/prompts.yaml") as f:
    prompts = yaml.safe_load(f)

def diff(a: str, b: str):
    a = a.splitlines(keepends=True)
    b = b.splitlines(keepends=True)
    diff = difflib.unified_diff(a, b)
    print("".join(diff))

In [None]:
diff(prompts["naive"], prompts["simple"])

In [None]:
diff(prompts["simple"], prompts["verbose"])

In [None]:
diff(prompts["verbose"], prompts["instruct"])

## Evaluate Runs

In [None]:
# Install from source while waiting for merge of https://github.com/trevismd/statannotations/pull/155
# !pip install https://github.com/getzze/statannotations/archive/compat-seaborn-13.zip

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from statannotations.Annotator import Annotator
from collections import defaultdict

In [None]:
experiments = {
    "Findings - Model": [
        ("LaB-RAG", "/opt/gpudata/rrg-data-2/exp-findings/exp-filter/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings_METRICS.csv"),
        ("RGRG", "/opt/gpudata/rrg-data-2/exp-baseline/rgrg_findings_METRICS.csv"),
        ("CheXagent", "/opt/gpudata/rrg-data-2/exp-baseline/chexagent_findings_METRICS.csv"),
        ("CXRMate", "/opt/gpudata/rrg-data-2/exp-baseline/cxr-mate_findings_METRICS.csv"),
    ],
    "Impression - Model": [
        ("LaB-RAG", "/opt/gpudata/rrg-data-2/exp-findings/exp-filter/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings_METRICS.csv"),
        ("CXR-RePaiR", "/opt/gpudata/rrg-data-2/exp-baseline/cxr-repair_impression_METRICS.csv"),
        ("CXR-ReDonE", "/opt/gpudata/rrg-data-2/exp-baseline/cxr-redone_impression_METRICS.csv"),
        ("X-REM", "/opt/gpudata/rrg-data-2/exp-baseline/x-rem_impression_METRICS.csv"),
        ("CheXagent", "/opt/gpudata/rrg-data-2/exp-baseline/chexagent_impression_METRICS.csv"),
        ("CXRMate", "/opt/gpudata/rrg-data-2/exp-baseline/cxr-mate_impression_METRICS.csv"),
    ],
    "Both - Model": [
        ("LaB-RAG", "/opt/gpudata/rrg-data-2/exp-findings/exp-filter/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings_METRICS.csv"),
        ("CheXagent", "/opt/gpudata/rrg-data-2/exp-baseline/chexagent_both_METRICS.csv"),
        ("CXRMate", "/opt/gpudata/rrg-data-2/exp-baseline/cxr-mate_both_METRICS.csv"),
    ],
    "Findings - Filter": [
        ("No-filter", "/opt/gpudata/rrg-data-2/exp-findings/exp-filter/Mistral-7B-Instruct-v0.3_no-filter_pred-label_simple_top-5_findings_METRICS.csv"),
        ("Exact", "/opt/gpudata/rrg-data-2/exp-findings/exp-filter/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings_METRICS.csv"),
        ("Partial", "/opt/gpudata/rrg-data-2/exp-findings/exp-filter/Mistral-7B-Instruct-v0.3_partial_pred-label_simple_top-5_findings_METRICS.csv"),
    ],
    "Findings - Prompt": [
        ("Naive", "/opt/gpudata/rrg-data-2/exp-findings/exp-prompt/Mistral-7B-Instruct-v0.3_exact_pred-label_naive_top-5_findings_METRICS.csv"),
        ("Simple", "/opt/gpudata/rrg-data-2/exp-findings/exp-prompt/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings_METRICS.csv"),
        ("Verbose", "/opt/gpudata/rrg-data-2/exp-findings/exp-prompt/Mistral-7B-Instruct-v0.3_exact_pred-label_verbose_top-5_findings_METRICS.csv"),
        ("Instruct", "/opt/gpudata/rrg-data-2/exp-findings/exp-prompt/Mistral-7B-Instruct-v0.3_exact_pred-label_instruct_top-5_findings_METRICS.csv"),
    ],
    "Findings - Language Model": [
        ("Mistral-v1", "/opt/gpudata/rrg-data-2/exp-findings/exp-model/Mistral-7B-Instruct-v0.1_exact_pred-label_simple_top-5_findings_METRICS.csv"),
        ("BioMistral", "/opt/gpudata/rrg-data-2/exp-findings/exp-model/BioMistral-7B_exact_pred-label_simple_top-5_findings_METRICS.csv"),
        ("Mistral-v3", "/opt/gpudata/rrg-data-2/exp-findings/exp-model/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings_METRICS.csv"),
    ],
    "Findings - Label": [
        ("True", "/opt/gpudata/rrg-data-2/exp-findings/exp-label/Mistral-7B-Instruct-v0.3_exact_true-label_simple_top-5_findings_METRICS.csv"),
        ("Predicted", "/opt/gpudata/rrg-data-2/exp-findings/exp-label/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings_METRICS.csv"),
    ],
    "Findings - Filter & Prompt": [
        ("Standard RAG", "/opt/gpudata/rrg-data-2/exp-findings/exp-redundancy/Mistral-7B-Instruct-v0.3_no-filter_pred-label_naive_top-5_findings_METRICS.csv"),
        ("Label Filter only", "/opt/gpudata/rrg-data-2/exp-findings/exp-redundancy/Mistral-7B-Instruct-v0.3_exact_pred-label_naive_top-5_findings_METRICS.csv"),
        ("Label Format only", "/opt/gpudata/rrg-data-2/exp-findings/exp-redundancy/Mistral-7B-Instruct-v0.3_no-filter_pred-label_simple_top-5_findings_METRICS.csv"),
        ("LaB-RAG", "/opt/gpudata/rrg-data-2/exp-findings/exp-redundancy/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings_METRICS.csv"),
    ],
    "Impression - Filter": [
        ("No-filter", "/opt/gpudata/rrg-data-2/exp-impression/exp-filter/Mistral-7B-Instruct-v0.3_no-filter_pred-label_simple_top-5_impression_METRICS.csv"),
        ("Exact", "/opt/gpudata/rrg-data-2/exp-impression/exp-filter/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_impression_METRICS.csv"),
        ("Partial", "/opt/gpudata/rrg-data-2/exp-impression/exp-filter/Mistral-7B-Instruct-v0.3_partial_pred-label_simple_top-5_impression_METRICS.csv"),
    ],
    "Impression - Prompt": [
        ("Naive", "/opt/gpudata/rrg-data-2/exp-impression/exp-prompt/Mistral-7B-Instruct-v0.3_exact_pred-label_naive_top-5_impression_METRICS.csv"),
        ("Simple", "/opt/gpudata/rrg-data-2/exp-impression/exp-prompt/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_impression_METRICS.csv"),
        ("Verbose", "/opt/gpudata/rrg-data-2/exp-impression/exp-prompt/Mistral-7B-Instruct-v0.3_exact_pred-label_verbose_top-5_impression_METRICS.csv"),
        ("Instruct", "/opt/gpudata/rrg-data-2/exp-impression/exp-prompt/Mistral-7B-Instruct-v0.3_exact_pred-label_instruct_top-5_impression_METRICS.csv"),
    ],
    "Impression - Language Model": [
        ("Mistral-v1", "/opt/gpudata/rrg-data-2/exp-impression/exp-model/Mistral-7B-Instruct-v0.1_exact_pred-label_simple_top-5_impression_METRICS.csv"),
        ("BioMistral", "/opt/gpudata/rrg-data-2/exp-impression/exp-model/BioMistral-7B_exact_pred-label_simple_top-5_impression_METRICS.csv"),
        ("Mistral-v3", "/opt/gpudata/rrg-data-2/exp-impression/exp-model/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_impression_METRICS.csv"),
    ],
    "Impression - Label": [
        ("True", "/opt/gpudata/rrg-data-2/exp-impression/exp-label/Mistral-7B-Instruct-v0.3_exact_true-label_simple_top-5_impression_METRICS.csv"),
        ("Predicted", "/opt/gpudata/rrg-data-2/exp-impression/exp-label/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_impression_METRICS.csv"),
    ],
    "Impression - Filter & Prompt": [
        ("Standard RAG", "/opt/gpudata/rrg-data-2/exp-impression/exp-redundancy/Mistral-7B-Instruct-v0.3_no-filter_pred-label_naive_top-5_impression_METRICS.csv"),
        ("Label Filter only", "/opt/gpudata/rrg-data-2/exp-impression/exp-redundancy/Mistral-7B-Instruct-v0.3_exact_pred-label_naive_top-5_impression_METRICS.csv"),
        ("Label Format only", "/opt/gpudata/rrg-data-2/exp-impression/exp-redundancy/Mistral-7B-Instruct-v0.3_no-filter_pred-label_simple_top-5_impression_METRICS.csv"),
        ("LaB-RAG", "/opt/gpudata/rrg-data-2/exp-impression/exp-redundancy/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_impression_METRICS.csv"),
    ],
    "Section": [
        ("Findings-Intersect", "/opt/gpudata/rrg-data-2/exp-section/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings-intersect_METRICS.csv"),
        ("Impression-Intersect", "/opt/gpudata/rrg-data-2/exp-section/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_impression-intersect_METRICS.csv"),
        ("Both", "/opt/gpudata/rrg-data-2/exp-section/Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_both_METRICS.csv"),
    ],
}

## Check duplicate runs are equivalent

In [None]:
count = defaultdict(list)
for g, ts in experiments.items():
    for _, t in ts:
        base = os.path.basename(t)
        count[base].append(t)

In [None]:
len(count.keys())

In [None]:
sum([len(l) for l in count.values()])

In [None]:
dupes = {k: v for k, v in count.items() if len(v) > 1}
print(len(dupes))
dupes

In [None]:
# metric annotation cols (for radgraph and chexbert) are hard to compare
# with np.isclose but should be same if derived metrics are the same
cols = ["study_id", "bleu4", "rougeL", "bertscore", "f1radgraph", "f1chexbert"]
for group, runs in dupes.items():
    group_dfs = []
    for run in runs:
        df = pd.read_csv(run)
        group_dfs.append(df)
    ref = group_dfs[0]
    for df, run in zip(group_dfs[1:], runs):
        print(run)
        assert np.isclose(ref[cols], df[cols]).all()

In [None]:
# map colors to experiments
sorted(list(count.keys()))

In [None]:
import seaborn as sns

def lighten_color(color, amount=0.5):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])

In [None]:
cmap = sns.color_palette(palette='Set3')
cmap

In [None]:
baseline_cmap = sns.color_palette(palette='CMRmap')
baseline_cmap

In [None]:
temp = {
    lighten_color(baseline_cmap[0], 0.6): [
        "chexagent_both_METRICS.csv",
        "chexagent_findings_METRICS.csv",
        "chexagent_impression_METRICS.csv",
    ],
    baseline_cmap[2]: [
        "cxr-mate_both_METRICS.csv",
        "cxr-mate_findings_METRICS.csv",
        "cxr-mate_impression_METRICS.csv",
    ],
    lighten_color(baseline_cmap[1], 0.7): [
        "cxr-redone_impression_METRICS.csv",
    ],
    baseline_cmap[3]: [
        "cxr-repair_impression_METRICS.csv",
    ],
    baseline_cmap[4]: [
        "rgrg_findings_METRICS.csv",
    ],
    baseline_cmap[5]: [
        "x-rem_impression_METRICS.csv",
    ],
    cmap[0]: [
        "Mistral-7B-Instruct-v0.3_no-filter_pred-label_simple_top-5_findings_METRICS.csv",
        "Mistral-7B-Instruct-v0.3_no-filter_pred-label_simple_top-5_impression_METRICS.csv",
    ],
    cmap[1]: [
        "Mistral-7B-Instruct-v0.3_partial_pred-label_simple_top-5_findings_METRICS.csv",
        "Mistral-7B-Instruct-v0.3_partial_pred-label_simple_top-5_impression_METRICS.csv",
    ],
    cmap[4]: [
        "Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings_METRICS.csv",
        "Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_impression_METRICS.csv",
        "Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_both_METRICS.csv",
    ],
    cmap[2]: [
        "Mistral-7B-Instruct-v0.3_exact_pred-label_naive_top-5_findings_METRICS.csv",
        "Mistral-7B-Instruct-v0.3_exact_pred-label_naive_top-5_impression_METRICS.csv",
    ],
    cmap[3]: [
        "Mistral-7B-Instruct-v0.3_exact_pred-label_verbose_top-5_findings_METRICS.csv",
        "Mistral-7B-Instruct-v0.3_exact_pred-label_verbose_top-5_impression_METRICS.csv",
    ],
    cmap[5]: [
        "Mistral-7B-Instruct-v0.3_exact_pred-label_instruct_top-5_findings_METRICS.csv",
        "Mistral-7B-Instruct-v0.3_exact_pred-label_instruct_top-5_impression_METRICS.csv",
    ],
    cmap[6]: [
        "Mistral-7B-Instruct-v0.1_exact_pred-label_simple_top-5_findings_METRICS.csv",
        "Mistral-7B-Instruct-v0.1_exact_pred-label_simple_top-5_impression_METRICS.csv",
    ],
    cmap[7]: [
        "BioMistral-7B_exact_pred-label_simple_top-5_findings_METRICS.csv",
        "BioMistral-7B_exact_pred-label_simple_top-5_impression_METRICS.csv",
    ],
    cmap[8]: [
        "Mistral-7B-Instruct-v0.3_exact_true-label_simple_top-5_findings_METRICS.csv",
        "Mistral-7B-Instruct-v0.3_exact_true-label_simple_top-5_impression_METRICS.csv",
    ],
    cmap[9]: [
        "Mistral-7B-Instruct-v0.3_no-filter_pred-label_naive_top-5_findings_METRICS.csv",
        "Mistral-7B-Instruct-v0.3_no-filter_pred-label_naive_top-5_impression_METRICS.csv",
    ],
    cmap[10]: ["Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_findings-intersect_METRICS.csv"],
    cmap[11]: ["Mistral-7B-Instruct-v0.3_exact_pred-label_simple_top-5_impression-intersect_METRICS.csv"],
}
colors = {v: k for k, vs in temp.items() for v in vs}

## Figures

In [None]:
metrics = [
    "bleu4",
    "rougeL",
    "bertscore",
    "f1radgraph",
    "f1chexbert",
]
os.makedirs("figs", exist_ok=True)
os.makedirs("figs-full", exist_ok=True)

In [None]:
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 9

In [None]:
def plot_metrics(metrics, folder, extra_room=False):
    for i, (group, runs) in enumerate(experiments.items()):
        print("\n\n\n\n")
        print(group)
        group_results = []
        for name, run in experiments[group]:
            results = pd.read_csv(run)
            group_results.append(results)

        # intersection of study ids
        study_ids = set(group_results[0]["study_id"])
        for results in group_results[1:]:
            study_ids &= set(results["study_id"])
        study_ids = sorted(list(study_ids))
        group_results = [results.set_index("study_id").loc[study_ids].reset_index() for results in group_results]

        melted_results = []
        for results, (name, _) in zip(group_results, experiments[group]):
            results = results.melt(id_vars="study_id", var_name="metric")
            results[group] = name
            melted_results.append(results)

        df = pd.concat(melted_results, ignore_index=True)
        df = df[df["metric"].isin(metrics)]
        x = "metric"
        y = "value"
        hue = group
        hue_order = [n for n, _ in experiments[group]]
        palette = [colors[os.path.basename(fp)] for _, fp in experiments[group]]
        order = metrics
        if group in {
            "Findings - Model",
            "Impression - Model",
            "Both - Model",
        }:
            # only compare to ours if evaluating literature models
            pairs = [
                ((metric, "LaB-RAG"), (metric, n2))
                for metric in metrics
                for n2 in hue_order[1:]
            ]
        else:
            pairs = [
                ((metric, n1), (metric, n2))
                for metric in metrics
                for i, n1 in enumerate(hue_order)
                for n2 in hue_order[i+1:]
            ]
        if extra_room:
            fig, ax = plt.subplots(figsize=(6, 3))
        else:
            fig, ax = plt.subplots(figsize=(3, 3))
        barplot = sns.barplot(
            df,
            x=x,
            y=y,
            order=order,
            hue=hue,
            hue_order=hue_order,
            palette=palette,
            ax=ax,
            saturation=1,
            zorder=15,
            errorbar="se",
            # capsize=0.2,
            err_kws={
                "zorder": 25,
                "linewidth": 1,
                "alpha": 1,
            },
            width=0.15*len(hue_order),
        )
        # Box plot
        # for bar in barplot.patches:
        #     bar.set_width(0.5)
        # sns.boxplot(
        #     df,
        #     x=x,
        #     y=y,
        #     order=order,
        #     hue=hue,
        #     hue_order=hue_order,
        #     palette=palette,
        #     ax=ax,
        #     fliersize=0.1,
        #     showmeans=True,
        #     meanprops={
        #         "markersize": 5,
        #         "markeredgecolor": "black",
        #         "marker": "+",
        #         # "marker": "P",
        #         # "markerfacecolor": "black",
        #         # "markeredgecolor": "darkgray",
        #         # "markeredgewidth": 1,
        #     },
        #     saturation=1,
        # )
        annot = Annotator(
            ax,
            pairs,
            data=df,
            x=x,
            y=y,
            order=order,
            hue=hue,
            hue_order=hue_order,
            palette=palette,
            width=0.15*len(hue_order),
        )
        # test = "t-test_paired" if group not in ["Section", "Section-true"] else "t-test_ind"
        test = "t-test_paired"
        annot._pvalue_format.fontsize = 9
        annot.configure(
            test=test,
            comparisons_correction="Bonferroni",
            hide_non_significant=True,
            # loc="outside",
            line_height=0.04,
            text_offset=-3,
            line_offset=10000,
            line_offset_to_group=0.1,
            line_width=0.75,
            pvalue_thresholds=[[0.05, "*"], [1, "ns"]],
        )
        _, annotations = annot.apply_test().annotate(line_offset=10000)
        # print(annotations[0].structs)
        ax.set_xlabel("")
        ax.set_ylabel("")
        if extra_room:
            ax.set_ylim([-0.05, 1.55])
        else:
            ax.set_ylim([-0.05, 1.05])
        ax.set_xlim([-0.5, len(metrics) - 0.5])
        ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
        ax.grid(which="major", axis="y", zorder=0)
        ax.set_title(f"{group.split(' - ')[0]}, N={len(study_ids)}", fontsize=10)
        legend = ax.legend(title=None, loc="upper left")
        legend.set_zorder(10)
        # legend.remove()
        fig.show()
        fig.tight_layout()
        fig.savefig(f"{folder}/{group}.pdf")

        # fig2, ax2 = plt.subplots(figsize=(3, 2))
        # handles, labels = ax.get_legend_handles_labels()
        # ax2.legend(handles, labels, loc="center")
        # ax2.axis("off")
        # fig2.savefig(f"{folder}/legends/{group}-legend.pdf")

In [None]:
plot_metrics(metrics=["f1radgraph", "f1chexbert"], folder="figs")

In [None]:
plot_metrics(metrics=["bleu4", "rougeL", "bertscore", "f1radgraph", "f1chexbert"], folder="figs-full", extra_room=True)

### ROC/PR curves

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from classify import plot_pr_curve, plot_roc_curve
from _data import DEFAULT_LABELS

In [None]:
y_test = pd.read_csv("/opt/gpudata/rrg-data-2/image-labels/test_true.csv")
y_prob_test = pd.read_csv("/opt/gpudata/rrg-data-2/image-labels/test_prob.csv")

In [None]:
plot_roc_curve(
    df_trues=y_test,
    df_probs=y_prob_test,
    labels=DEFAULT_LABELS,
    title=f"Test ROC Curve",
    output_path="figs/Classifier ROC Curve.pdf",
)

plot_pr_curve(
    df_trues=y_test,
    df_probs=y_prob_test,
    labels=DEFAULT_LABELS,
    title=f"Test PR Curve",
    output_path="figs/Classifier PR Curve.pdf",
)

# Tables

## Experiment Results

In [None]:
metrics = ["bleu4", "rougeL", "bertscore", "f1radgraph", "f1chexbert"]

In [None]:
results = []
for group, runs in experiments.items():
    parts = group.split(" - ")
    if len(parts) > 1:
        section = parts[0]
        experiment = parts[1]
    else:
        experiment = group
        section = "Intersection"
    for run_name, run_path in runs:
        temp = pd.read_csv(run_path)[metrics].mean()
        temp.name = (experiment, section, run_name)
        results.append(temp)

In [None]:
# [(e.split(" - ")[1], e.split(" - ")[0], n) for e, nps in experiments.items() for n, p in nps if e != "Section"]

In [None]:
rows = [
    ("Model", "Findings", "LaB-RAG"),
    ("Model", "Findings", "RGRG"),
    ("Model", "Findings", "CheXagent"),
    ("Model", "Findings", "CXRMate"),
    ("Model", "Impression", "LaB-RAG"),
    ("Model", "Impression", "CXR-RePaiR"),
    ("Model", "Impression", "CXR-ReDonE"),
    ("Model", "Impression", "X-REM"),
    ("Model", "Impression", "CheXagent"),
    ("Model", "Impression", "CXRMate"),
    ("Model", "Both", "LaB-RAG"),
    ("Model", "Both", "CheXagent"),
    ("Model", "Both", "CXRMate"),
    # ==================================================================
    ("Filter & Prompt", "Findings", "Standard RAG"),
    ("Filter & Prompt", "Findings", "Label Filter only"),
    ("Filter & Prompt", "Findings", "Label Format only"),
    ("Filter & Prompt", "Findings", "LaB-RAG"),
    ("Filter & Prompt", "Impression", "Standard RAG"),
    ("Filter & Prompt", "Impression", "Label Filter only"),
    ("Filter & Prompt", "Impression", "Label Format only"),
    ("Filter & Prompt", "Impression", "LaB-RAG"),
    # ==================================================================
    ("Filter", "Findings", "No-filter"),
    ("Filter", "Findings", "Exact"),
    ("Filter", "Findings", "Partial"),
    ("Filter", "Impression", "No-filter"),
    ("Filter", "Impression", "Exact"),
    ("Filter", "Impression", "Partial"),
    # ==================================================================
    ("Prompt", "Findings", "Naive"),
    ("Prompt", "Findings", "Simple"),
    ("Prompt", "Findings", "Verbose"),
    ("Prompt", "Findings", "Instruct"),
    ("Prompt", "Impression", "Naive"),
    ("Prompt", "Impression", "Simple"),
    ("Prompt", "Impression", "Verbose"),
    ("Prompt", "Impression", "Instruct"),
    # ==================================================================
    ("Language Model", "Findings", "Mistral-v1"),
    ("Language Model", "Findings", "BioMistral"),
    ("Language Model", "Findings", "Mistral-v3"),
    ("Language Model", "Impression", "Mistral-v1"),
    ("Language Model", "Impression", "BioMistral"),
    ("Language Model", "Impression", "Mistral-v3"),
    # ==================================================================
    ("Label", "Findings", "True"),
    ("Label", "Findings", "Predicted"),
    ("Label", "Impression", "True"),
    ("Label", "Impression", "Predicted"),
    # ==================================================================
    ("Section", "Intersection", "Findings-Intersect"),
    ("Section", "Intersection", "Impression-Intersect"),
    ("Section", "Intersection", "Both"),
]

In [None]:
results = pd.DataFrame(results)
results.sort_index(level=0, inplace=True)
results.index.set_names(["experiment", "section", "variable"], inplace=True)
results = results.loc[rows].copy()

In [None]:
results

In [None]:
latex = results.style.format(precision=3).to_latex()

In [None]:
print("\n\\cline{2-8}\n & \\multirow".join("\n\\cline{1-8}\n\\multirow".join(latex.split("\n\\multirow")).split("\n & \\multirow")))

## View Counts

In [None]:
import pandas as pd
from _data import DEFAULT_VIEW_ORDER

view_idx = {k: i for i, k in enumerate(DEFAULT_VIEW_ORDER)}

In [None]:
splits = {split: pd.read_csv(f"/opt/gpudata/rrg-data-2/image-labels/{split}_true.csv") for split in ["train", "val", "test"]}
view_counts = pd.DataFrame({split: df["ViewPosition"].value_counts() for split, df in splits.items()}).fillna(0).astype(int).sort_index(key=lambda xs: [view_idx[x] for x in xs])
view_counts["overall"] = view_counts.sum(axis=1)
view_counts = view_counts[["overall", "train", "val", "test"]]

In [None]:
lens = pd.Series({split: len(df.drop_duplicates("study_id")) for split, df in splits.items()})
lens.loc["overall"] = lens.sum()
lens = lens[["overall", "train", "val", "test"]]

In [None]:
lens

In [None]:
view_counts = view_counts.astype(str) + " (" + (view_counts / lens * 100).map(lambda x: f"{x:.1f}") + ")"

In [None]:
view_counts

In [None]:
print(view_counts.to_latex())

## Section Counts

In [None]:
import pandas as pd

In [None]:
sectioned = pd.read_csv("/opt/gpudata/mimic-cxr/mimic_cxr_sectioned.csv")
splits = pd.read_csv("/opt/gpudata/mimic-cxr/mimic-cxr-2.0.0-split.csv")

In [None]:
splits = splits[["study_id", "subject_id", "split"]].drop_duplicates()
merged = sectioned.merge(splits, on="study_id")
merged["impression"] = merged["impression"].notna()
merged["findings"] = merged["findings"].notna()
merged["both"] = merged["impression"] & merged["findings"]
merged.loc[merged["split"] == "validate", "split"] = "val"
temp = pd.DataFrame({
    "impression": merged.groupby("split")["impression"].value_counts(),
    "findings": merged.groupby("split")["findings"].value_counts(),
    "both": merged.groupby("split")["both"].value_counts(),
}).iloc[[1,3,5]].reset_index(level=1, drop=True).T
temp.index.name = "section"
temp = temp[["train", "val", "test"]]
temp["overall"] = temp.sum(axis=1)
temp = temp[["overall", "train", "val", "test"]]

In [None]:
lens = splits.drop_duplicates("study_id").groupby("split").size().rename({"validate": "val"})
lens.loc["overall"] = lens.sum()
lens = lens[["overall", "train", "val", "test"]]

In [None]:
lens

In [None]:
temp

In [None]:
temp = temp.astype(str) + " (" + (temp / lens * 100).map(lambda x: f"{x:.1f}") +  ")"

In [None]:
print(temp.to_latex())

## Demographics

In [None]:
import pandas as pd
from tableone import TableOne

In [None]:
edstays = pd.read_csv("/opt/gpudata/mimic/iv/ed/edstays.csv.gz")
patients = pd.read_csv("/opt/gpudata/mimic/iv/hosp/patients.csv.gz")
metadata = pd.read_csv("/opt/gpudata/mimic-cxr/mimic-cxr-2.0.0-metadata.csv")
splits = pd.read_csv("/opt/gpudata/mimic-cxr/mimic-cxr-2.0.0-split.csv")
splits = splits[["study_id", "subject_id", "split"]].drop_duplicates()
splits["split"] = splits["split"].replace({"validate": "val"})

In [None]:
mapping = {
    "WHITE": "White",
    "BLACK/AFRICAN AMERICAN": "Black",
    "OTHER": "Other",
    "UNKNOWN": "Unknown",
    "ASIAN": "Asian",
    "HISPANIC/LATINO - PUERTO RICAN": "Hispanic or Latino",
    "WHITE - OTHER EUROPEAN": "White",
    "ASIAN - CHINESE": "Asian",
    "HISPANIC/LATINO - DOMINICAN": "Hispanic or Latino",
    "BLACK/CAPE VERDEAN": "Black",
    "BLACK/AFRICAN": "Black",
    "WHITE - RUSSIAN": "White",
    "HISPANIC OR LATINO": "Hispanic or Latino",
    "BLACK/CARIBBEAN ISLAND": "Black",
    "HISPANIC/LATINO - GUATEMALAN": "Hispanic or Latino",
    "ASIAN - ASIAN INDIAN": "Asian",
    "ASIAN - SOUTH EAST ASIAN": "Asian",
    "WHITE - BRAZILIAN": "White",
    "HISPANIC/LATINO - MEXICAN": "Hispanic or Latino",
    "HISPANIC/LATINO - SALVADORAN": "Hispanic or Latino",
    "WHITE - EASTERN EUROPEAN": "White",
    "HISPANIC/LATINO - COLUMBIAN": "Hispanic or Latino",
    "PORTUGUESE": "Other",
    "AMERICAN INDIAN/ALASKA NATIVE": "American Indian or Alaska Native",
    "SOUTH AMERICAN": "Other",
    "PATIENT DECLINED TO ANSWER": "Unknown",
    "ASIAN - KOREAN": "Asian",
    "HISPANIC/LATINO - HONDURAN": "Hispanic or Latino",
    "HISPANIC/LATINO - CENTRAL AMERICAN": "Hispanic or Latino",
    "NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER": "Native Hawaiian or Pacific Islander",
    "HISPANIC/LATINO - CUBAN": "Hispanic or Latino",
    "UNABLE TO OBTAIN": "Unknown",
    "MULTIPLE RACE/ETHNICITY": "Other",
}

In [None]:
last_stay_race = edstays.sort_values(by=["subject_id", "intime"], ascending=True).drop_duplicates(subset="subject_id", keep="last")[["subject_id", "race"]]
last_stay_race["race"] = last_stay_race["race"].replace(mapping)

In [None]:
metadata["study_year"] = metadata["StudyDate"].astype(str).str[:4].astype(int)
study_year = metadata[["subject_id", "study_id", "study_year"]]

In [None]:
assert not last_stay_race[["subject_id", "race"]].drop_duplicates()["subject_id"].duplicated(keep=False).any()
assert not patients["subject_id"].duplicated().any()
assert not last_stay_race["subject_id"].duplicated().any()
assert not study_year[["subject_id", "study_id", "study_year"]].drop_duplicates()[["subject_id", "study_id"]].duplicated(keep=False).any()

In [None]:
merged = (
    splits.merge(patients, on="subject_id", how="left")
    .merge(last_stay_race, on="subject_id", how="left")
    .merge(study_year.drop_duplicates(["subject_id", "study_id"]), on=["subject_id", "study_id"], how="left")
).rename(columns={"gender": "sex"})

merged["year_diff"] = merged["study_year"] - merged["anchor_year"]
merged["age"] = merged["anchor_age"] + merged["year_diff"]
merged["sex"] = merged["sex"].replace({"F": "Female", "M": "Male"})

In [None]:
assert splits["split"].value_counts().equals(merged["split"].value_counts())
assert not merged["study_id"].duplicated().any()

In [None]:
groupby = "split"
columns = ["age", "sex", "race"]
categorical = ["sex", "race"]
continuous = ["age"]
nonnormal=["age"]

In [None]:
table = TableOne(
    data=merged,
    columns=columns,
    categorical=categorical,
    continuous=continuous,
    groupby=groupby,
    nonnormal=nonnormal,
    overall=True,
    include_null=False
)
table = table.tableone.droplevel(level=0, axis="columns")[["Overall", "train", "val", "test"]].rename(columns={"Overall": "overall"})

In [None]:
temp = merged[["split", "age"]].copy()
temp["age"] = temp["age"].isna()
missing = temp.groupby("split")["age"].sum()
missing.loc["overall"] = temp["age"].sum()
sizes = temp.groupby("split").size()
sizes.loc["overall"] = len(temp)
table.loc[("age, median [Q1,Q3]", "Missing"), :] = (missing.astype(str) + " (" + (missing / sizes).map(lambda x: f"{x*100:.1f}") + ")")

In [None]:
temp = merged[["split", "sex"]].copy()
temp["sex"] = temp["sex"].isna()
missing = temp.groupby("split")["sex"].sum()
missing.loc["overall"] = temp["sex"].sum()
sizes = temp.groupby("split").size()
sizes.loc["overall"] = len(temp)
table.loc[("sex, n (%)", "Missing"), :] = (missing.astype(str) + " (" + (missing / sizes).map(lambda x: f"{x*100:.1f}") + ")")

In [None]:
temp = merged[["split", "race"]].copy()
temp["race"] = temp["race"].isna()
missing = temp.groupby("split")["race"].sum()
missing.loc["overall"] = temp["race"].sum()
sizes = temp.groupby("split").size()
sizes.loc["overall"] = len(temp)
table.loc[("race, n (%)", "Missing"), :] = (missing.astype(str) + " (" + (missing / sizes).map(lambda x: f"{x*100:.1f}") + ")")

In [None]:
temp = merged.drop_duplicates("subject_id")["split"].value_counts()
temp.loc["overall"] = temp.sum()
table.loc[("n", "Patients"), :] = temp

In [None]:
for k1, k2 in {
    ("n", ""): ("n", "Studies"),
    ("age, median [Q1,Q3]", ""): ("age", "Median [Q1,Q3]"),
    ("age, median [Q1,Q3]", "Missing"): ("age", "Missing, n (%)"),
}.items():
    table.loc[k2, :] = table.loc[k1].copy()
    table.drop(index=k1, inplace=True)

In [None]:
print(table.sort_index().to_latex())

# Misc