In [None]:
import pickle as pkl
from scipy.special import softmax
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from vla_calibration.utils import *
from vla_calibration.calibration import *

plt.style.use('seaborn-v0_8')
pal = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
def run_experiment(
        task_name, 
        quant=None,
        alternate_set=1, 
        n_prompts=20, 
        n_cal_bins=12,
):
    
    data_save_dir = f"../results/libero_{task_name}"
    if quant is not None:
        data_save_dir += f"/{quant}"

    top_n_steps=1

    base_probs, _, correct = get_base_data(data_save_dir, top_n_steps)

    base_probs = np.expand_dims(base_probs, axis=2)

    all_probs = []

    for i in range(n_prompts):

        prompt_probs = []

        if alternate_set == 1:
            data_save_str = f"{data_save_dir}/episode_data_prompt_{i}.pkl"
        elif alternate_set == 2:
            data_save_str = f"{data_save_dir}/episode_data_prompt_{i}_v2.pkl"
        elif alternate_set == 3:
            data_save_str = f"{data_save_dir}/episode_data_prompt_{i}_v3.pkl"
        else:
            raise ValueError

        with open(data_save_str, "rb") as f:  
            data = pkl.load(f)

        for episode in data:

            episode_probs = []

            steps = episode["steps"]

            for step in steps[:top_n_steps]:

                logits = step["logits"]
                probs = softmax(logits, -1)

                episode_probs.append(probs)

            episode_probs = np.stack(episode_probs)
            prompt_probs.append(episode_probs)

        prompt_probs = np.stack(prompt_probs)

        all_probs.append(prompt_probs)

    
    all_probs = np.stack(all_probs)
    ens_probs = np.transpose(all_probs, (1,2,0,3,4))

    base_probs = base_probs[:,0]
    ens_probs = ens_probs[:,0]


    base_probs = np.max(base_probs, -1)
    ens_probs = np.max(ens_probs, -1)

    base_conf = np.mean(base_probs, -2)
    ens_conf = np.mean(ens_probs, -2)

    mean_base_conf = np.mean(base_conf, -1)
    mean_ens_conf = np.mean(ens_conf, -1)

    base_ece1 = round(get_ece(mean_base_conf, correct, n_cal_bins, p=1), 3)
    ens_ece1 = round(get_ece(mean_ens_conf, correct, n_cal_bins, p=1), 3)

    base_ece2 = round(get_ece(mean_base_conf, correct, n_cal_bins, p=2), 3)
    ens_ece2 = round(get_ece(mean_ens_conf, correct, n_cal_bins, p=2), 3)

    base_brier = round(np.mean((mean_base_conf - correct)**2), 3)
    ens_brier = round(np.mean((mean_ens_conf - correct)**2), 3)

    base_ce = cross_entropy(correct, mean_base_conf)
    ens_ce = cross_entropy(correct, mean_ens_conf)

    if quant is not None:
        quant_tag = quant
    else:
        quant_tag = "Full"

    base_row = [task_name, quant_tag, "baseline", base_ece1, base_ece2, base_brier, base_ce, np.mean(correct)]
    ens_row = [task_name, quant_tag, "reprompt", ens_ece1, ens_ece2, ens_brier, ens_ce, np.mean(correct)]
    rows = [base_row, ens_row]

    df = pd.DataFrame(rows, columns=["Dataset","Model","Method","ECE-1","ECE-2","Brier","NLL","Accuracy"])
    return df
    


In [None]:
n_bins = 12
alternate_set = 1

full_df = pd.DataFrame()

In [None]:
df = run_experiment(
    "spatial", 
    alternate_set=alternate_set, 
    n_cal_bins=n_bins,
    n_prompts=20
)
full_df = pd.concat([full_df, df])

df = run_experiment(
    "object", 
    alternate_set=alternate_set, 
    n_cal_bins=n_bins,
    n_prompts=20
)
full_df = pd.concat([full_df, df])

df = run_experiment(
    "goal", 
    alternate_set=alternate_set,  
    n_cal_bins=n_bins,
    n_prompts=20
)
full_df = pd.concat([full_df, df])


In [None]:
full_df

In [None]:
df = run_experiment(
    "spatial", 
    alternate_set=alternate_set, 
    n_cal_bins=n_bins,
    quant="quant8",
    n_prompts=20
)
full_df = pd.concat([full_df, df])

df = run_experiment(
    "object", 
    alternate_set=alternate_set, 
    n_cal_bins=n_bins,
    quant="quant8",
    n_prompts=20
)
full_df = pd.concat([full_df, df])

df = run_experiment(
    "goal", 
    alternate_set=alternate_set, 
    n_cal_bins=n_bins,
    quant="quant8",
    n_prompts=20
)
full_df = pd.concat([full_df, df])

