In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import pandas as pd
import seaborn as sns
import numpy as np
from data.personas import *
from data.constants import MODEL_ORDER
from data.loader import load_data
from utils.significance_testing import *
from utils.metrics import *
import matplotlib.pyplot as plt
import pickle

In [None]:
all_counts = {}

In [None]:
rename_metrics = {
    'OP': "Exp. Advantage: Static \n(e.g., expert in fact-checking)",
    "level1": " Exp. Advantage: Broad \n(e.g., expert in math)",
    "level2": "Exp. Advantage: Focused\n(e.g., expert in abstract algebra)",
    "level3": "Exp. Advantage: Niche\n(e.g., expert in group theory)",
    'WU_color': "Robustness\n(Color)",
    'WU_name': "Robustness\n(Name)",
    'Fid_Ed': "Fidelity\n(education)", 
    "Fid_Exp": "Fidelity\n(domain match)",
    'Fid_ExpLevel': "Fidelity\n(specialization)" 
}

In [None]:
expertise_metrics = ["OP", "level1", "level2", "level3"]
robustness_metrics = ['WU_color', 'WU_name']
fidelity_metrics = [ 'Fid_Ed', "Fid_Exp", 'Fid_ExpLevel']

In [None]:
dataset_dfs = {}
task2persona = {}
persona2task = {}
all_categories = {}

dataset_order = ["truthfulqa", "gsm8k", "mmlu_pro", "bigbench", "math"]

for dataset in dataset_order:
    print(f"Loading and processing {dataset}.")
    dataset_dfs[dataset] = load_data(dataset).to_pandas()
    if dataset == "mmlu_pro":
        persona2task["mmlu_pro"] = {x: "other" if x == "an expert in miscellaneous fields including international relations, sociology, accounting, and human sexuality" else x.replace("an expert in ", "") for x in EXPERTS["mmlu_pro"]}
        all_categories["mmlu_pro"] = dataset_dfs["mmlu_pro"]["category"]
    elif dataset == "bigbench":
        experts = EXPERTS["bigbench"]
        tasks = ['contextual_parametric_knowledge_conflicts',
                'logic_grid_puzzle',
                'strategyqa',
                'tracking_shuffled_objects']
        persona2task["bigbench"] = {
                                        experts[0]: tasks[1],
                                        experts[1]: tasks[2],
                                        experts[2]: tasks[3],
                                        experts[3]: tasks[0]
                                    }
        all_categories["bigbench"] = dataset_dfs["bigbench"]["category"]
    elif dataset == "math":
        experts = EXPERTS[dataset][1:8]
        tasks =  ['Algebra',
                'Counting & Probability',
                'Geometry',
                'Intermediate Algebra',
                'Number Theory',
                'Prealgebra',
                'Precalculus']
        persona2task["math"] = {p: s for p, s in zip(experts, tasks)}
        all_categories["math"] = dataset_dfs["math"]["type"]

for k, v in persona2task.items():
    task2persona[k] = {value: key for key,value in v.items()}


task_to_dataset = {
    "truthfulqa": "truthfulqa",
    "gsm8k": "gsm8k",
}
mmlu_tasks = {x: "mmlu_pro" for x in task2persona["mmlu_pro"].keys()}
bigbench_tasks = {x: "bigbench" for x in task2persona["bigbench"].keys()}
math_tasks = {x: "math" for x in task2persona["math"].keys()}
task_to_dataset = {**task_to_dataset, **mmlu_tasks, **bigbench_tasks, **math_tasks}

In [None]:
metric_names = [expertise_metrics, robustness_metrics, fidelity_metrics]

In [None]:
import matplotlib.font_manager as fm
print(sorted(fm.get_font_names()))

In [None]:
fm.fontManager.addfont("/usr/share/fonts/truetype/cmu/cmunrm.ttf")   # regular
fm.fontManager.addfont("/usr/share/fonts/truetype/cmu/cmunbx.ttf")   # bold
bold_font = fm.FontProperties(fname="/usr/share/fonts/truetype/cmu/cmunbx.ttf")

In [None]:
cmu_serif = fm.FontProperties(fname="/usr/share/fonts/truetype/cmu/cmunrm.ttf").get_name()
print("Font name:", cmu_serif)  # should be "CMU Serif"

