In [None]:
PATH = "./outputs"
FILES = [
    "Qwen_Qwen3-Embedding-8B_ptsd_500_0.csv",
    "Qwen_Qwen3-Embedding-8B_ptsd_500_3.jsonl",
]
SUBFOLDER = "ptsd"
GROLTS_LABELS = True

In [None]:
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.patches import ConnectionPatch

# Set the style
sns.set_theme(style="whitegrid")

# Set general plot parameters for Overleaf (1-column A4)
plt.rcParams.update(
    {
        "figure.figsize": (3.3, 2.5),  # inches, ~1-column width
        "axes.titlesize": 18,
        "axes.labelsize": 16,
        "xtick.labelsize": 14,
        "ytick.labelsize": 14,
        "legend.fontsize": 14,
        "pdf.fonttype": 42,  # vector fonts
        "ps.fonttype": 42,
    }
)

id_map = {
    0: 0,
    1: 1,
    2: 2,
    3: 5,
    4: 6,
    5: 7,
    6: 9,
    7: 10,
    8: 11,
    9: 12,
    10: 13,
    11: 13,
    12: 14,
    13: 14,
    14: 15,
    15: 16,
    16: 19,
    17: 20,
}

reverse_map = defaultdict(list)
for new_q, old_q in id_map.items():
    reverse_map[old_q].append(new_q)

In [None]:
def load_llm_accuracies(df_labels):
    """
    Returns:
    - old_acc_df: DataFrame indexed by old question_id, columns = model names
    - new_acc_df: DataFrame indexed by new question_id, columns = model names
    """
    old_data = {}
    new_data = {}

    for f in FILES:
        df = pd.read_csv(f"{PATH}/{f}")
        filename = os.path.basename(f).replace(".csv", "")
        if df["question_id"].nunique() == 21:
            # Old checklist
            df_merged = df.merge(
                df_labels, on=["paper_id", "question_id"], suffixes=("_pred", "_true")
            )
            df_merged["correct"] = (
                df_merged["answer_pred"] == df_merged["answer_true"]
            ).astype(int)
            acc = df_merged.groupby("question_id")["correct"].mean()
            old_data[filename] = acc
        elif df["question_id"].nunique() == 18:
            # New checklist
            new_acc = []
            for i, row in df.iterrows():
                new_q = row["question_id"]
                old_q = id_map.get(new_q)
                if old_q is None:
                    continue
                label_rows = df_labels[df_labels["question_id"] == old_q]
                match = label_rows[label_rows["paper_id"] == row["paper_id"]]
                if not match.empty:
                    correct = row["answer"] == match["answer"].values[0]
                    new_acc.append({"question_id": new_q, "correct": int(correct)})
            if new_acc:
                acc_df = pd.DataFrame(new_acc).groupby("question_id")["correct"].mean()
                new_data[filename] = acc_df

    old_acc_df = pd.DataFrame(old_data).sort_index()
    new_acc_df = pd.DataFrame(new_data).sort_index()
    return old_acc_df, new_acc_df


