In [8]:
## Code is modified from source released by 
## https://github.com/chueh-ermon/battery-fast-charging-optimization

import importlib

from bax.env import bandit_ucb_public as bucb
from bax.utils import trainer_public as trainer
from bax.utils import plotter_public as plotter 

from collections import defaultdict

import glob
import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt
from matplotlib import rcParams

import matplotlib.cm as cm
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import matplotlib.patheffects as pe

import numpy as np
import pickle

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import ray
ray.init(log_to_driver=False, _temp_dir='/tmp/ray/')

In [None]:
def gen_ci_fn(alpha):
    """generate the function used for confidence interval."""
    def ci_fn(T, times_pulled):
        print('running with ')
        if times_pulled == 0:
            return np.inf
        num = 2*(T**alpha - 1)
        den = alpha * times_pulled
        
        if alpha == 0:
            ci = np.sqrt(np.log(T)/times_pulled)
        else:
            ci = np.sqrt(num/den)
        return ci
    return ci_fn

In [12]:
PRIOR_HI = np.load('../bax/env/battery_hi.npy')
max_mean = np.max(PRIOR_HI[:, 0])

In [None]:
## Let's use prior for the battery dataset "hi"
def get_prior(prior, num_arms, seed=10, max_mean=1200):
    from numpy.random import default_rng
    rng = default_rng(seed)
    
    total_arms = prior.shape[0]
    means = prior[:, 0]
    SIGMA = np.max(prior[:, 1])
    MU = means / max_mean

    norm_SIGMA = SIGMA / max_mean
    SIGMA2 = norm_SIGMA**2
    
    
    if num_arms == 'all':
        armidx = np.arange(total_arms)
        mu_argsort = np.argsort(MU)
        PRIOR = np.array([[MU[idx], SIGMA2] for idx in mu_argsort])
        return PRIOR, mu_argsort
    
    armidx = rng.choice(np.arange(total_arms), num_arms, replace=False)
    MUs = MU[armidx]
    mu_argsort = np.argsort(MUs)
    PRIOR = np.array([[MUs[idx], SIGMA2] for idx in mu_argsort])
    return PRIOR, armidx

In [None]:
prior, arms = get_prior(PRIOR_HI, 'all', max_mean=max_mean)

In [None]:
num_horizon = 3  # number of distinct T

Ts = [25000, 45000, 70000]

num_runs= 100
results_all = {}

ALPHAS=[0.001]
for alpha in ALPHAS:
    print("COLLECTING DATASET FOR ALPHA={}".format(alpha))
    results_dict = trainer.train_helper_for_horizon(prior=prior, 
                                                    Ts=Ts, algo='ucb', 
                                                    ci_fn=gen_ci_fn,
                                                    alpha=alpha,
                                                    num_runs=num_runs)
    results_all[alpha] = results_dict

In [None]:
## get the true gaps for given prior
true_gap = plotter.get_true_gaps(prior)

## get estimated gaps per arm 
gaps_all = {k: v['gaps'] for k, v in results_all.items()}

## compute the error between estimate and true gaps
errors_all = get_mse(gaps_all, true_gap, norm=None)

## mean error, confidence-interval across multiple runs
emean, eci = get_mean_ci(errors_all, arm=None)

In [13]:
## param-space maps arms to current values
param_file = "../bax/env/paramspace.pkl"
with open(param_file, 'rb') as infile:
    param_space, ub, lb, mean = pickle.load(infile)

In [None]:
## we ignore the arm with largest reward
dataset_error = emean[0.001][:, :223]
min_lifetime = np.min(dataset_error)
max_lifetime = np.max(dataset_error)

In [None]:
def text(x1,y1,x2,y2,k):
    ax.annotate("T= "+str(k), xy=(x2, y1), xycoords='figure fraction',
                xytext=(x1, y1), textcoords='figure fraction',
                size=20, va="center", ha="center",
                bbox=dict(boxstyle="round", fc="w"))

    
fig, axes = plt.subplots(3,3,figsize=(16,16))
axes[0,0].set_axis_off()
axes[0,1].set_axis_off()
axes[0,2].set_axis_off()
axes[1,0].set_axis_off()
axes[1,1].set_axis_off()
axes[1,2].set_axis_off()
axes[2,0].set_axis_off()
axes[2,1].set_axis_off()
axes[2,2].set_axis_off()


fig.set_size_inches(w=15,h=11)

FS = 16
LW = 3

rcParams['pdf.fonttype'] = 42
rcParams['ps.fonttype'] = 42
rcParams['font.size'] = FS
rcParams['axes.labelsize'] = FS
rcParams['xtick.labelsize'] = FS
rcParams['ytick.labelsize'] = FS
rcParams['font.sans-serif'] = ['Arial']


##############################################################################
# PLOTTING PARAMETERS
batches_to_plot = [0,1,2]

colormap = 'plasma_r'
el, az = 30, 240
point_size = 50
##############################################################################
        
## MAKE SUBPLOTS
for k, batch_idx in enumerate(batches_to_plot):
    with sns.axes_style('white'):
        if k==0:
            ax = fig.add_axes([0.05,0.48,0.24,0.24],projection='3d')
        else:
            ax = fig.add_axes([0.05+0.165*k,0.48,0.24,0.24],projection='3d')
    ax.set_aspect('equal')
    
    ## PLOT POLICIES
    CC1 = param_space[arms[:-1],0]
    CC2 = param_space[arms[:-1],1]
    CC3 = param_space[arms[:-1],2]
    lifetime = dataset_error[batch_idx][:]
    with plt.style.context(('classic')):
        plt.set_cmap(colormap)
        ax.scatter(CC1,CC2,CC3, s=point_size, c=lifetime.ravel(),
               vmin=min_lifetime, vmax=max_lifetime)
    
    ax.set_xlim([3, 8]), ax.set_ylim([3, 8]), ax.set_zlim([3, 8])
   
    if k == 0:
        ax.set_xlabel('CC1',fontsize=FS)
        ax.set_ylabel('CC2',fontsize=FS)
        ax.set_zlabel('CC3',fontsize=FS,rotation=90)
    #ax.set_title('Before batch '+str(batch_idx))
    
    
    ax.view_init(elev=el, azim=az)

# ADD COLORBAR
cbar_ax = fig.add_axes([0.65, 0.45, 0.02, 0.3]) # [left, bottom, width, height]
norm = matplotlib.colors.Normalize(min_lifetime, max_lifetime)
print(norm, min_lifetime, max_lifetime)
m = plt.cm.ScalarMappable(norm=norm, cmap=colormap)
m.set_array([])

cbar = plt.colorbar(m, cax=cbar_ax)
cbar.ax.tick_params(labelsize=FS,length=0)
cbar.ax.set_title('MSE  $\mathbb{E} | \hat{\mu}_i - \mu_i|^2$',fontsize=FS)


margin = 0.18
for k in np.arange(3):
    text(0.15+0.18*k,0.78,0.22+0.18*k,0.78,Ts[k])


plt.tight_layout()
plt.savefig('battery_allarms_errorlandscape.png',bbox_inches='tight')
plt.savefig('battery_allarms_errorlandscape.pdf',bbox_inches='tight',format='pdf')