In [None]:

import torch 
import matplotlib.pyplot as plt
from data.kernel import *
from data.sampler import *
from data.function import *
from data.evaluation import * 
from policies.pbo import * 
from utils.plot import *
from policies.transformer import * 
import os 
from data.candy_data_handler import * 
from data.sushi_data_handler import * 
%reload_ext autoreload
%autoreload 2

In [None]:
def read_pbo_results(dataname, acq_function_type, dataset_id, num_seed=30, T=30):
    root = f'results/evaluation/{dataname}/{acq_function_type}/{str(dataset_id)}'
    simple_regret, cumulative_regret, inference_time, immediate_regret = list(), list(), list(), list()
    for i in range(num_seed):
        simple_regret.append(torch.load(os.path.join(root, f"simple_regret_{i}.pt"), map_location="cpu")) # (H)
        cumulative_regret.append(torch.load(os.path.join(root, f"cumulative_regret_{i}.pt"), map_location="cpu"))
        inference_time.append(torch.load(os.path.join(root, f"cumulative_inference_time_{i}.pt"), map_location="cpu"))
        immediate_regret.append(torch.load(os.path.join(root, f"immediate_regret_{i}.pt"), map_location="cpu"))
        
    simple_regret = torch.stack(simple_regret) # (num_seeds, H)
    cumulative_regret = torch.stack(cumulative_regret)
    immediate_regret = torch.stack(immediate_regret)
    inference_time = torch.stack(inference_time)
    assert simple_regret.shape == (30, T+1), f"{acq_function_type}, {simple_regret.shape}"
    assert cumulative_regret.shape == (30, T+1), f"{acq_function_type}, {cumulative_regret.shape}"
    assert immediate_regret.shape == (30, T+1), f"{acq_function_type}, {immediate_regret.shape}"
    assert inference_time.shape == (30, T), f"{acq_function_type}, {inference_time.shape}"
    return simple_regret, cumulative_regret, immediate_regret, inference_time
    

In [None]:
def plot_results(results: dict, fig=None, i=0, row=1, col=3): 
    """Plot simple regret, cumulative regret, and Cumulative inference time.
    
    Args: 
        results: dictionary of models' results. For a model m, results[m] is also a dictionary: 
            - simple_regret, (num_seed, num_dataset, H): simple regret along trajectories.
            - cumulative_regret, (num_seed, num_dataset, H): cumulative regret along trajectories.
            - inference_time, (num_seed, num_dataset, H): cumulative inference_time along trajectories.
    """
    if fig is None: 
        fig = plt.figure(figsize=(16, 5))
    model_names = list(results.keys())
    ax = fig.add_subplot(row, col, i+1)
    plot_metric_along_trajectory(metrics=[results[k]["simple_regret"].mean(dim=1) for k in model_names],
                                 model_names=model_names, 
                                 ax=ax)
    ax.set_yscale("log")
    ax = fig.add_subplot(row, col, i+2)
    plot_metric_along_trajectory(metrics=[results[k]["cumulative_regret"].mean(dim=1) for k in model_names],
                                 model_names=model_names, 
                                 ax=ax)
    ax.set_yscale("log")

    ax = fig.add_subplot(row, col, i+3)
    plot_metric_along_trajectory(metrics=[results[k]["inference_time"].mean(dim=1) for k in model_names],
                                    model_names=model_names, 
                                    ax=ax)

In [None]:
dataname = "sushi"
num_dataset=1
results  = {}

models = ["rs", "qEUBO", "qTS", "qEI", "qNEI", "mpes"]
for acq_function_type in models: 
    results[acq_function_type] = dict()
    results[acq_function_type]["simple_regret"]  = list()
    results[acq_function_type]["immediate_regret"] = list()
    results[acq_function_type]["cumulative_regret"] = list()
    results[acq_function_type]["inference_time"] = list()
    
    for dataset_id in range(num_dataset): 
        simple_regret, cumulative_regret, immediate_regret, inference_time = read_pbo_results(dataname=dataname, 
                                                                                              acq_function_type=acq_function_type, 
                                                                                              dataset_id=dataset_id, 
                                                                                              T=100)
        results[acq_function_type]["simple_regret"].append(simple_regret)
        results[acq_function_type]["cumulative_regret"].append(cumulative_regret)
        results[acq_function_type]["inference_time"].append(inference_time)
        results[acq_function_type]["immediate_regret"].append(immediate_regret)
    
    results[acq_function_type]["simple_regret"]  = torch.stack(results[acq_function_type]["simple_regret"], dim=1) # (num_seed, num_dataset, H)
    results[acq_function_type]["cumulative_regret"] = torch.stack(results[acq_function_type]["cumulative_regret"], dim=1)
    results[acq_function_type]["inference_time"] = torch.stack(results[acq_function_type]["inference_time"], dim=1)
    results[acq_function_type]["immediate_regret"] = torch.stack(results[acq_function_type]["immediate_regret"], dim=1)

pabbo = "PABBO_GP4D"

results["PABBO512"] = dict()
root = f'results/evaluation/{dataname}/PABBO/{pabbo}'
results["PABBO512"]["simple_regret"] =  torch.load(f"{root}/SIMPLE_REGRET_S512_B1.pt", map_location="cpu")
results["PABBO512"]["inference_time"] = torch.load(f"{root}/CUMULATIVE_TIME_S512_B1.pt", map_location="cpu")
results["PABBO512"]["cumulative_regret"] = torch.load(f"{root}/CUMULATIVE_REGRET_S512_B1.pt", map_location="cpu")
results["PABBO512"]["immediate_regret"] = torch.load(f"{root}/IMMEDIATE_REGRET_S512_B1.pt", map_location="cpu")


plot_results(results)
fig = plt.figure(figsize=(16, 5))
plot_results(results, fig=fig)
handles, labels = plt.gca().get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=len(results), bbox_to_anchor=(0.5, -0.1))
plt.show()