In [None]:
full_df

In [None]:
df = run_experiment(
    "spatial", 
    alternate_set=alternate_set, 
    n_cal_bins=n_bins,
    quant="quant4",
    n_prompts=20
)
full_df = pd.concat([full_df, df])

df = run_experiment(
    "object", 
    alternate_set=alternate_set, 
    n_cal_bins=n_bins,
    quant="quant4",
    n_prompts=20
)
full_df = pd.concat([full_df, df])

df = run_experiment(
    "goal", 
    alternate_set=alternate_set, 
    n_cal_bins=n_bins,
    quant="quant4",
    n_prompts=20
)
full_df = pd.concat([full_df, df])

In [None]:
full_df

In [None]:
main_df = full_df[["Model", "Dataset", "Method", "ECE-1", "ECE-2", "Brier", "NLL"]]
main_df = main_df[main_df["Model"] != "quant4"]
display(main_df)

In [None]:
print(main_df.to_latex(index=False, float_format="%.3f"))

In [None]:
app_df = full_df[["Model", "Dataset", "Method", "ECE-1", "ECE-2", "Brier", "NLL"]]
app_df = app_df[app_df["Model"] == "quant4"]
display(app_df)

In [None]:
print(app_df.to_latex(index=False, float_format="%.3f"))

In [None]:
tradeoff_df = full_df[full_df["Method"] == "reprompt"]
display(tradeoff_df)

metrics_list = ["ECE-1", "ECE-2", "Brier", "NLL"]

In [None]:
success_df = tradeoff_df[["Dataset", "Model", "Accuracy"]]
display(success_df)

row1 = ["Full"] + success_df[success_df["Model"] == "Full"]["Accuracy"].tolist()
row2 = ["Quant-8"] + success_df[success_df["Model"] == "quant8"]["Accuracy"].tolist()
row3 = ["Quant-4"] + success_df[success_df["Model"] == "quant4"]["Accuracy"].tolist()

success_df = pd.DataFrame([row1, row2, row3], columns=["Model", "Spatial", "Object", "Goal"])
print(success_df.to_latex(index=False, float_format="%.3f"))

In [None]:
baseline_df = full_df[full_df["Method"] == "baseline"]

In [None]:
reprompt_df = full_df[full_df["Method"] == "reprompt"]

In [None]:
baseline_df

In [None]:
fig, axs = plt.subplots(1,4, figsize=(13,3.25))

for i, metric in enumerate(metrics_list):

    coord = i
    ax = axs[coord]

    baseline_scores = baseline_df[metric].tolist()
    reprompt_scores = reprompt_df[metric].tolist()
    ax.scatter(reprompt_scores, baseline_scores, color=pal[i], s=70)
    ax.plot([0,1],[0,1], "--", color="k", alpha=0.5)

    ax_min = min(min(baseline_scores),min(reprompt_scores))*0.95
    ax_max = max(max(baseline_scores),max(reprompt_scores))*1.05

    ax.set_xlim(ax_min, ax_max)
    ax.set_ylim(ax_min, ax_max)

    ax.set_title(metric, fontsize=18)
    ax.set_xlabel("Reprompt", fontsize=18)
    
    if i == 0:
        ax.set_ylabel("Baseline", fontsize=18)


axs[0].set_xticks([0.05, 0.10,0.15])
axs[0].set_yticks([0.05, 0.10,0.15])

axs[1].set_xticks([0.05, 0.10,0.15,0.2])
axs[1].set_yticks([0.05, 0.10,0.15,0.2])

axs[2].set_xticks([0.1, 0.15, 0.2,0.25])
axs[2].set_yticks([0.1, 0.15, 0.2,0.25])

axs[3].set_xticks([0.40, 0.50, 0.60, 0.70])
axs[3].set_yticks([0.40, 0.50, 0.60, 0.70])

for i in range(4):
    axs[i].tick_params(axis='x', labelsize=13)
    axs[i].tick_params(axis='y', labelsize=13)

axs[0].set_title(r"$\text{ECE}_1$", fontsize=18)
axs[1].set_title(r"$\text{ECE}_2$", fontsize=18)
axs[2].set_title("Brier score", fontsize=18)
    
fig.tight_layout()
plt.savefig("../plots/total_results.png", dpi=600, bbox_inches="tight")
plt.show()


In [None]:
fig = plt.figure(figsize=(6,3))

for i, metric in enumerate(metrics_list):

    bm = np.array(baseline_df[metric].tolist())
    rm = np.array(reprompt_df[metric].tolist())

    print(rm/bm)

    plt.scatter([i/2]*len(bm), -(1-(rm/bm))*100, marker="x", s=100, lw=4)
    if i == 0:
        plt.scatter([i/2], -(1-(np.mean(rm/bm)))*100, marker="+", color="k", label="Avg. % Change", s=150, lw=3)
    else:
        plt.scatter([i/2], -(1-(np.mean(rm/bm)))*100, marker="+", color="k", s=150, lw=3)


