In [1]:
import jax
import jax.numpy as jnp

from gymnax.environments import spaces
from typing import Optional

import sys
sys.path.append('..')

from experior.utils import PRNGSequence
from experior.envs import BayesStochasticBandit
from experior.bandit_agents import (
  make_thompson_sampling,
  make_max_ent_thompson_sampling,
  make_bernoulli_thompson_sampling,
  make_multi_armed_explore_ucb,
  make_multi_armed_ucb,
  make_multi_armed_bc,
  LinearDiscreteRewardModel
)
from experior.experts import generate_optimal_trajectories

jax.config.update('jax_enable_x64', True)

import logging
logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')


%load_ext autoreload
%autoreload 2


### Hyperparameter Search


We tune the hyperparameters for all baselines by optimizing the Bayesian regret for a beta task distribution $\mu^\star = Beta(0.5, 0.5)$ with $K = 10$ and $T = 1024$. We sample $N_{\text{task}} = 100$ to estimate the Bayesian regret.


In [2]:
NUM_ACTIONS = 5
NUM_ENVS = 100
NUM_STEPS = 1024
SEED = 42

#### Naive-TS


We search over the following hyperparameters:

- SGLD batch size in `[128, 256, 1024]`
- Number of Langevin update steps per episode `[1, 5, 10, 20]`
- SGLD learning rate `[1e-1, 1e-2, 5e-2, 1e-3]`


In [3]:
# bernoulli bandit setup
alpha, beta = 0.5, 0.5
action_space = spaces.Discrete(NUM_ACTIONS)
prior_function = jax.tree_util.Partial(
    lambda key, _: jax.random.beta(
        key,
        alpha * jnp.ones((NUM_ACTIONS,)),
        beta * jnp.ones((NUM_ACTIONS,)),
        shape=(NUM_ACTIONS,),
    )
)
reward_dist_fn = jax.tree_util.Partial(
    lambda key, means, _, action: jax.random.bernoulli(key, means[action]).astype(
        jnp.float32
    )
)
best_action_value_fn = jax.tree_util.Partial(
    lambda means, _: (means.argmax(), means.max())
)
reward_mean_fn = jax.tree_util.Partial(lambda means, _, action: means[action])
mutli_armed_bandit = BayesStochasticBandit(
    action_space, prior_function, reward_dist_fn, reward_mean_fn, best_action_value_fn
)

feature_fn = jax.tree_util.Partial(
    lambda obs, action: jax.nn.one_hot(action, NUM_ACTIONS)
)

reward_model = LinearDiscreteRewardModel(
    n_actions=NUM_ACTIONS,
    params_dim=NUM_ACTIONS,
    feature_fn=feature_fn,
    dist="bernoulli",
)

# search over SGLD batch size, updates per step, and learning rate,
b_sizes = [128, 256, 1024]
steps = [1, 5, 10, 20]
langevin_learning_rates = jnp.array([1e-1, 1e-2, 5e-2, 1e-3])
results = jnp.zeros((len(b_sizes), len(steps), langevin_learning_rates.shape[0]))

for i, b_size in enumerate(b_sizes):
    for j, update_per_step in enumerate(steps):
        ts_train = make_thompson_sampling(
            env=mutli_armed_bandit,
            reward_model=reward_model,
            num_envs=NUM_ENVS,
            total_steps=NUM_STEPS,
            langevin_batch_size=b_size,
            langevin_updates_per_step=update_per_step,
        )

        jit_train_ts_hyper = jax.vmap(jax.jit(ts_train), in_axes=(None, 0))

        state, hyper_metrics = jit_train_ts_hyper(
            jax.random.PRNGKey(SEED), langevin_learning_rates
        )  # shape: (n_lrs, n_steps, n_envs, ...)
        bayes_regret = (
            (hyper_metrics["optimal_value"] - hyper_metrics["reward_mean"])
            .mean(axis=-1)
            .sum(axis=-1)
        )  # shape: (n_lrs,)
        results = results.at[i, j, :].set(bayes_regret)
        logging.info(f"batch size: {b_size}, updates per step: {update_per_step} done!")