In [None]:
plt.rcParams.update({
    "text.usetex": False,  # Enable LaTeX
    "mathtext.fontset": "cm",  # Use Computer Modern (LaTeX default)
    "font.family": cmu_serif,
    "font.size": 14,         # Base font size
    "axes.titlesize": 16,    # Title font size
    "axes.labelsize": 14,    # Axis label font size
    "xtick.labelsize": 12,   # X-axis tick font size
    "ytick.labelsize": 12,   # Y-axis tick font size
    "legend.fontsize": 14    # Legend font size
})

### Make Dataset Table

In [None]:
table = pd.DataFrame()

In [None]:
datasets = list(dataset_dfs.keys())

In [None]:
tasks = []
for dataset in datasets:
    if dataset in all_categories:
        tasks.extend(np.unique(all_categories[dataset]))
    else:
        tasks.append(dataset)

In [None]:
table["Task"] = tasks

In [None]:
table["Dataset"] = table.Task.apply(lambda x: task_to_dataset[x])

In [None]:
samples = []
for task in table.Task.tolist():
    dataset = task_to_dataset[task]
    if dataset in ["truthfulqa", "gsm8k"]:
        samples.append(len(dataset_dfs[dataset]))
    else:
        if dataset != "math":
            counts = dataset_dfs[dataset].category.value_counts()
        else:
            counts = dataset_dfs[dataset].type.value_counts()
        samples.append(counts.loc[task])

In [None]:
table["# Instances"] = samples

In [None]:
table = table[["Dataset", "Task", "# Instances"]]

In [None]:
table["Task"] = table["Task"].str.capitalize()

In [None]:
count = 0
for data in dataset_dfs.values():
    count += len(data)
count

In [None]:
table["# Instances"].sum()

In [None]:
print(table.to_latex(index=False))

In [None]:
table.to_clipboard()

### Aggregate figures

In [None]:
mitigation = "base"

In [None]:
metrics = pickle.load(open(f"./results/all_metrics.pkl", "rb"))
pvalues = pickle.load(open(f"./results/all_pvalues.pkl", "rb"))
all_results = pickle.load(open(f"./results/all_results.pkl", "rb"))
significances = pickle.load(open(f"./results/fidelity_significances.pkl", "rb"))

In [None]:
def get_significances(metric, task, model):
    if metric == "OP":
        pvalue = pvalues[task].loc["in-expert", model]
    if metric == "WU_color":
        worst = worst_case_utility(all_results[task], COLOR_PERSONAS, return_persona=True)[1][model]
        pvalue = pvalues[task].loc[worst, model]
    if metric == "WU_name":
        worst = worst_case_utility(all_results[task], NAMES, return_persona=True)[1][model]
        pvalue = pvalues[task].loc[worst, model]
    if "level" in metric:
        pvalue =  pvalues[task].loc[metric, model]
    if "Fid" in metric:
        return significances[task].loc[metric, model]
    else:
        return pvalue < .05    

In [None]:
def process_df(metrics_df, comp=False):
    metrics_df = metrics_df[metrics_df.metric != "empty"].copy()
    metrics_df["significant"] = metrics_df.apply(lambda x: get_significances(x.metric, x.task, x.model), axis=1)
    metrics_df.score = metrics_df.apply(lambda x: 1 if x.score > 0 and x.significant else (-1 if x.score < 0 and x.significant else 0),axis=1) 

    metrics_df = metrics_df.groupby(["metric", "model"], as_index=False).score.value_counts()
    
    model_type = pd.CategoricalDtype(categories=MODEL_ORDER, ordered=True)
    
    metrics_df.model = metrics_df.model.astype(model_type)        
    
    metrics_df.model = metrics_df.model.apply(lambda x: "-".join(x.split("-")[:-1]))
    
    metrics_df["percent"] = metrics_df["count"] / 27 * 100
    metric_dfs = []
    for idx, metrics in enumerate(metric_names):
        metric_type = pd.CategoricalDtype(categories=metrics, ordered=True)
        metrics_df_filtered = metrics_df[metrics_df.metric.isin(metrics)].copy()
        metrics_df_filtered.metric = metrics_df_filtered.metric.astype(metric_type)
        mapping = ["Negative", "Not significant", "Positive"]
        # if idx == 0:
        #     mapping[0] = mapping[0] + " ❌"
        #     mapping[1] = mapping[1] + " ✅"
        #     mapping[2] = mapping[2] + " ✅"
        # elif idx == 1:
        #     mapping[0] = mapping[0] + " ❌"
        #     mapping[1] = mapping[1] + " ✅"
        #     mapping[2] = mapping[2] + " ❌"
        # elif idx == 2:
        #     mapping[0] = mapping[0] + " ❌"
        #     mapping[1] = mapping[1] + " ❌"
        #     mapping[2] = mapping[2] + " ✅"
        metrics_df_filtered.score = metrics_df_filtered.score.apply(lambda x: mapping[x+1])
        metrics_df_filtered.score = metrics_df_filtered.score.astype(pd.CategoricalDtype(categories=mapping, ordered=True))
        metrics_df_filtered = metrics_df_filtered.sort_values(["metric", "model", "score"])
        metric_dfs.append(metrics_df_filtered)
    
    return metric_dfs




