In [None]:
import pickle as pkl
from scipy.special import softmax
import numpy as np
from sklearn.calibration import calibration_curve
from pathlib import Path
import matplotlib.pyplot as plt

from vla_calibration.utils import *
from vla_calibration.calibration import *

plt.style.use('seaborn-v0_8')
pal = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
methods_map = {
    "last_step": "Current",
    "sliding_5": "Window (5)",
    "avg_all": "Avg. All"
}

In [None]:
def get_timestep_data(data_save_dir, x_points, methods):

    path = f"{data_save_dir}/base_probs.pkl"
    path = Path(path).expanduser().resolve()
    verbose=True
    protocol = pkl.HIGHEST_PROTOCOL

    if path.is_file():
        if verbose:
            print(f"[load_or_create_pickle] Loading existing pickle: {path}")
        with path.open("rb") as f:
            all_probs, correct = pkl.load(f)
        
    else:


        with open(f"{data_save_dir}/episode_data_true_prompt.pkl", "rb") as f: 
            data = pkl.load(f)

        all_probs = {}
        correct = []

        for episode in data:

            steps = episode["steps"]

            episode_probs = []

            for step in steps:

                logits = step["logits"]

                probs = softmax(logits, -1)

                episode_probs.append(probs)

            episode_probs = np.stack(episode_probs)
            episode_conf = np.max(episode_probs, -1)

            for percent in x_points:

                if percent not in all_probs:
                    all_probs[percent] = dict()

                last_step = np.ceil(len(steps)*percent).astype(int)
                if last_step >= len(steps):
                    last_step = len(steps)-1

                for method in methods:

                    if method not in all_probs[percent]:
                        all_probs[percent][method] = []

                    if "sliding" in method:
                        window_size = int(method.split("_")[-1])
                        min_step = np.max([0,last_step-window_size])
                        conf = np.mean(np.mean(episode_conf[min_step:last_step+1,:], -1), -1)

                    elif method == "avg_all":
                        conf = np.mean(np.mean(episode_conf[:last_step+1,:], -1), -1)

                    elif method == "last_step":
                        conf = np.mean(episode_conf[last_step], -1)

                    all_probs[percent][method].append(conf)

            correct.append(int(episode["done"]))

        data = (all_probs, correct)
        path.parent.mkdir(parents=True, exist_ok=True) 
        with path.open("wb") as f:
            pkl.dump(data, f, protocol=protocol)

    return all_probs, correct