xticks = [r"$\text{ECE}_1$",r"$\text{ECE}_2$","Brier score","NLL"]
plt.xticks(np.arange(4)/2, xticks, fontsize=18)
plt.ylabel("% Change w/ Reprompt", fontsize=15)
plt.yticks(fontsize=14)
plt.legend(fontsize=12, loc="upper left")

plt.xlim(-0.2,1.7)

fig.tight_layout()

plt.savefig("../plots/pct_reduction.png", dpi=600, bbox_inches="tight")
plt.show()


In [None]:
tradeoff_df = full_df[full_df["Method"] == "baseline"]
display(tradeoff_df)

metrics_list = ["ECE-1", "ECE-2", "Brier", "NLL"]

fig, axs = plt.subplots(
    1,4, 
    figsize=(13,3.25), 
)

for i, metric in enumerate(metrics_list):

    coord = i
    ax = axs[coord]

    for model in ["Full", "quant8", "quant4"]:

        sub_df = tradeoff_df[tradeoff_df["Model"] == model]
        axs[coord].scatter(1-np.array(sub_df["Accuracy"]), sub_df[metric], marker="+", label=model, s=150, lw=5, color=pal[i])

for i in range(4):
    axs[i].set_xlabel("Task Error Rate", fontsize=18)
    axs[i].set_xticks([0.15, 0.2, 0.25], [0.15, 0.2, 0.25], fontsize=13)


axs[0].set_ylabel(r"$\text{ECE}_1$", fontsize=18)
axs[1].set_ylabel(r"$\text{ECE}_2$", fontsize=18)
axs[2].set_ylabel(r"Brier score", fontsize=18)
axs[3].set_ylabel(r"NLL", fontsize=18)


axs[0].set_yticks([0.05, 0.10,0.15], [0.05, 0.10,0.15], fontsize=13)
axs[1].set_yticks([0.05, 0.10,0.15,0.2], [0.05, 0.10,0.15,0.2], fontsize=13)
axs[2].set_yticks([0.1, 0.15, 0.2,0.25], [0.1, 0.15, 0.2,0.25], fontsize=13)
axs[3].set_yticks([0.40, 0.50, 0.60, 0.70], [0.40, 0.50, 0.60, 0.70], fontsize=13)

fig.suptitle("Task Error vs. Calibration Error", y=0.94, fontsize=18)

plt.tight_layout()
plt.savefig("../plots/tradeoffs_baseline.png", dpi=600, bbox_inches="tight")
plt.show()

In [None]:
tradeoff_df = full_df[full_df["Method"] == "reprompt"]
display(tradeoff_df)

metrics_list = ["ECE-1", "ECE-2", "Brier", "NLL"]

fig, axs = plt.subplots(
    1,4, 
    figsize=(13,3.25), 
)

for i, metric in enumerate(metrics_list):

    coord = i
    ax = axs[coord]

    for model in ["Full", "quant8", "quant4"]:

        sub_df = tradeoff_df[tradeoff_df["Model"] == model]
        axs[coord].scatter(1-np.array(sub_df["Accuracy"]), sub_df[metric], marker="+", label=model, s=150, lw=5, color=pal[i])

for i in range(4):
    axs[i].set_xlabel("Task Error Rate", fontsize=18)
    axs[i].set_xticks([0.15, 0.2, 0.25], [0.15, 0.2, 0.25], fontsize=13)


axs[0].set_ylabel(r"$\text{ECE}_1$", fontsize=18)
axs[1].set_ylabel(r"$\text{ECE}_2$", fontsize=18)
axs[2].set_ylabel(r"Brier score", fontsize=18)
axs[3].set_ylabel(r"NLL", fontsize=18)


axs[0].set_yticks([0.05, 0.10,0.15], [0.05, 0.10,0.15], fontsize=13)
axs[1].set_yticks([0.05, 0.10,0.15,0.2], [0.05, 0.10,0.15,0.2], fontsize=13)
axs[2].set_yticks([0.1, 0.15, 0.2,0.25], [0.1, 0.15, 0.2,0.25], fontsize=13)
axs[3].set_yticks([0.40, 0.50, 0.60, 0.70], [0.40, 0.50, 0.60, 0.70], fontsize=13)

fig.suptitle("Task Error vs. Calibration Error (Reprompt)", y=0.95, fontsize=18)

plt.tight_layout()
plt.savefig("../plots/tradeoffs_reprompt.png", dpi=600, bbox_inches="tight")
plt.show()