In [None]:
metrics_df = pd.DataFrame()
for task, df in metrics.items():
    ms = df.T.stack().reset_index().copy()
    ms["task"] = task
    ms.columns = ["model", "metric", "score", "task"]
    metrics_df =  pd.concat([metrics_df, ms], axis=0)

In [None]:
for task, df in all_results.items():
    expertise_df = df.loc[["level1", "level2", "level3"]].copy()
    expertise_df = expertise_df - df.loc[["empty"]].values
    ms = expertise_df.T.stack().reset_index()
    ms["task"] = task
    ms.columns = ["model", "metric", "score", "task"]
    metrics_df =  pd.concat([metrics_df, ms], axis=0)

In [None]:
expertise_df, robustness_df, fidelity_df = process_df(metrics_df)

In [None]:
g = sns.FacetGrid(expertise_df, col="metric", sharey=True, sharex=True, col_wrap=2, height=4, aspect=.95)
for idx, (ax, (metric, sub_df)) in enumerate(zip(g.axes.flat, expertise_df.groupby("metric"))):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=12, color="black", fontproperties=bold_font)
                #if score not in ["Not significant", "Positive"]:
                    #rect.set_hatch("**")
                    #rect.set_alpha(.3)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    if idx //2==1:
        sec2 = ax.secondary_xaxis(location=0)
        sec2.set_xticks([2.5, 5.5], labels=[])
        sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font)
    ax.set_ylabel("# of Tasks")
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)

#legend = g.axes.flat[1].legend(loc='upper center',  bbox_to_anchor=(.0, 1.35), ncol=3)

# g.fig.suptitle("Expertise advantage", fontproperties=bold_font, y=1.22)er", va="bottom",
plt.subplots_adjust(wspace=.07, hspace=.25)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/expertise_aggregate_base.pdf", dpi=300, bbox_inches="tight")

In [None]:
g = sns.FacetGrid(robustness_df, col="metric", sharey=True, col_wrap=2, height=4, aspect=.95)
for ax, (metric, sub_df) in zip(g.axes.flat, robustness_df.groupby("metric")):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=12, color="black", fontproperties=bold_font)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    sec2 = ax.secondary_xaxis(location=0)
    sec2.set_xticks([2.5, 5.5], labels=[])
    sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font)
    ax.set_ylabel("# of Tasks")
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)

#g.axes.flat[0].legend(loc='upper center',  bbox_to_anchor=(1., 1.35), ncol=3)
# g.fig.suptitle("Persona vs No-Persona", fontproperties=bold_font, y=1.22)
plt.subplots_adjust(wspace=.05, hspace=0)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/robustness_aggregate_base.pdf", dpi=300, bbox_inches="tight")

In [None]:
g = sns.FacetGrid(fidelity_df, col="metric", sharey=True, sharex=True, col_wrap=3, height=4, aspect=.95)
for ax, (metric, sub_df) in zip(g.axes.flat, fidelity_df.groupby("metric")):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=14, color="black", fontproperties=bold_font)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    #if "xp" in metric:
    sec2 = ax.secondary_xaxis(location=0)
    sec2.set_xticks([2.5, 5.5], labels=[])
    sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font, fontsize=18)
    ax.set_ylabel("# of Tasks", fontsize=18)
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)
    ax.xaxis.set_tick_params(labelsize=16)

