In [None]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import itertools
import palettable

sns.set(style='whitegrid', font_scale=1.75)

from palettable.cartocolors.qualitative import *

In [None]:
api = wandb.Api()

In [None]:
tasks = [
    'zinc_penalized_logp',
    'zinc_3pbl_docking',
    'zinc_qed',
    'poas_stability',
]

## Benchmark against genetic optimizers

In [None]:
result_queries = [
    {
        'label': 'UCB',
        'color': "#003049",
        'filters': {
            'config.acq_fn': 'ucb',
            'config.exp_name': 'tabular_bandits',
            'config.version': 'v0.0.5',
            'config.num_total_rows': 256,
            'config.max_num_steps': 128,
            'config.num_baseline': 32
        },
    },
    {
        'label': 'C-UCB',
        'color': "#f6ae2d",
        'filters': {
            'config.acq_fn': 'cucb',
            'config.exp_name': 'tabular_bandits',
            'config.version': 'v0.0.5',
            'config.num_total_rows': 256,
            'config.max_num_steps': 128,
            'config.num_baseline': 32
        },
    },
]

In [None]:
key = 'cum_regret'

for task in tasks:
    fig = plt.figure(figsize=(6, 5))
    ax = fig.add_subplot(1, 1, 1)
    
    for query in result_queries:

        filters = {'config.task/name': task}
        filters.update(query['filters'])
        # if filters['config.acq_fn'] == 'ucb':
        #     key = 'query_cred_cvrg'
        # else:
        #     key = 'query_conf_cvrg'
        
        runs = api.runs(
            path='samuelstanton/conformal-bayesopt',
            filters=filters
        )
        configs = [r.config for r in runs]
        unique_seeds = pd.DataFrame(configs).seed.drop_duplicates()
        unique_runs = [r for (idx, r) in enumerate(runs) if idx in unique_seeds.index]
        history = [r.scan_history() for r in unique_runs]
        
        if len(history) == 0:
            continue
        else:
            print(len(history))
            
        df = pd.concat([pd.DataFrame(h) for h in history])
        # x_col = [c for c in df.columns if in c]
        x_col = ['_step']
        y_col = [c for c in df.columns if key in c]
        
        y_med = df.groupby(x_col)[y_col].quantile(
            0.5, interpolation='linear'
        )
        y_med = pd.DataFrame(y_med.to_records())
        y_med = y_med.sort_values(x_col).rolling(window=5, min_periods=1, axis=0).mean()
        
        y_lb = df.groupby(x_col)[y_col].quantile(
            0.2, interpolation='linear'
        )
        y_lb = pd.DataFrame(y_lb.to_records())
        y_lb = y_lb.sort_values(x_col).rolling(window=5, min_periods=1, axis=0).mean()
        y_lb = y_lb[y_col].values.reshape(-1)
        
        y_ub = df.groupby(x_col)[y_col].quantile(
            0.8, interpolation='linear'
        )
        y_ub = pd.DataFrame(y_ub.to_records())
        y_ub = y_ub.sort_values(x_col).rolling(window=5, min_periods=1, axis=0).mean()
        y_ub = y_ub[y_col].values.reshape(-1)
        
        x_vals = y_med[x_col].values.reshape(-1)
        y_med = y_med[y_col].values.reshape(-1)
        
        ax.plot(x_vals, y_med, linewidth=4, zorder=3, color=query['color'], label=query['label'])
        ax.fill_between(x_vals, y_lb, y_ub, alpha=0.25, color=query['color'])
        
    # exp_cvrg = df.groupby(x_col)['query_exp_cvrg'].mean()
    # ax.plot(x_vals, exp_cvrg, color='black', linestyle='--', linewidth=4, zorder=2)
    
    # ax.set_xlim(plt.xlim())
    # ax.hlines(y_vals[0], *plt.xlim(), color='black', linestyle='--', linewidth=4, zorder=2)
        
    ax.set_xlabel('Online Queries')
    ax.set_ylabel('Cumulative Regret')
    # ax.set_yscale('log')
        
    plt.tight_layout()
    task_name = task.replace('_', '-')
    q_version = query["filters"]["config.version"]
    plt.savefig(f"./figures/tab_bandits_{key.replace('_', '-')}_{task_name}_{q_version}.pdf")

In [None]:
figlegend = plt.figure(figsize=(2,2))
handles = ax.get_legend_handles_labels()
plt.legend(*handles, 
           loc ='upper left',
           fontsize=32,
           ncol=4
          )
plt.axis("off")
plt.savefig(
    f"./figures/tab_bandits_{key.replace('_', '-')}_legend_{q_version}.pdf", bbox_inches="tight"
)