def plot_mapped_accuracy_heatmaps(old_acc_df, new_acc_df, id_map):
    # Sort for consistent mapping
    old_acc_df = old_acc_df.sort_index()
    new_acc_df = new_acc_df.sort_index()

    fig = plt.figure(figsize=(14, 10))
    gs = fig.add_gridspec(1, 2, width_ratios=[1, 1], wspace=0.1)

    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1])

    # --- Old checklist heatmap (with colorbar on the right) ---
    _ = sns.heatmap(
        old_acc_df,
        vmin=0,
        vmax=1,
        cmap=sns.diverging_palette(20, 145, as_cmap=True),
        cbar=False,
        ax=ax1,
        annot=True,
        annot_kws={"size": 14},
        fmt=".2f",
    )
    ax1.set_title("Old Checklist Accuracy")
    ax1.set_ylabel("Old Question ID")
    ax1.set_xlabel("LLM")
    ax1.set_yticks(np.arange(len(old_acc_df)) + 0.5)
    if GROLTS_LABELS:
        ax1.set_yticklabels(
            [
                "1",
                "2",
                "3a",
                "3b",
                "3c",
                "4",
                "5",
                "6a",
                "6b",
                "7",
                "8",
                "9",
                "10",
                "11",
                "12",
                "13",
                "14a",
                "14b",
                "14c",
                "15",
                "16",
            ],
            rotation=0,
        )
    else:
        ax1.set_yticklabels(old_acc_df.index, rotation=0)
    ax1.set_xticklabels(old_acc_df.columns, rotation=45, ha="right")

    # --- New checklist heatmap ---
    sns.heatmap(
        new_acc_df,
        vmin=0,
        vmax=1,
        cmap=sns.diverging_palette(20, 145, as_cmap=True),
        cbar=True,
        ax=ax2,
        cbar_kws={"label": "Accuracy", "pad": 0.1},
        annot=True,
        annot_kws={"size": 14},
        fmt=".2f",
    )
    ax2.set_title("New Checklist Accuracy")
    ax2.set_ylabel("New Question ID")
    ax2.set_xlabel("LLM")
    ax2.set_yticks(np.arange(len(new_acc_df)) + 0.5)
    if GROLTS_LABELS:
        ax2.set_yticklabels(
            [
                "1",
                "2",
                "3",
                "4",
                "5",
                "6",
                "7",
                "8",
                "9",
                "10",
                "11",
                "12",
                "13",
                "14",
                "15",
                "16",
                "17",
                "18",
            ],
            rotation=0,
        )
    else:
        ax2.set_yticklabels(new_acc_df.index, rotation=0)
    ax2.set_xticklabels(new_acc_df.columns, rotation=45, ha="right")
    ax2.yaxis.tick_right()
    ax2.tick_params(axis="y", which="both", length=0)
    ax2.yaxis.set_label_position("right")

    # --- Add connection lines ---
    old_index_to_pos = {qid: i for i, qid in enumerate(old_acc_df.index)}
    new_index_to_pos = {qid: i for i, qid in enumerate(new_acc_df.index)}

    x_right_old = len(old_acc_df.columns)
    x_left_new = 0

    for new_q, old_q in id_map.items():
        if old_q in old_index_to_pos and new_q in new_index_to_pos:
            old_y = old_index_to_pos[old_q] + 0.5
            new_y = new_index_to_pos[new_q] + 0.5
            con = ConnectionPatch(
                xyA=(x_right_old, old_y),
                coordsA=ax1.transData,
                xyB=(x_left_new, new_y),
                coordsB=ax2.transData,
                color="gray",
                lw=1.0,
                alpha=0.5,
            )
            fig.add_artist(con)

    # plt.tight_layout()
    plt.savefig("./viz/comparison.pdf", bbox_inches="tight")
    plt.show()

In [None]:
# Load labels
df_labels = pd.read_csv(f"./human_labels/{SUBFOLDER}.csv", delimiter=";", dtype=int)
df_labels = df_labels.melt(
    id_vars=["paper_id"], var_name="question_id", value_name="answer"
)
df_labels["paper_id"] = df_labels["paper_id"].astype(int)
df_labels["question_id"] = df_labels["question_id"].astype(int)
df_labels["answer"] = df_labels["answer"].astype(int)

old_acc_df, new_acc_df = load_llm_accuracies(df_labels)
plot_mapped_accuracy_heatmaps(old_acc_df, new_acc_df, id_map)

In [None]:
print("Mean accuracies per model:")
display(old_acc_df.mean().sort_values(ascending=False))
display(new_acc_df.mean().sort_values(ascending=False))

print("Normalized accuracies per model:")
norm_old_acc_df = old_acc_df / 21
norm_new_acc_df = new_acc_df / 18
display(norm_old_acc_df.mean().sort_values(ascending=False))
display(norm_new_acc_df.mean().sort_values(ascending=False))