#g.axes.flat[0].legend(loc='lower right',  bbox_to_anchor=(1.85, -.85), ncol=1)
plt.subplots_adjust(wspace=.07, hspace=.25)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig("../persona_performance_paper/media/fidelity_aggregate_base.pdf", dpi=300, bbox_inches="tight")

In [None]:
mitigation = "instruction"

In [None]:
metrics = pickle.load(open(f"./results/{mitigation}/all_metrics.pkl", "rb"))
pvalues = pickle.load(open(f"./results/{mitigation}/all_pvalues.pkl", "rb"))
all_results = pickle.load(open(f"./results/{mitigation}/all_results.pkl", "rb"))
significances = pickle.load(open(f"./results/{mitigation}/fidelity_significances.pkl", "rb"))

In [None]:
def get_significances(metric, task, model):
    if metric == "OP":
        pvalue = pvalues[task].loc["in-expert", model]
    if metric == "WU_color":
        worst = worst_case_utility(all_results[task], COLOR_PERSONAS, return_persona=True)[1][model]
        pvalue = pvalues[task].loc[worst, model]
    if metric == "WU_name":
        worst = worst_case_utility(all_results[task], NAMES, return_persona=True)[1][model]
        pvalue = pvalues[task].loc[worst, model]
    if "level" in metric:
        pvalue =  pvalues[task].loc[metric, model]
    if "Fid" in metric:
        return significances[task].loc[metric, model]
    else:
        return pvalue < .05    

In [None]:
metrics_df = pd.DataFrame()
for task, df in metrics.items():
    ms = df.T.stack().reset_index().copy()
    ms["task"] = task
    ms.columns = ["model", "metric", "score", "task"]
    metrics_df =  pd.concat([metrics_df, ms], axis=0)

In [None]:
for task, df in all_results.items():
    expertise_df = df.loc[["level1", "level2", "level3"]].copy()
    expertise_df = expertise_df - df.loc[["empty"]].values
    ms = expertise_df.T.stack().reset_index()
    ms["task"] = task
    ms.columns = ["model", "metric", "score", "task"]
    metrics_df =  pd.concat([metrics_df, ms], axis=0)

In [None]:
expertise_df, robustness_df, fidelity_df = process_df(metrics_df)

In [None]:
g = sns.FacetGrid(expertise_df, col="metric", sharey=True, sharex=True, col_wrap=2, height=4, aspect=.95)
for idx, (ax, (metric, sub_df)) in enumerate(zip(g.axes.flat, expertise_df.groupby("metric"))):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=12, color="black", fontproperties=bold_font)
                #if score not in ["Not significant", "Positive"]:
                    #rect.set_hatch("**")
                    #rect.set_alpha(.3)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    if idx //2==1:
        sec2 = ax.secondary_xaxis(location=0)
        sec2.set_xticks([2.5, 5.5], labels=[])
        sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font)
    ax.set_ylabel("# of Tasks")
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)

#legend = g.axes.flat[1].legend(loc='upper center',  bbox_to_anchor=(.0, 1.35), ncol=3)

g.fig.suptitle("Instruction", fontproperties=bold_font, y=1.05, fontsize=20)
plt.subplots_adjust(wspace=.07, hspace=.25)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/expertise_aggregate_{mitigation}.pdf", dpi=300, bbox_inches="tight")

In [None]:
g = sns.FacetGrid(robustness_df, col="metric", sharey=True, col_wrap=2, height=4, aspect=.95)
for ax, (metric, sub_df) in zip(g.axes.flat, robustness_df.groupby("metric")):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=12, color="black", fontproperties=bold_font)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    sec2 = ax.secondary_xaxis(location=0)
    sec2.set_xticks([2.5, 5.5], labels=[])
    sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font)
    ax.set_ylabel("# of Tasks")
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)

#g.axes.flat[0].legend(loc='upper center',  bbox_to_anchor=(1., 1.35), ncol=3)
# g.fig.suptitle("Persona vs No-Persona", fontproperties=bold_font, y=1.22)
g.fig.suptitle("Instruction", fontproperties=bold_font, y=1.05, fontsize=20)
plt.subplots_adjust(wspace=.05, hspace=0)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/robustness_aggregate_{mitigation}.pdf", dpi=300, bbox_inches="tight")

