In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch

from vla_calibration.utils import *
from vla_calibration.calibration import *

import warnings
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8')
pal = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
def recal_experiment(
    action_conf, 
    action_logits,
    correct, 
    test_size=0.4,
    n_trials = 100,
    n_cal_bins = 10
):
    
    uncal_eces = []
    recal_eces = []
    action_recal_eces = []
    temp_scale_eces = []
    
    for trial_no in tqdm(range(n_trials)):

        conf_train, conf_test, logits_train, logits_test, correct_train, correct_test \
              = train_test_split_three_way(action_conf, action_logits, correct, test_size=test_size, random_state=trial_no)
        
        mean_train_conf = np.mean(conf_train, -1)
        mean_test_conf = np.mean(conf_test, -1)

        uncal_ece = round(get_ece(mean_test_conf, correct_test, n_cal_bins), 3)

        scaler = PlattScaler(max_iter=200, tol=1e-6)
        scaler.fit(mean_train_conf, correct_train)

        calibrated_probs = scaler.predict(mean_test_conf)

        recal_ece = round(get_ece(calibrated_probs, correct_test, n_cal_bins), 3)

        action_scaler = ActionPlattScaler(max_iter=200, tol=1e-8, combine_method="mean")
        action_scaler.fit(conf_train, correct_train)

        calibrated_probs = action_scaler.predict(conf_test)

        action_recal_ece = round(get_ece(calibrated_probs, correct_test, n_cal_bins), 3)

        temp_scaler = TempScaler()
        temp_scaler.set_temperature(torch.Tensor(logits_train), torch.FloatTensor(correct_train))
        temp_scaler.eval()
        with torch.no_grad():
            scaled_test_conf = temp_scaler.temperature_scale(torch.Tensor(logits_test).cuda()).cpu().numpy()
        temp_scale_ece = round(get_ece(scaled_test_conf, correct_test, n_cal_bins), 3)


        uncal_eces.append(uncal_ece)
        recal_eces.append(recal_ece)
        action_recal_eces.append(action_recal_ece)

        temp_scale_eces.append(temp_scale_ece)
        

    print(f"uncal ece: {np.mean(uncal_eces)} | recal ece: {np.mean(recal_eces)} | action recal ece: {np.mean(action_recal_eces)}")
    print(f"temp scale ece: {np.mean(temp_scale_eces)}")
    return {
        "uncal_ece": np.mean(uncal_eces), "recal_ece": np.mean(recal_eces), "action_recal_ece": np.mean(action_recal_eces),
        "temp_scale_ece": np.mean(temp_scale_eces),
        "recal_ece_std": np.std(recal_eces), "action_recal_ece_std": np.std(action_recal_eces),
        "temp_scale_std": np.std(temp_scale_eces),
        }




In [None]:
def run_exp(task_name, quant, n_bins=12, test_size=0.8, n_prompts=20, n_trials=1000, alternate_set=1):


    base_probs, _, base_logits, _, correct, by_dim_results = get_scaling_data(
        task_name, 
        alternate_set=alternate_set, 
        n_prompts=n_prompts,
        quant=quant,
    )

    base_conf = np.mean(np.max(base_probs, -1), -2)[:,0]

    print("---------------------\nBase Recalibration")
    baseline_results = recal_experiment(
        base_conf, 
        base_logits[:,0,0],
        correct, 
        test_size=test_size,
        n_trials=n_trials,
        n_cal_bins=n_bins,
    )
    print(baseline_results)

    bar_colors = pal[1:4]
    scale_factor = 1.25

    fig, axs = plt.subplots(1,2, figsize=(10, 3.25), width_ratios=[0.45,0.55])

    scores = by_dim_results["baseline"]
    X = np.arange(len(scores))
    axs[0].bar(X, scores)
    labels = [f"{i+1}" for i in range(len(by_dim_results["baseline"]))]
    axs[0].set_xticks(X, labels, fontsize=15)
    axs[0].set_xlabel("Action Dimension", fontsize=18)

    scores = [baseline_results["temp_scale_ece"], baseline_results["recal_ece"], baseline_results["action_recal_ece"]]
    stds = [baseline_results["temp_scale_std"], baseline_results["recal_ece_std"], baseline_results["action_recal_ece_std"]]
    labels = ["Temp\nScaling","Platt\nScaling","Action-Wise\nPlatt Scaling"]
    X = np.arange(len(scores))
    axs[1].bar(X, scores, color=bar_colors, yerr=stds/np.sqrt(n_trials), error_kw=dict(ecolor='dimgrey', lw=2, capsize=3, capthick=2))
    axs[1].set_xticks(X, labels, fontsize=15)

    for i in range(2):
        axs[i].set_ylabel(r"$\text{ECE}_1$", fontsize=18)
        axs[i].tick_params(axis="y", labelsize=13)

    if quant is not None:
        quant_tag = f" ({quant})"
    else:
        quant_tag = ""

    axs[1].set_ylim(min(baseline_results["recal_ece"], baseline_results["action_recal_ece"])/scale_factor, None)

    fig.suptitle(f"{str.title(task_name)}{str.title(quant_tag)}", fontsize=18, y=0.95)
    fig.tight_layout()

    quant_save_string = quant_tag.replace("(","").replace(")","").strip()

    plt.savefig(f"../plots/action_scaling_{task_name}_{quant_save_string}_baseline_w_temp_scaling.png", dpi=600, bbox_inches="tight")
    plt.show()


In [None]:
n_bins = 10
test_size = 0.8
n_trials = 1000


run_exp("spatial", quant=None, n_bins=n_bins, test_size=test_size, n_trials=n_trials)
run_exp("goal", quant=None, n_bins=n_bins, test_size=test_size, n_trials=n_trials)


In [None]:
run_exp("spatial", quant="quant8", n_bins=n_bins, test_size=test_size, n_trials=n_trials)
run_exp("goal", quant="quant8", n_bins=n_bins, test_size=test_size, n_trials=n_trials)


In [None]:
run_exp("spatial", quant="quant4", n_bins=n_bins, test_size=test_size, n_trials=n_trials)
run_exp("goal", quant="quant4", n_bins=n_bins, test_size=test_size, n_trials=n_trials)