2024-04-09 15:23:07 INFO     Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter
2024-04-09 15:23:07 INFO     Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
2024-04-09 15:23:15 INFO     batch size: 128, updates per step: 1 done!
2024-04-09 15:23:20 INFO     batch size: 128, updates per step: 5 done!
2024-04-09 15:23:22 INFO     batch size: 128, updates per step: 10 done!
2024-04-09 15:23:25 INFO     batch size: 128, updates per step: 20 done!
2024-04-09 15:23:28 INFO     batch size: 256, updates per step: 1 done!
2024-04-09 15:23:30 INFO     batch size: 256, updates per step: 5 done!
2024-04-09 15:23:33 INFO     batch size: 256, updates per step: 10 done!
2024-04-09 15:23:36 INFO     batch size: 256, updates per step: 20 done!
2024-04-09 15:23:38 INFO     batch size: 1024, updates per step: 1 done!
2024-04-09 15:23:41 INFO     batch size:

In [4]:
indices = jnp.where(results == results.min())
logging.info(
    f"""
    Best hyperparameters: batch: {b_sizes[indices[0][0]]},
    langevin steps: {steps[indices[1][0]]},
    learning rate: {langevin_learning_rates[indices[2][0]]},
    regret: {results.min()}"""
)

2024-04-09 15:23:49 INFO     
    Best hyperparameters: batch: 1024,
    langevin steps: 5,
    learning rate: 0.05000000074505806,
    regret: 18.292362213134766


#### ExPerior


For ExPerior, we search over the following hyperparameters:

- SGLD learning rate `[1e-2, 5e-2, 1e-3, 2.5e-4]`
- Lagrange multiplier ($\lambda^\star$) `[0.1, 1.0, 10.0, 50.0, 100.0]`
- Expert competence level ($\beta$) `[0.1, 1.0, 3.0, 10.0]`


In [5]:
# we fix other hyperparameters as

B_SIZE = 1024
UPDATES_PER_STEP = 5

MAX_ENT_SEED = 712
MAX_ENT_LR = 1e-2  # learning rate to optimize (6) in Proposition 1
MAX_ENT_STEPS = 1000  # number of steps to optimize (6) in Proposition 1
MAX_ENT_SAMPLES = (
    1024  # number of samples to estimate the expectation in (6) in Proposition 1
)

EXPERT_TRAJ_SEED = 1233
N_EXPERT_TRAJECTORY = 1000  # number of expert trajectories `N`` in Proposition 1

In [6]:
# generate expert trajectories
expert_trajectories = generate_optimal_trajectories(
    jax.random.PRNGKey(EXPERT_TRAJ_SEED),
    mutli_armed_bandit,
    N_EXPERT_TRAJECTORY,
    1,
    None,
)

In [7]:
reward_model = LinearDiscreteRewardModel(
    n_actions=NUM_ACTIONS,
    params_dim=NUM_ACTIONS,
    feature_fn=feature_fn,
    dist="bernoulli",
)

max_ent_ts_train = make_max_ent_thompson_sampling(
    env=mutli_armed_bandit,
    reward_model=reward_model,
    num_envs=NUM_ENVS,
    total_steps=NUM_STEPS,
    langevin_batch_size=B_SIZE,
    langevin_updates_per_step=UPDATES_PER_STEP,
    max_ent_prior_n_samples=MAX_ENT_SAMPLES,
    max_ent_steps=MAX_ENT_STEPS,
)


langevin_learning_rates = jnp.array([1e-2, 5e-2, 1e-3, 2.5e-4])
max_ent_lambdas = jnp.array([0.1, 1.0, 10.0, 50.0, 100.0])
expert_betas = jnp.array([0.1, 1.0, 3.0, 10.0])

jit_max_ent_train_ts_hyper = jax.vmap(
    jax.vmap(
        jax.vmap(
            jax.jit(max_ent_ts_train),
            in_axes=(None, None, None, None, None, None, 0),
        ),
        in_axes=(None, None, None, None, None, 0, None),
    ),
    in_axes=(None, None, None, 0, None, None, None),
)