In [None]:
g = sns.FacetGrid(fidelity_df, col="metric", sharey=True, sharex=True, col_wrap=3, height=4, aspect=.95)
for ax, (metric, sub_df) in zip(g.axes.flat, fidelity_df.groupby("metric")):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=14, color="black", fontproperties=bold_font)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    #if "xp" in metric:
    sec2 = ax.secondary_xaxis(location=0)
    sec2.set_xticks([2.5, 5.5], labels=[])
    sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font, fontsize=18)
    ax.set_ylabel("# of Tasks", fontsize=18)
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)
    ax.xaxis.set_tick_params(labelsize=16)

#g.axes.flat[0].legend(loc='lower right',  bbox_to_anchor=(1.85, -.85), ncol=1)
g.fig.suptitle("Instruction", fontproperties=bold_font, y=1.1, fontsize=20)
plt.subplots_adjust(wspace=.07, hspace=.25)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/fidelity_aggregate_{mitigation}.pdf", dpi=300, bbox_inches="tight")

In [None]:
mitigation = "refine"

In [None]:
metrics = pickle.load(open(f"./results/{mitigation}/all_metrics.pkl", "rb"))
pvalues = pickle.load(open(f"./results/{mitigation}/all_pvalues.pkl", "rb"))
all_results = pickle.load(open(f"./results/{mitigation}/all_results.pkl", "rb"))
significances = pickle.load(open(f"./results/{mitigation}/fidelity_significances.pkl", "rb"))

In [None]:
def get_significances(metric, task, model):
    if metric == "OP":
        pvalue = pvalues[task].loc["in-expert", model]
    if metric == "WU_color":
        worst = worst_case_utility(all_results[task], COLOR_PERSONAS, return_persona=True)[1][model]
        pvalue = pvalues[task].loc[worst, model]
    if metric == "WU_name":
        worst = worst_case_utility(all_results[task], NAMES, return_persona=True)[1][model]
        pvalue = pvalues[task].loc[worst, model]
    if "level" in metric:
        pvalue =  pvalues[task].loc[metric, model]
    if "Fid" in metric:
        return significances[task].loc[metric, model]
    else:
        return pvalue < .05    

In [None]:
metrics_df = pd.DataFrame()
for task, df in metrics.items():
    ms = df.T.stack().reset_index().copy()
    ms["task"] = task
    ms.columns = ["model", "metric", "score", "task"]
    metrics_df =  pd.concat([metrics_df, ms], axis=0)

In [None]:
for task, df in all_results.items():
    expertise_df = df.loc[["level1", "level2", "level3"]].copy()
    expertise_df = expertise_df - df.loc[["empty"]].values
    ms = expertise_df.T.stack().reset_index()
    ms["task"] = task
    ms.columns = ["model", "metric", "score", "task"]
    metrics_df =  pd.concat([metrics_df, ms], axis=0)

In [None]:
expertise_df, robustness_df, fidelity_df = process_df(metrics_df)

In [None]:
g = sns.FacetGrid(expertise_df, col="metric", sharey=True, sharex=True, col_wrap=2, height=4, aspect=.95)
for idx, (ax, (metric, sub_df)) in enumerate(zip(g.axes.flat, expertise_df.groupby("metric"))):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=12, color="black", fontproperties=bold_font)
                #if score not in ["Not significant", "Positive"]:
                    #rect.set_hatch("**")
                    #rect.set_alpha(.3)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    if idx //2==1:
        sec2 = ax.secondary_xaxis(location=0)
        sec2.set_xticks([2.5, 5.5], labels=[])
        sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font)
    ax.set_ylabel("# of Tasks")
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)

#legend = g.axes.flat[1].legend(loc='upper center',  bbox_to_anchor=(.0, 1.35), ncol=3)

# g.fig.suptitle("Expertise advantage", fontproperties=bold_font, y=1.22)er", va="bottom",
g.fig.suptitle("Refine + Instruction", fontproperties=bold_font, y=1.05, fontsize=20)
plt.subplots_adjust(wspace=.07, hspace=.25)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/expertise_aggregate_{mitigation}.pdf", dpi=300, bbox_inches="tight")

In [None]:
g = sns.FacetGrid(robustness_df, col="metric", sharey=True, col_wrap=2, height=4, aspect=.95)
for ax, (metric, sub_df) in zip(g.axes.flat, robustness_df.groupby("metric")):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=12, color="black", fontproperties=bold_font)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    sec2 = ax.secondary_xaxis(location=0)
    sec2.set_xticks([2.5, 5.5], labels=[])
    sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font)
    ax.set_ylabel("# of Tasks")
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)