In [None]:
def produce_plots(
        all_probs,
        correct,
        task_name, 
        x_points,
        methods,
        quant,
        save_string,
        n_cal_bins=12,
        save_fig=True,
        trailing_k=5,
        title_addition="",     
):
    quant_save_string = ""
    if quant is not None:
        quant_save_string = f"_{quant}"

    quant_string = ""
    if quant is not None:
        quant_string = f" ({str.title(quant)})"

    results = {}
    brier_results = {}

    fig, axs = plt.subplots(
        1,4, 
        figsize=(13, 3.5), 
    )

    for method in methods:
        if method not in results:
            results[method] = []
            brier_results[method] = []
        for percent in x_points:
            ece = get_ece(np.array(all_probs[percent][method]), correct, n_cal_bins)
            results[method].append(ece)
            brier_results[method].append(np.mean(((np.array(all_probs[percent][method]) - correct)**2)))

        axs[0].plot(np.array(x_points)*100, trailing_average(np.array(results[method]), k=trailing_k), "--", label=methods_map[method], lw=2)
        axs[1].plot(np.array(x_points)*100, trailing_average(np.array(brier_results[method]), k=trailing_k), "--", label=methods_map[method], lw=2)

    baseline = results[methods[0]][0]
    axs[0].plot(np.array(x_points)*100,[baseline]*len(x_points), "--", color="k", alpha=0.5)

    baseline = brier_results[methods[0]][0]
    axs[1].plot(np.array(x_points)*100,[baseline]*len(x_points), "--", color="k", alpha=0.5)
    

    ax23_ymin = 100.0
    ax23_ymax = 0.0

    reliability_data = {
        "0.0": dict(),
        "0.5": dict(),
        "0.75": dict(),
        "0.99": dict(),
    }
    rd_pcts = list(reliability_data.keys())


    t = 0.9
    print("-"*20)
    print("accuracy:", np.mean(correct))
    results = {}
    for method in methods:

        if method not in results:
            results[method] = dict()
            for score in ["avg_incorrect", "avg_correct", "avg_all", "std_incorrect", "std_correct", "pct_incorrect", "pct_correct", "pct_all"]:
                results[method][score] = list()
        for percent in x_points:
            probs = np.array(all_probs[percent][method])
            avg_incorrect, avg_correct = average_confidences(correct, probs)
            std_incorrect, std_correct = std_confidences(correct, probs)
            pct_incorrect, pct_correct = pct_high_confidence(correct, probs, t)
            results[method]["avg_incorrect"].append(avg_incorrect)
            results[method]["avg_correct"].append(avg_correct)
            results[method]["avg_all"].append(np.mean(probs))
            results[method]["std_incorrect"].append(std_incorrect)
            results[method]["std_correct"].append(std_correct)
            results[method]["pct_incorrect"].append(pct_incorrect)
            results[method]["pct_correct"].append(pct_correct)
            results[method]["pct_all"].append((probs > t).mean())

            if str(percent) in reliability_data:
                reliability_data[str(percent)][method] = (probs, correct)

        avg_correct = trailing_average(np.array(results[method]["avg_correct"]), k=trailing_k)
        avg_incorrect = trailing_average(np.array(results[method]["avg_incorrect"]), k=trailing_k)

        axs[2].plot(np.array(x_points)*100, avg_correct, "--", label=methods_map[method], lw=2)
        axs[3].plot(np.array(x_points)*100, avg_incorrect, "--", label=methods_map[method], lw=2)

        ax23_ymin = min(min(avg_correct), ax23_ymin)
        ax23_ymin = min(min(avg_incorrect), ax23_ymin)

        ax23_ymax = max(max(avg_correct), ax23_ymax)
        ax23_ymax = max(max(avg_incorrect), ax23_ymax)

    axs[0].set_title("Calibration Error", fontsize=18)
    axs[1].set_title("Calibration Error", fontsize=18)

    axs[0].set_ylabel(r"$\text{ECE}_1$", fontsize=18)
    axs[1].set_ylabel("Brier Score", fontsize=18)

    axs[2].set_title("Successful Trials", fontsize=18)
    axs[3].set_title("Failed Trials", fontsize=18)

    axs[2].set_ylabel("Avg. Conf.", fontsize=18)
    axs[3].set_ylabel("Avg. Conf.", fontsize=18)
    

    for i in range(4):
        axs[i].set_xlabel("% Task Completion", fontsize=16)
        axs[i].tick_params(axis='y', labelsize=11)
        axs[i].set_xticks([0, 25, 50, 75, 100], [0, 25, 50, 75, 100], fontsize=12)

    for j in range(2):
        axs[j+2].set_ylim(ax23_ymin*0.995, ax23_ymax*1.005)

    axs[-1].legend(fontsize=12)

    fig.suptitle(f"{str.title(task_name)}{quant_string}{title_addition}", y=0.94, fontsize=18)

    fig.tight_layout()
    if save_fig:
        plt.savefig(f"../plots/across_time_{task_name}{quant_save_string}_{save_string}.png", dpi=600, bbox_inches="tight")
    plt.show()


    fig, axs = plt.subplots(2,len(rd_pcts), sharex=True, sharey=True, figsize=(9,4))

    for pct_idx, pct in enumerate(rd_pcts):

        for i, method in enumerate(["last_step", "avg_all"]):
            data = reliability_data[pct][method]
            prob_true, prob_pred = calibration_curve(np.array(data[1]), data[0], n_bins=n_cal_bins, strategy="quantile")
            axs[i,pct_idx].scatter(prob_pred, prob_true, color=pal[i+3])
            axs[i,pct_idx].set_xlim(.5,1.02)
            axs[i,pct_idx].set_ylim(0.0,1.0)
            axs[i,pct_idx].plot([0,1], [0,1], "--", color="k", alpha=0.5)

        axs[0,pct_idx].set_title(f"{int(float(pct)*100)}% Complete", fontsize=14)
        
    axs[0,0].set_ylabel("Accuracy", fontsize=14)
    axs[1,0].set_ylabel("Accuracy", fontsize=14)

    fig.text(-0.12, 0.72, "Current", fontsize=14, fontweight="bold")
    fig.text(-0.12, 0.3, "Avg. All", fontsize=14, fontweight="bold")

    for i in range(len(rd_pcts)):
        axs[1,i].set_xlabel("Confidence", fontsize=14)

    fig.tight_layout()

    if save_fig:
        plt.savefig(f"../plots/across_time_{task_name}{quant_save_string}_{save_string}_reliability_diagrams.png", dpi=600, bbox_inches="tight")
    plt.show()

    print("-"*20)


def run_experiment(
        task_name, 
        x_points=[0.1, 0.5, 0.99],
        methods=["last_step", "sliding_5", "sliding_10", "avg_all"],
        quant=None,
        n_cal_bins=12,
        save_fig=True,
        trailing_k=5
):
    
    data_save_dir = f"../results/libero_{task_name}"
    if quant is not None:
        data_save_dir += f"/{quant}"

    base_probs, correct = get_timestep_data(data_save_dir, x_points, methods)
    
    produce_plots(
        base_probs, 
        correct,
        task_name, 
        x_points,
        methods,
        quant,
        save_string="_baseline",
        n_cal_bins=n_cal_bins,
        save_fig=save_fig,
        trailing_k=trailing_k,
        title_addition="",
    )


In [None]:
n_bins = 12
x_points = list(np.arange(100)/100)
methods=[
    "last_step", 
    "sliding_5", 
    "avg_all"
]

In [None]:
for suite in [
    "spatial", 
    "object", 
    "goal"
]:

    run_experiment(
        suite, 
        x_points,
        methods,
        n_cal_bins=n_bins,
        trailing_k=3
    )



In [None]:
for suite in [
    "spatial", 
    "object", 
    "goal"
]:

    run_experiment(
        suite, 
        x_points,
        methods,
        quant="quant8",
        n_cal_bins=n_bins,
        trailing_k=3
    )