state, max_ent_state, hyper_metrics = jit_max_ent_train_ts_hyper(
    jax.random.PRNGKey(MAX_ENT_SEED),
    jax.random.PRNGKey(SEED),
    expert_trajectories,
    max_ent_lambdas,
    MAX_ENT_LR,
    expert_betas,
    langevin_learning_rates,
)  # shape: (n_lambdas, n_betas, n_lrs, n_steps, n_envs ...)
bayes_regret = (
    (hyper_metrics["optimal_value"] - hyper_metrics["reward_mean"])
    .mean(axis=-1)
    .sum(axis=-1)
)  # shape: (n_lambdas, n_betas, n_lrs)

In [8]:
indices = jnp.where(bayes_regret == bayes_regret.min())
logging.info(
    f"""
    Best hyperparameters: lambda: {max_ent_lambdas[indices[0][0]]},
    beta: {expert_betas[indices[1][0]]},
    learning rate: {langevin_learning_rates[indices[2][0]]},
    regret: {bayes_regret.min()}"""
)

2024-04-09 15:24:23 INFO     
    Best hyperparameters: lambda: 1.0,
    beta: 10.0,
    learning rate: 0.05000000074505806,
    regret: 18.13381576538086


#### Naive-UCB and UCB-ExPLORe


For Naive-UCB, we only search over a constant factor `rho` $\in$ `[1, 2, 4, 8]` to scale the confidene interval.

For UCB-ExPLORe, we also have a `burn_in` $\in$ `[0, 5, 10, 20, 50, 100]` parameter, which is the number of steps of running Naive-UCB before applying the optimistic rewards.


In [9]:
rhos = jnp.array([1, 2, 4, 8])
rng = jax.random.PRNGKey(SEED)

ucb_train = make_multi_armed_ucb(
    env=mutli_armed_bandit, num_envs=NUM_ENVS, total_steps=NUM_STEPS
)
state, hyper_metrics = jax.vmap(jax.jit(ucb_train), in_axes=(None, 0))(rng, rhos)
bayes_regret = (
    (hyper_metrics["optimal_value"] - hyper_metrics["reward_mean"])
    .mean(axis=-1)
    .sum(axis=-1)
)  # shape: (n_lrs,)

UCB_RHO = rhos[bayes_regret.argmin()]

# ucb explore
expert_fractions = (
    jnp.bincount(expert_trajectories.action.reshape(-1), length=NUM_ACTIONS)
    / N_EXPERT_TRAJECTORY
)

burn_ins = jnp.array([0, 5, 10, 20, 50, 100])
ucb_train = make_multi_armed_explore_ucb(
    env=mutli_armed_bandit, num_envs=NUM_ENVS, total_steps=NUM_STEPS
)
state, hyper_metrics = jax.vmap(jax.jit(ucb_train), in_axes=(None, None, None, 0))(
    rng, expert_fractions, UCB_RHO, burn_ins
)
bayes_regret = (
    (hyper_metrics["optimal_value"] - hyper_metrics["reward_mean"])
    .mean(axis=-1)
    .sum(axis=-1)
)  # shape: (n_lrs,)
UCB_BURN_IN = burn_ins[bayes_regret.argmin()]

logging.info(f"UCB rho: {UCB_RHO}, burn in: {UCB_BURN_IN}")

2024-04-09 15:24:44 INFO     UCB rho: 1, burn in: 50


### Final Hyperparams


In [2]:
ORACLE_TS_SGLD_BSIZE = 1024
ORACLE_TS_UPDATES_PER_STEP = 5
ORACLE_TS_SGLD_LR = 0.05

EXPERIOR_SGLD_BSIZE = 1024
EXPERIOR_UPDATES_PER_STEP = 5
EXPERIOR_SGLD_LR = 0.05
EXPERIOR_LAMBDA = 1.0
EXPERIOR_BETA = 10.0

MAX_ENT_LR = 1e-2
MAX_ENT_STEPS = 1000
MAX_ENT_SAMPLES = 1024