#g.axes.flat[0].legend(loc='upper center',  bbox_to_anchor=(1., 1.35), ncol=3)
# g.fig.suptitle("Persona vs No-Persona", fontproperties=bold_font, y=1.22)
g.fig.suptitle("Refine + Instruction", fontproperties=bold_font, y=1.08, fontsize=20)
plt.subplots_adjust(wspace=.05, hspace=0)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/robustness_aggregate_{mitigation}.pdf", dpi=300, bbox_inches="tight")

In [None]:
g = sns.FacetGrid(fidelity_df, col="metric", sharey=True, sharex=True, col_wrap=3, height=4, aspect=.95)
for ax, (metric, sub_df) in zip(g.axes.flat, fidelity_df.groupby("metric")):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=14, color="black", fontproperties=bold_font)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    #if "xp" in metric:
    sec2 = ax.secondary_xaxis(location=0)
    sec2.set_xticks([2.5, 5.5], labels=[])
    sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font, fontsize=18)
    ax.set_ylabel("# of Tasks", fontsize=18)
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)
    ax.xaxis.set_tick_params(labelsize=16)

#g.axes.flat[0].legend(loc='lower right',  bbox_to_anchor=(1.85, -.85), ncol=1)
g.fig.suptitle("Refine + Instruction", fontproperties=bold_font, y=1.1, fontsize=20)
plt.subplots_adjust(wspace=.07, hspace=.25)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/fidelity_aggregate_{mitigation}.pdf", dpi=300, bbox_inches="tight")

In [None]:
mitigation = "refine_basic"

In [None]:
metrics = pickle.load(open(f"./results/{mitigation}/all_metrics.pkl", "rb"))
pvalues = pickle.load(open(f"./results/{mitigation}/all_pvalues.pkl", "rb"))
all_results = pickle.load(open(f"./results/{mitigation}/all_results.pkl", "rb"))
significances = pickle.load(open(f"./results/{mitigation}/fidelity_significances.pkl", "rb"))

In [None]:
def get_significances(metric, task, model):
    if metric == "OP":
        pvalue = pvalues[task].loc["in-expert", model]
    if metric == "WU_color":
        worst = worst_case_utility(all_results[task], COLOR_PERSONAS, return_persona=True)[1][model]
        pvalue = pvalues[task].loc[worst, model]
    if metric == "WU_name":
        worst = worst_case_utility(all_results[task], NAMES, return_persona=True)[1][model]
        pvalue = pvalues[task].loc[worst, model]
    if "level" in metric:
        pvalue =  pvalues[task].loc[metric, model]
    if "Fid" in metric:
        return significances[task].loc[metric, model]
    else:
        return pvalue < .05    

In [None]:
metrics_df = pd.DataFrame()
for task, df in metrics.items():
    ms = df.T.stack().reset_index().copy()
    ms["task"] = task
    ms.columns = ["model", "metric", "score", "task"]
    metrics_df =  pd.concat([metrics_df, ms], axis=0)

In [None]:
for task, df in all_results.items():
    expertise_df = df.loc[["level1", "level2", "level3"]].copy()
    expertise_df = expertise_df - df.loc[["empty"]].values
    ms = expertise_df.T.stack().reset_index()
    ms["task"] = task
    ms.columns = ["model", "metric", "score", "task"]
    metrics_df =  pd.concat([metrics_df, ms], axis=0)

In [None]:
expertise_df, robustness_df, fidelity_df = process_df(metrics_df)

In [None]:
g = sns.FacetGrid(expertise_df, col="metric", sharey=True, sharex=True, col_wrap=2, height=4, aspect=.95)
for idx, (ax, (metric, sub_df)) in enumerate(zip(g.axes.flat, expertise_df.groupby("metric"))):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=12, color="black", fontproperties=bold_font)
                #if score not in ["Not significant", "Positive"]:
                    #rect.set_hatch("**")
                    #rect.set_alpha(.3)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    if idx //2==1:
        sec2 = ax.secondary_xaxis(location=0)
        sec2.set_xticks([2.5, 5.5], labels=[])
        sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font)
    ax.set_ylabel("# of Tasks")
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)

#legend = g.axes.flat[1].legend(loc='upper center',  bbox_to_anchor=(.0, 1.35), ncol=3)

