In [None]:
import numpy as np 
import matplotlib.pyplot as plt 
import matplotlib as mpl
import pickle 
import os 
from cycler import cycler
from collections import Counter
from tqdm.notebook import tqdm 

%load_ext autoreload
%autoreload 2

from policies import Egreedy, Random, UCB, TS, Exp3
from model import Ct
from utils import minmax_containment, experiment_trial, gen_regions

# Test Bandit Policies

Lines that are commented out were used to generated the population estimate figures 

In [None]:
n_steps = 30
n_trials = 5
beta = 0.001

gamma = 0.5

budgets = [500, 900, 1500, 2000]
n_regions = [10]

results = {n: {} for n in n_regions}
pop_estimates = {n: {} for n in n_regions}

for n_regs in tqdm(n_regions, desc='regions'): 

    for budget in tqdm(budgets, desc='budgets', leave=False): 

        for s in tqdm(range(n_trials), desc='trials', leave=False): 
            
            
            # Define all the policies here 
            thomp = TS(n_regs, gamma=gamma, seed=s)
            ucb68 = UCB(n_regs, gamma=gamma, alpha=0.16, seed=s)
            ucb95 = UCB(n_regs, gamma=gamma, alpha=0.025, seed=s)
            exp310 = Exp3(n_regs, eps=0.1, gamma=gamma, seed=s)
            egreedy1 = Egreedy(n_regs, eps=0.01, gamma=gamma, seed=s)
            egreedy10 = Egreedy(n_regs, eps=0.1, gamma=gamma, seed=s)
            rand = Random(n_regs, seed=s)
            
            # Add them to this list to ensure they're used 
            policies = [
                thomp, ucb68, ucb95, exp310, egreedy1, egreedy10, rand
            ]
#            policies = [ucb68]
            
            np.random.seed(s)
            regions = gen_regions(n_regs, policies)
            regions, pop_ests = experiment_trial(
                regions, budget, beta, n_steps, policies
            )
            

            for policy in policies:
                
                # Update results 
                if policy.name not in results[n_regs].keys():
                    results[n_regs][policy.name] = {b: [] for b in budgets}
                results[n_regs][policy.name][budget].append(
                    np.sum(
                        [regions[i][policy.name]['cases_true'][-1] 
                         for i in regions.keys()]
                    )
                )
                
                # Update population estimates 
                if policy.name not in pop_estimates[n_regs].keys():
                    pop_estimates[n_regs][policy.name] = {
                        b: {'est': [], 'true': []} for b in budgets
                    }
                for key in ['est', 'true']:
                    pop_estimates[n_regs][policy.name][budget][key].append(
                        pop_ests[policy.name][key]
                    )
                    


In [None]:
with open(f'local/comparison_n_steps_{n_steps}_{n_trials}_beta_{beta}_'
        f'max_budget_{max(budgets)}_n_policies_{len(policies)}.p', 'wb') as f:
    pickle.dump(results, f)

### Plot improvement over random

In [None]:
# Global plotting params
mpl.rcParams['font.family'] = 'Times'
mpl.rcParams['font.size'] = '15'
mpl.rcParams['axes.linewidth'] = 0.4
mpl.rcParams['axes.edgecolor'] = 'gray'
mpl.rcParams['text.usetex'] = True

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(14,8), sharex=True)
ax = ax.ravel()
keywd = budgets
fill = True

c_wheel = [
        'tab:blue', 'tab:purple', 'tab:olive', 'darkgreen', 'tab:cyan',
        'gold', 'pink', 'lightgreen'
]
ms = ['o', 'v', 'D', 's', '*', '^']


for i, n_regs in enumerate(n_regions):
    
    rand_performance = np.array([results[n_regs]['Random'][k] for k in keywd])
    for j, policy in enumerate(policies):
        if policy.name == 'Random': continue 
        diff = rand_performance - np.array([results[n_regs][policy.name][k] for k in keywd])
        means = np.mean(diff, axis=1)
        ax[i].plot(keywd, means, label=policy.formal_name, lw=2, c=c_wheel[j], marker=ms[j])
        
        if fill: 
            ymin, ymax = minmax_containment(np.transpose(diff), 0.68)
            ax[i].fill_between(keywd, ymin, ymax, alpha=0.5, color=c_wheel[j])
        

    ax[i].set_title(f'{n_regs} regions')

handles, labels = ax[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=int(len(policies)/1), bbox_to_anchor=(0.5, -0.02))

fig.suptitle('Improvement over random sampling', y=1.0, size=20)
fig.text(0.5, 0.04, 'Budget', ha='center', size=16)
fig.text(0.07, 0.5, 'Reduction in Cases', va='center', rotation='vertical', size=16)

fig.savefig(f'local/comparison_n_steps_{n_steps}_{n_trials}_beta_{beta}_'
            f'max_budget_{max(budgets)}_fill_{fill}.jpeg', dpi=300, bbox_inches='tight')

### Plot population estimates

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(14,4), sharex=True)
ax = ax.ravel()

n_regs = 10
budgets = [500,900, 1500,2000]

policy = ucb68

c_wheel = ['tab:olive', 'tab:green', 'tab:purple', 'tab:cyan']
for i, budget in enumerate(budgets): 

    ests = np.array(pop_estimates[n_regs][policy.name][budget]['est'])[:, 1:]
    true = np.array(pop_estimates[n_regs][policy.name][budget]['true'])[:, 1:]
    
    ax[0].plot(range(1, n_steps), np.mean(ests, axis=0), 
               lw=2, color=c_wheel[i])
    ax[0].plot(range(1, n_steps), np.mean(true, axis=0), 
               lw=3, color=c_wheel[i], marker='o')
    
    ymin_est, ymax_est = minmax_containment(ests, 0.68)
    ax[0].fill_between(range(1, n_steps), ymin_est, ymax_est, alpha=0.5, color=c_wheel[i])
    
    
    diff = ests - true
    ax[1].plot(range(1, n_steps), np.mean(diff, axis=0), label=f'Budget {budget}', color=c_wheel[i])
    ymin, ymax = minmax_containment(diff, 0.95)
    ax[1].fill_between(range(1, n_steps), ymin, ymax, alpha=0.5, color=c_wheel[i])

ax[0].set_ylabel('Prevalence Rate', x=1)
ax[1].set_ylabel('Difference from true')
ax[1].yaxis.set_label_coords(-0.14,0.5)
ax[0].set_xlabel('Timestep')
ax[1].set_xlabel('Timestep')    
ax[0].plot([], [], c='k', marker='o', label='True')    
ax[0].legend()    
handles, labels = ax[1].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=4, bbox_to_anchor=(0.5, -0.06))

fig.suptitle('UCB Population Estimates', y=1.0, size=20)
fig.savefig(f'local/figs/UCB_pop_estimate.png',
           bbox_inches='tight', dpi=300)