UCB_BURN_IN = 50
UCB_RHO = 1

In [3]:
# setup the environment
NUM_STEPS = 1500
NUM_ENVS = 128
NUM_PRIORS = 256
EXPERT_N_TRAJECTORY = 500
HORIZON = 1

ENV_SEED = 42
MAX_ENT_SEED = 512
SETTING_SEED = 2048

### Comparison to Baselines and Empirical Regret Analysis


In [12]:
NUM_ACTIONS_LIST = range(2, 11)

naive_ts = []
naive_ucb = []
ucb_explore = []
oracle_ts = []
bc = []
experior = []
expert_entropies_list = []

for NUM_ACTIONS in NUM_ACTIONS_LIST:
    rng = PRNGSequence(SETTING_SEED)

    alpha_betas = jax.random.beta(next(rng), 1.0, 1.0, (NUM_PRIORS, NUM_ACTIONS, 2)) * 4

    action_space = spaces.Discrete(NUM_ACTIONS)

    def prior_function(key, i: Optional[int] = 0):
        return jax.random.beta(key, alpha_betas[i, :, 0], alpha_betas[i, :, 1])

    prior_function = jax.tree_util.Partial(prior_function)
    reward_dist_fn = jax.tree_util.Partial(
        lambda key, means, _, action: jax.random.bernoulli(key, means[action]).astype(
            jnp.float32
        )
    )
    best_action_value_fn = jax.tree_util.Partial(
        lambda means, _: (means.argmax(), means.max())
    )
    reward_mean_fn = jax.tree_util.Partial(lambda means, _, action: means[action])
    mutli_armed_bandit = BayesStochasticBandit(
        action_space,
        prior_function,
        reward_dist_fn,
        reward_mean_fn,
        best_action_value_fn,
    )

    # generate the expert trajectories
    expert_trajectories = jax.vmap(
        generate_optimal_trajectories, in_axes=(None, None, None, None, 0)
    )(
        next(rng),
        mutli_armed_bandit,
        EXPERT_N_TRAJECTORY,
        1,
        jnp.arange(alpha_betas.shape[0]),
    )  # shape: (n_priors, n_steps, n_envs, ...)

    expert_fractions = jax.vmap(
        lambda a: jnp.bincount(a.reshape(-1), length=NUM_ACTIONS) / EXPERT_N_TRAJECTORY
    )(expert_trajectories.action)

    # calculate the entropy of each expert trajectory
    epsilon = 1e-8
    entropy_fn = lambda exp_traj: -jax.vmap(lambda p: p * jnp.log2(p + epsilon))(
        jnp.bincount(exp_traj.action.flatten(), length=NUM_ACTIONS)
        / EXPERT_N_TRAJECTORY
    ).sum()
    expert_entropies = jax.vmap(entropy_fn)(expert_trajectories)

    expert_entropies_list.append(expert_entropies)

    # train the naive thompson sampling
    feature_fn = jax.tree_util.Partial(
        lambda obs, action: jax.nn.one_hot(action, NUM_ACTIONS)
    )
    reward_model = LinearDiscreteRewardModel(
        n_actions=NUM_ACTIONS,
        params_dim=NUM_ACTIONS,
        feature_fn=feature_fn,
        dist="bernoulli",
    )

    ts_train = make_thompson_sampling(
        env=mutli_armed_bandit,
        reward_model=reward_model,
        num_envs=NUM_ENVS,
        total_steps=NUM_STEPS,
        langevin_batch_size=ORACLE_TS_SGLD_BSIZE,
        langevin_updates_per_step=ORACLE_TS_UPDATES_PER_STEP,
    )

    state, no_prior_metrics = jax.vmap(jax.jit(ts_train), in_axes=(None, None, 0))(
        jax.random.PRNGKey(ENV_SEED),
        ORACLE_TS_SGLD_LR,
        jnp.arange(alpha_betas.shape[0]),
    )  # shape: (n_priors, n_steps, n_envs, ...)

    naive_ts.append(
        (no_prior_metrics["optimal_value"] - no_prior_metrics["reward_mean"])
        .mean(axis=-1)
        .cumsum(axis=-1)
    )
    logging.info(f"Naive-TS done for {NUM_ACTIONS} actions!")

    # train the oracle thompson sampling
    prior_alpha_betas = lambda i: alpha_betas[i, :, :]
    true_prior_ts_train = make_bernoulli_thompson_sampling(
        env=mutli_armed_bandit,
        num_envs=NUM_ENVS,
        total_steps=NUM_STEPS,
        prior_alpha_betas=prior_alpha_betas,
    )

    state, true_prior_metrics = jax.vmap(
        jax.jit(true_prior_ts_train), in_axes=(None, 0)
    )(jax.random.PRNGKey(ENV_SEED), jnp.arange(alpha_betas.shape[0]))

    oracle_ts.append(
        (true_prior_metrics["optimal_value"] - true_prior_metrics["reward_mean"])
        .mean(axis=-1)
        .cumsum(axis=-1)
    )
    logging.info(f"Oracle-TS done for {NUM_ACTIONS} actions!")

    # train the max entropy thompson sampling (ExPerior)
    reward_model = LinearDiscreteRewardModel(
        n_actions=NUM_ACTIONS,
        params_dim=NUM_ACTIONS,
        feature_fn=feature_fn,
        dist="bernoulli",
    )
    max_ent_ts_train = make_max_ent_thompson_sampling(
        env=mutli_armed_bandit,
        reward_model=reward_model,
        num_envs=NUM_ENVS,
        total_steps=NUM_STEPS,
        langevin_batch_size=EXPERIOR_SGLD_BSIZE,
        langevin_updates_per_step=EXPERIOR_UPDATES_PER_STEP,
        max_ent_prior_n_samples=MAX_ENT_SAMPLES,
        max_ent_steps=MAX_ENT_STEPS,
    )
    state, max_ent_state, max_ent_metrics = jax.vmap(
        jax.jit(max_ent_ts_train),
        in_axes=(None, None, 0, None, None, None, None, 0),
    )(
        jax.random.PRNGKey(MAX_ENT_SEED),
        jax.random.PRNGKey(ENV_SEED),
        expert_trajectories,
        EXPERIOR_LAMBDA,
        MAX_ENT_LR,
        EXPERIOR_BETA,
        EXPERIOR_SGLD_LR,
        jnp.arange(alpha_betas.shape[0]),
    )
    experior.append(
        (max_ent_metrics["optimal_value"] - max_ent_metrics["reward_mean"])
        .mean(axis=-1)
        .cumsum(axis=-1)
    )
    logging.info(f"ExPerior done for {NUM_ACTIONS} actions!")

    # naive ucb
    ucb_train = make_multi_armed_ucb(
        env=mutli_armed_bandit, num_envs=NUM_ENVS, total_steps=NUM_STEPS
    )
    state, ucb_metrics = jax.vmap(jax.jit(ucb_train), in_axes=(None, None, 0))(
        jax.random.PRNGKey(ENV_SEED), UCB_RHO, jnp.arange(alpha_betas.shape[0])
    )
    naive_ucb.append(
        (ucb_metrics["optimal_value"] - ucb_metrics["reward_mean"])
        .mean(axis=-1)
        .cumsum(axis=-1)
    )
    logging.info(f"Naive-UCB done for {NUM_ACTIONS} actions!")

    # explore ucb
    ucb_train = make_multi_armed_explore_ucb(
        env=mutli_armed_bandit, num_envs=NUM_ENVS, total_steps=NUM_STEPS
    )
    state, explore_ucb_metrics = jax.vmap(
        jax.jit(ucb_train), in_axes=(None, 0, None, None, 0)
    )(
        jax.random.PRNGKey(ENV_SEED),
        expert_fractions,
        UCB_RHO,
        UCB_BURN_IN,
        jnp.arange(alpha_betas.shape[0]),
    )

    ucb_explore.append(
        (explore_ucb_metrics["optimal_value"] - explore_ucb_metrics["reward_mean"])
        .mean(axis=-1)
        .cumsum(axis=-1)
    )
    logging.info(f"UCB-ExPLORe done for {NUM_ACTIONS} actions!")

    # bc
    bc_train = make_multi_armed_bc(
        env=mutli_armed_bandit, num_envs=NUM_ENVS, total_steps=NUM_STEPS
    )
    state, bc_metrics = jax.vmap(jax.jit(bc_train), in_axes=(None, 0, 0))(
        jax.random.PRNGKey(ENV_SEED), expert_fractions, jnp.arange(alpha_betas.shape[0])
    )

    bc.append(
        (bc_metrics["optimal_value"] - bc_metrics["reward_mean"])
        .mean(axis=-1)
        .cumsum(axis=-1)
    )
    logging.info(f"BC done for {NUM_ACTIONS} actions!")