# g.fig.suptitle("Expertise advantage", fontproperties=bold_font, y=1.22)er", va="bottom",
g.fig.suptitle("Refine", fontproperties=bold_font, y=1.05, fontsize=20)
plt.subplots_adjust(wspace=.07, hspace=.25)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/expertise_aggregate_{mitigation}.pdf", dpi=300, bbox_inches="tight")

In [None]:
g = sns.FacetGrid(robustness_df, col="metric", sharey=True, col_wrap=2, height=4, aspect=.95)
for ax, (metric, sub_df) in zip(g.axes.flat, robustness_df.groupby("metric")):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=12, color="black", fontproperties=bold_font)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    sec2 = ax.secondary_xaxis(location=0)
    sec2.set_xticks([2.5, 5.5], labels=[])
    sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font)
    ax.set_ylabel("# of Tasks")
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)

#g.axes.flat[0].legend(loc='upper center',  bbox_to_anchor=(1., 1.35), ncol=3)
# g.fig.suptitle("Persona vs No-Persona", fontproperties=bold_font, y=1.22)
g.fig.suptitle("Refine", fontproperties=bold_font, y=1.05, fontsize=20)
plt.subplots_adjust(wspace=.05, hspace=0)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/robustness_aggregate_{mitigation}.pdf", dpi=300, bbox_inches="tight")

In [None]:
g = sns.FacetGrid(fidelity_df, col="metric", sharey=True, sharex=True, col_wrap=3, height=4, aspect=.95)
for ax, (metric, sub_df) in zip(g.axes.flat, fidelity_df.groupby("metric")):
    pivot_percent = sub_df.pivot(index="model", columns="score", values="percent").fillna(0)
    pivot_count = sub_df.pivot(index="model", columns="score", values="count").fillna(0)
    all_counts.setdefault(mitigation, {})[metric] = pivot_count
    bars = pivot_count.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
        # Annotate bars with frequency count
    for container, score in zip(bars.containers, pivot_percent.columns):
        for rect, count in zip(container, pivot_percent[score]):
            if count > 0:  # Only show for non-zero counts
                height = rect.get_y() + rect.get_height() / 2
                ax.text(rect.get_x() + rect.get_width() / 2, height, f"{int(count)}", 
                        ha='center', va='center', fontsize=14, color="black", fontproperties=bold_font)
    ax.vlines([2.5, 5.5], 0, 27, lw=1, color="black", linestyles="-")
    #if "xp" in metric:
    sec2 = ax.secondary_xaxis(location=0)
    sec2.set_xticks([2.5, 5.5], labels=[])
    sec2.tick_params('x', length=80, width=1, grid_linestyle="dashed")
    #ax.set_xlim(-0.1, 9.1)
    ax.set_title(rename_metrics[metric], fontproperties=bold_font, fontsize=18)
    ax.set_ylabel("# of Tasks", fontsize=18)
    ax.set_xlabel("")
    ax.tick_params(axis="x", rotation=90)
    ax.xaxis.set_tick_params(labelsize=16)

#g.axes.flat[0].legend(loc='lower right',  bbox_to_anchor=(1.85, -.85), ncol=1)
g.fig.suptitle("Refine", fontproperties=bold_font, y=1.1, fontsize=20)
plt.subplots_adjust(wspace=.07, hspace=.25)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/fidelity_aggregate_{mitigation}.pdf", dpi=300, bbox_inches="tight")

### Metric comparison

In [None]:
rename = {
    "base": "Base",
    "instruction": "Instruction",
    "refine_basic": "Refine",
    "refine": "Ref. + Inst"
}

In [None]:
agg_metrics = {}
for method in all_counts.keys():
    agg_metrics.setdefault(method, {})["Exp. Advantage"] = (all_counts[method]["OP"] +  all_counts[method]["level1"] + all_counts[method]["level2"] + all_counts[method]["level3"])/27/4
    agg_metrics[method]["Robustness"] = (all_counts[method]["WU_color"] +  all_counts[method]["WU_name"])/27/2
    agg_metrics[method]["Fidelity"] = sum([x for k, x in all_counts[method].items() if "Fid" in k])/27/3