2024-04-09 15:25:53 INFO     Naive-TS done for 2 actions!
2024-04-09 15:25:57 INFO     Oracle-TS done for 2 actions!
2024-04-09 15:27:38 INFO     ExPerior done for 2 actions!
2024-04-09 15:27:39 INFO     Naive-UCB done for 2 actions!
2024-04-09 15:27:40 INFO     UCB-ExPLORe done for 2 actions!
2024-04-09 15:27:41 INFO     BC done for 2 actions!
2024-04-09 15:28:30 INFO     Naive-TS done for 3 actions!
2024-04-09 15:28:34 INFO     Oracle-TS done for 3 actions!
2024-04-09 15:30:05 INFO     ExPerior done for 3 actions!
2024-04-09 15:30:07 INFO     Naive-UCB done for 3 actions!
2024-04-09 15:30:08 INFO     UCB-ExPLORe done for 3 actions!
2024-04-09 15:30:10 INFO     BC done for 3 actions!
2024-04-09 15:30:51 INFO     Naive-TS done for 4 actions!
2024-04-09 15:30:55 INFO     Oracle-TS done for 4 actions!
2024-04-09 15:32:27 INFO     ExPerior done for 4 actions!
2024-04-09 15:32:29 INFO     Naive-UCB done for 4 actions!
2024-04-09 15:32:30 INFO     UCB-ExPLORe done for 4 actions!
2024-04-09 

### Save the Results


In [13]:
jnp.savez(
    "../output/bandit/regret_results_temp.npz",
    naive_ts=naive_ts,
    naive_ucb=naive_ucb,
    ucb_explore=ucb_explore,
    oracle_ts=oracle_ts,
    bc=bc,
    experior=experior,
    expert_entropies_list=expert_entropies_list,
    num_actions_list=NUM_ACTIONS_LIST,
)

### Plot the Results


In [4]:
# load the results
results = jnp.load("../output/bandit/regret_results.npz")

In [5]:
import matplotlib as mpl
import matplotlib.pyplot as plt
from experior.utils import (
    latexify,
    FIG_WIDTH,
    GOLDEN_RATIO,
    FONT_SIZE,
    LEGEND_SIZE,
    LIGHT_COLORS,
)

mpl.use("pdf")

#### Comparison to baselines


In [6]:
# fix actions to 10
action_ind = -1
low_entropies = results["expert_entropies_list"][action_ind] < 0.8
mid_entropies = (results["expert_entropies_list"][action_ind] >= 0.8) & (
    results["expert_entropies_list"][action_ind] <= 1.6
)
high_entropies = results["expert_entropies_list"][action_ind] > 1.6
# plot the mean bayes regret for each method under different entropy levels - use bar plots
latexify(
    FIG_WIDTH,
    FIG_WIDTH * GOLDEN_RATIO * 0.3,
    font_size=FONT_SIZE,
    legend_size=LEGEND_SIZE,
    labelsize=LEGEND_SIZE,
)
fig, ax = plt.subplots(1, 1, figsize=(FIG_WIDTH, FIG_WIDTH * GOLDEN_RATIO * 0.3))
method_names = [
    r"Oracle-TS",
    r"ExPerior ({Ours})",
    r"Naïve-TS",
    r"Naïve-UCB",
    r"UCB-ExPLORe",
    r"BC",
]
keys = [
    "oracle_ts",
    "experior",
    "naive_ts",
    "naive_ucb",
    "ucb_explore",
    "bc",
]
method_regrets = [results[key][action_ind] for key in keys]
ax.set_ylabel(r"Bayesian Regret")
width = 0.2
x_values = [
    j for i in range(len(method_names)) for j in [i - 1.1 * width, i, i + 1.1 * width]
]
colors = [LIGHT_COLORS["blue"], LIGHT_COLORS["green"], LIGHT_COLORS["red"]]
hatches = ["\\\\\\\\\\", "---", "/////"]
ax.set_xticks(
    x_values,
    [j for i in range(len(method_names)) for j in ["", method_names[i], ""]],
    ha="center",
    va="top",
)
means = [
    j
    for i in range(len(method_names))
    for j in [
        method_regrets[i][low_entropies].mean(),
        method_regrets[i][mid_entropies].mean(),
        method_regrets[i][high_entropies].mean(),
    ]
]
ax.set_yticks(range(0, 151, 50))
ax.bar(
    x_values,
    means,
    width=width,
    color=colors * len(method_names),
    hatch=hatches * len(method_names),
)