In [None]:
for method in agg_metrics.keys():
    for m in agg_metrics[method].keys():
        agg_metrics[method][m]["method"] = method
        agg_metrics[method][m]["metric"] =  m
    agg_metrics[method] = pd.concat(agg_metrics[method].values(), axis=0)

In [None]:
agg_metrics_df = pd.concat(agg_metrics.values(), axis=0)

In [None]:
agg_metrics_df.method = agg_metrics_df.method.apply(lambda x: rename[x])

In [None]:
from itertools import product

df = agg_metrics_df.reset_index()
df_melted = df.melt(
    id_vars=["model", "method", "metric"],
    value_vars=["Negative", "Not significant", "Positive"],
    var_name="score",
    value_name="value"
)

# Optional: ensure ordering
model_order = [
    "gemma-2-2b", "gemma-2-9b", "gemma-2-27b",
    "Llama-3.2-3B", "Llama-3.1-8B", "Llama-3.1-70B",
    "Qwen2.5-3B", "Qwen2.5-7B", "Qwen2.5-72B"
]

score_order = ["Negative", "Not significant", "Positive"]

# Create label for combined model+method
label_order = [f"{me}|{mo}" for me, mo in product([x for x in rename.values()],model_order)]
metric_order = ["Exp. Advantage", "Robustness", "Fidelity"]
method_order = list(rename.values())[1:]
df_melted["label"] = df_melted["method"].astype(str) + "|" + df_melted["model"].astype(str)
label_type = pd.CategoricalDtype(label_order, ordered=True)
metric_type = pd.CategoricalDtype(metric_order, ordered=True)
method_type = pd.CategoricalDtype(method_order, ordered=True)
df_melted["label"] = df_melted["label"].astype(label_type)
df_melted["metric"] = df_melted["metric"].astype(metric_type)
df_melted = df_melted.sort_values(["label", "metric"])
base_df = df_melted[df_melted.method == "Base"]
df_melted =  df_melted[df_melted.method != "Base"]
df_melted["method"] = df_melted["method"].astype(method_type)
df_melted = df_melted.sort_values(["label", "method", "metric"])

# Create FacetGrid with one subplot per metric
g = sns.FacetGrid(df_melted, col="method", row="metric",  sharey=True, height=3, aspect=.95)

# Plot stacked bars manually per subplot
for ax, ((metric, method), sub_df) in zip(g.axes.flat, df_melted.groupby(["metric", "method"])):
    pivot_df = sub_df.pivot_table(index="model", columns="score", values="value", aggfunc="sum").fillna(0)
    pivot_df = pivot_df[score_order]  # reorder columns
    pivot_base = base_df[(base_df.metric == metric)]
    pivot_base = pivot_base.pivot_table(index="model", columns="score", values="value", aggfunc="sum").fillna(0)
    pivot_base["lower"] = pivot_base["Negative"]
    pivot_base["upper"] = pivot_base["Negative"] +  pivot_base["Not significant"]
    bars = pivot_df.plot(kind="bar", stacked=True, ax=ax, colormap=sns.diverging_palette(260, 30,  l=70, s=100, center='light', as_cmap=True), legend=False)
    ax.scatter(data=pivot_base, x=pivot_base.index, y="lower", color="blue", marker="_",s=500,label=None)
    ax.scatter(data=pivot_base, x=pivot_base.index, y="upper", color="orange", marker="_",s=500 ,label=None)
    ax.scatter(data=pivot_base, x=pivot_base.index, y="lower", color="blue", marker="X",s=50,label=None )
    ax.scatter(data=pivot_base, x=pivot_base.index, y="upper", color="orange", marker="X",s=50,label=None )
    ax.vlines([2.5, 5.5, 8.5], 0, 1, lw=1, color="black", linestyles="dashed")
    ax.set_title(method, fontproperties=bold_font) if metric == "Exp. Advantage" else ax.set_title("")
    ax.set_ylabel(f"% of Tasks ({metric})")
    ax.set_xlabel("")
    ax.tick_params(axis='x', rotation=90)

# Add one legend to the figure
#legend = g.axes.flat[1].legend(loc='upper center',  bbox_to_anchor=(.5, 1.3), ncol=3)
plt.subplots_adjust(wspace=.07, hspace=.05)
#plt.tight_layout()
plt.show()

In [None]:
g.fig.savefig(f"../persona_performance_paper/media/methods_comp_all.pdf", dpi=300, bbox_inches="tight")