import matplotlib.patches as mpatches

labels = [r"Low Entropy", r"Mid Entropy", r"High Entropy"]
handles = [
    mpatches.Patch(facecolor=colors[i], label=labels[i], hatch=hatches[i])
    for i in range(len(labels))
]
ax.legend(handles=handles, ncol=1)

plt.subplots_adjust(left=0.07, right=0.95, bottom=0.24)

fig.savefig("../output/bandit/baselines.pdf")

#### Empirical regret analysis


In [7]:
latexify(
    0.75 * FIG_WIDTH,
    FIG_WIDTH * GOLDEN_RATIO / 3,
    font_size=7,
    legend_size=LEGEND_SIZE,
    labelsize=LEGEND_SIZE,
)
fig, axes = plt.subplots(1, 3, figsize=(0.75 * FIG_WIDTH, FIG_WIDTH * GOLDEN_RATIO / 3))
marker_size = 7

bases = [
    results["naive_ts"],
    results["experior"],
    results["oracle_ts"],
]
names = [r"Naïve-TS", r"ExPerior", r"Oracle-TS"]
color_list = ["red", "blue", "green", "black"]
fill_colors = ["red", "blue", "green", "black"]
markers = ["v", "^", "x", "o"]
linestyles = ["-.", "--", "-", ":"]
linewidth = 1

NUM_STEPS = results["oracle_ts"][0].shape[1]

# regret v.s. actions
for i in range(len(bases)):
    axis = axes[0]
    means = jnp.array(
        [bases[i][j][:, -1].mean() for j in range(len(results["num_actions_list"]))]
    )
    stds = jnp.array(
        [bases[i][j][:, -1].std() for j in range(len(results["num_actions_list"]))]
    )
    axis.plot(
        results["num_actions_list"],
        means,
        label=names[i],
        linewidth=linewidth,
        linestyle=linestyles[i],
        c=LIGHT_COLORS[color_list[i]],
    )
    if i == -1:
        axis.errorbar(
            results["num_actions_list"],
            means,
            yerr=stds,
            fmt="none",
            c=LIGHT_COLORS[color_list[i]],
            linewidth=linewidth,
        )
    else:
        axis.fill_between(
            results["num_actions_list"],
            means - stds,
            means + stds,
            color=LIGHT_COLORS[fill_colors[i]],
            alpha=0.5,
            linewidth=linewidth,
        )
    axis.set_xticks(results["num_actions_list"])

axis.set_ylabel(r"Bayesian Regret")
axis.set_xlabel(r"Number of Arms, $K$")
# regret v.s. entropy for actions 10
for i in range(len(bases)):
    axis = axes[1]
    axis.scatter(
        results["expert_entropies_list"][action_ind],
        bases[i][action_ind][:, -1],
        label=names[i],
        s=marker_size,
        marker=markers[i],
        c=LIGHT_COLORS[color_list[i]],
        alpha=0.6 if i == 1 else 1,
    )
axis.set_xlabel(r"Entropy of Optimal Action")


# regret v.s. horizon for actions 10
for i in range(len(bases)):
    axis = axes[2]
    means = bases[i][action_ind].mean(axis=0)
    stds = bases[i][action_ind].std(axis=0)
    axis.plot(
        jnp.arange(1, NUM_STEPS + 1),
        means,
        label=names[i],
        linestyle=linestyles[i],
        linewidth=linewidth,
        c=LIGHT_COLORS[color_list[i]],
    )
    if i == -1:
        axis.errorbar(
            jnp.arange(1, NUM_STEPS + 1)[::200],
            means[::200],
            yerr=stds[::200],
            fmt="none",
            c=LIGHT_COLORS[color_list[i]],
            linewidth=linewidth,
        )
    else:
        axis.fill_between(
            jnp.arange(1, NUM_STEPS + 1),
            means - stds,
            means + stds,
            color=LIGHT_COLORS[fill_colors[i]],
            alpha=0.5,
            linewidth=linewidth,
        )
axis.set_xlabel(r"Episodes, $T$")
hs0 = axes[1].get_legend_handles_labels()[0]
hs1 = axes[0].get_legend_handles_labels()[0]

handles = [(hs0[i], hs1[i]) for i in range(len(hs0))]
fig.legend(handles, names, loc="upper center", ncol=3)
plt.subplots_adjust(left=0.07, bottom=0.25, right=0.97, top=0.83, wspace=0.2)
fig.savefig("../output/bandit/emp_regret.pdf")

2024-04-09 16:00:20 INFO     Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
2024-04-09 16:00:20 INFO     Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
  axis.scatter(


#### Frequentist regret bound v.s. entropy of optimal actions


In [10]:
def entropy(px):
    return -1 * jnp.dot(px, jnp.log(px) / jnp.log(2))


def regret(px, T):
    sum = 0
    for i in range(len(px)):
        for j in range(len(px)):
            if j == i:
                continue
            sum += jnp.sqrt(
                (px[i] / (px[i] + px[j])) * (1 - px[i] / (px[i] + px[j]))
            ) * (jnp.sqrt(px[i]) + jnp.sqrt(px[j]))
    return sum * 2 * jnp.sqrt(T * jnp.log(T))


def sample_simplex(key, n):
    """
    Sample a probability vector of length n from the probability simplex.
    """
    # Sample n points from a uniform distribution
    key, rng1, rng2 = jax.random.split(key, 3)
    random_points1 = jax.random.beta(rng1, 1, 10, shape=(n,))
    random_points2 = jax.random.beta(rng2, 10, 1, shape=(n,))

    p = jax.random.choice(key, 2, shape=(n,))
    random_points = p * random_points1 + (1 - p) * random_points2

    # Normalize to ensure the sum is 1
    probability_vector = random_points / jnp.sum(random_points)
    return probability_vector


x = []
y = []
K = 2
n_samples = 500
rng = jax.random.PRNGKey(42)
for i in range(n_samples):
    rng, key = jax.random.split(rng)
    p = sample_simplex(key, K)
    x.append(entropy(p))
    y.append(regret(p, T=100))

latexify(
    0.25 * FIG_WIDTH,
    FIG_WIDTH * GOLDEN_RATIO / 3,
    font_size=FONT_SIZE,
    legend_size=LEGEND_SIZE,
    labelsize=LEGEND_SIZE,
)
fig, ax = plt.subplots(1, 1, figsize=(0.25 * FIG_WIDTH, FIG_WIDTH * GOLDEN_RATIO / 3))

ax.scatter(x, y, s=1, c=LIGHT_COLORS["black"])
ax.set_xlabel(r"Entropy of Optimal Action")
plt.subplots_adjust(left=0.09, bottom=0.25, right=0.97, top=0.83, wspace=0.2)
fig.savefig("../output/bandit/freq_regret.pdf")