<a href="https://colab.research.google.com/github/probml/bandits/blob/main/bandits/scripts/subspace_bandits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bayesian Subspace bandits

See  https://arxiv.org/abs/2112.00195 for details.


## Installation

In [1]:
!git clone --depth 1 https://github.com/probml/bandits

Cloning into 'bandits'...
remote: Enumerating objects: 56, done.[K
remote: Counting objects: 100% (56/56), done.[K
remote: Compressing objects: 100% (53/53), done.[K
remote: Total 56 (delta 11), reused 23 (delta 1), pack-reused 0[K
Unpacking objects: 100% (56/56), done.


In [2]:
!pip install -qqq fire
!pip install -qqq ml-collections
!pip install -qqq git+git://github.com/deepmind/optax.git
!pip install -qqq --upgrade git+https://github.com/google/flax.git

[?25l[K     |███▊                            | 10 kB 32.8 MB/s eta 0:00:01[K     |███████▌                        | 20 kB 37.0 MB/s eta 0:00:01[K     |███████████▏                    | 30 kB 22.2 MB/s eta 0:00:01[K     |███████████████                 | 40 kB 18.1 MB/s eta 0:00:01[K     |██████████████████▊             | 51 kB 17.1 MB/s eta 0:00:01[K     |██████████████████████▍         | 61 kB 14.0 MB/s eta 0:00:01[K     |██████████████████████████▏     | 71 kB 11.9 MB/s eta 0:00:01[K     |██████████████████████████████  | 81 kB 13.1 MB/s eta 0:00:01[K     |████████████████████████████████| 87 kB 5.8 MB/s 
[?25h  Building wheel for fire (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 88 kB 6.0 MB/s 
[K     |████████████████████████████████| 65 kB 3.3 MB/s 
[?25h  Building wheel for optax (setup.py) ... [?25l[?25hdone
  Building wheel for flax (setup.py) ... [?25l[?25hdone


## Test the installatation

In [3]:
%%bash
cd /content/bandits
python bandits test

Expected Reward : 4419.70 ± 13.78
Time : 11.732s


## Setup 

In [None]:
%cd /content/bandits/bandits/experiments

/content/bandits/bandits/experiments


In [None]:
import os
os.chdir("..")

import jax
import ml_collections

import pandas as pd

import glob
from datetime import datetime

import scripts.movielens_exp as movielens_run
import scripts.mnist_exp as mnist_run
import scripts.tabular_exp as tabular_run
import scripts.tabular_subspace_exp as tabular_sub_run

print(jax.device_count())



1


In [None]:
def get_config(results_filename):
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.filepath = results_filename
  config.ntrials = 2 # was 10 in paper
  return config

In [None]:
timestamp = datetime.timestamp(datetime.now())

In [None]:
def plot_figure(data, x, y, filename, figsize=(24, 9), log_scale=False):   
    sns.set(font_scale=1.5)
    plt.style.use("seaborn-poster")

    fig, ax = plt.subplots(figsize=figsize, dpi=300)
    g = sns.barplot(x=x, y=y, hue="Method", data=data, errwidth=2, ax=ax, palette=colors)
    if log_scale:
        g.set_yscale("log")
    plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0))
    plt.tight_layout()
    plt.savefig(f"./figures/{filename}.png")
    plt.show()

def read_data(dataset_name):
    *_, filename = sorted(glob.glob(f"./results/{dataset_name}_results*.csv"))
    df = pd.read_csv(filename)
    if dataset_name=="mnist":
        linear_df = df[(df["Method"]=="Lin-KF") | (df["Method"]=="Lin")].copy()
        linear_df["Model"] = "MLP2"
        df = df.append(linear_df)
        linear_df["Model"] = "LeNet5"
        df = df.append(linear_df)

    by = ["Rank"] if dataset_name=="tabular" else ["Rank", "AltRank"]

    data_up = df.sort_values(by=by).copy()
    data_down = df.sort_values(by=by).copy()

    data_up["Reward"] = data_up["Reward"] + data_up["Std"]
    data_down["Reward"] = data_down["Reward"] - data_down["Std"]
    data = pd.concat([data_up, data_down])
    return data

def plot_subspace_figure(df, filename=None):
    df = df.reset_index().drop(columns=["index"])
    plt.style.use("seaborn-darkgrid")
    fig, ax = plt.subplots(figsize=(12, 8))
    sns.lineplot(x="Subspace Dim", y="Reward", hue="Method", marker="o", data=df)
    lines, labels = ax.get_legend_handles_labels()
    for line, method in zip(lines, labels):
        data = df[df["Method"]==method]
        color = line.get_c()
        y_lower_bound =  data["Reward"] -  data["Std"]
        y_upper_bound = data["Reward"] + data["Std"]
        ax.fill_between(data["Subspace Dim"],  y_lower_bound, y_upper_bound, color=color, alpha=0.3)

    ax.set_ylabel("Reward", fontsize=16)
    plt.setp(ax.get_xticklabels(), fontsize=16) 
    plt.setp(ax.get_yticklabels(), fontsize=16) 
    ax.set_xlabel("Subspace Dimension(d)", fontsize=16)
    dataset = df.iloc[0]["Dataset"]
    ax.set_title(f"{dataset.title()} - Subspace Dim vs. Reward", fontsize=18)
    legend = ax.legend(loc="lower right", prop={'size': 16},frameon=1)
    frame = legend.get_frame()
    frame.set_color('white')
    frame.set_alpha(0.6)
    
    file_path = "./figures/"
    file_path = file_path + f"{dataset}_sub_reward.png" if filename is None else file_path + f"{filename}.png"
    plt.savefig(file_path)

# Run tabular experiments

In [None]:
%cd /content/bandits/bandits

/content/bandits/bandits/experiments


In [None]:
tabular_filename = f"./results/tabular_results_{timestamp}.csv"
config = get_config(tabular_filename)
tabular_run.main(config)

Environment :  shuttle
	Bandit : Linear
		Expected Reward : 4413.50 ± 4.50
		Time : 10.469s
	Bandit : Linear KF
		Expected Reward : 4414.50 ± 4.50
		Time : 6.309s
	Bandit : Linear Wide
		Expected Reward : 4210.00 ± 10.00
		Time : 25.030s
	Bandit : Limited Neural Linear
		Expected Reward : 3840.00 ± 3.00
		Time : 23.608s
	Bandit : Unlimited Neural Linear
		Expected Reward : 4089.00 ± 70.00
		Time : 42.628s
	Bandit : EKF Subspace SVD
		Expected Reward : 4731.00 ± 116.00
		Time : 198.925s
	Bandit : EKF Subspace RND
		Expected Reward : 4846.50 ± 1.50
		Time : 199.065s
	Bandit : EKF Diagonal Subspace SVD
		Expected Reward : 4831.00 ± 0.00
		Time : 9.122s
	Bandit : EKF Diagonal Subspace RND
		Expected Reward : 4797.00 ± 0.00
		Time : 9.127s
	Bandit : EKF Orig Diagonal
		Expected Reward : 3915.00 ± 4.00
		Time : 6.106s
	Bandit : EKF Orig Full
		Expected Reward : 3913.00 ± 2.00
		Time : 875.099s
Environment :  covertype
	Bandit : Linear
		Expected Reward : 3016.50 ± 13.50
		Time : 20.976s
	Ban

In [None]:
dataset_name = "tabular"
tabular_df = read_data(dataset_name)
tabular_rows = ['EKF-Sub-SVD', 'EKF-Sub-RND', 'EKF-Sub-Diag-SVD', 'EKF-Sub-Diag-RND',
                'EKF-Orig-Full',  'EKF-Orig-Diag', 'NL-Lim', 'NL-Unlim', 'Lin', 'Lim2', 'NeuralTS']
tabular_df = tabular_df[tabular_df['Method'].isin(tabular_rows)]

In [None]:
x, y = "Dataset", "Reward"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(tabular_df, x, y, filename)

In [None]:
x, y = "Dataset", "Time"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(tabular_df[tabular_df["Method"] != "NeuralTS"], x, y, filename, log_scale=True)

# Run movielens experiments

In [None]:
movielens_filename = f"./results/movielens_results_{timestamp}.csv"
config = get_config(movielens_filename)
movielens_run.main(config)

In [None]:
dataset_name = "movielens"
movielens_df = read_data(dataset_name)
movielens_rows =  ['EKF-Sub-SVD', 'EKF-Sub-RND', 'EKF-Sub-Diag-SVD', 'EKF-Sub-Diag-RND',
                   'EKF-Orig-Diag', 'NL-Lim', 'NL-Unlim', 'Lin']
movielens_df = movielens_df[movielens_df['Method'].isin(movielens_rows)]

In [None]:
x, y = "Model", "Reward"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(movielens_df, x, y, filename)

In [None]:
x, y = "Model", "Time"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(movielens_df, x, y, filename)

# Run MNIST experiments

In [None]:
mnist_filename = f"./results/mnist_results_{timestamp}.csv"
config = get_config(mnist_filename)
mnist_run.main(config)

In [None]:
method_ordering = {"EKF-Sub-SVD": 0,
                   "EKF-Sub-RND": 1,
                   "EKF-Sub-Diag-SVD": 2,
                   "EKF-Sub-Diag-RND": 3,
                   "EKF-Orig-Full": 4,
                   "EKF-Orig-Diag": 5,
                   "NL-Lim": 6,
                   "NL-Unlim": 7,
                   "Lin": 8,
                   "Lin-KF": 9,
                   "Lin-Wide": 9,
                   "Lim2": 10,
                   "NeuralTS": 11}
                   
colors = {k : sns.color_palette("Paired")[v]
          if k!="Lin-KF" else  sns.color_palette("tab20")[8]
          for k,v in method_ordering.items()}

In [None]:
dataset_name = "mnist"
# For possible methods, run mnist_df.Method.unique()
mnist_rows = ['EKF-Sub-SVD', 'EKF-Sub-RND', 'EKF-Sub-Diag-SVD', 'EKF-Sub-Diag-RND', 'EKF-Orig-Diag', 'NL-Lim', 'NL-Unlim', 'Lin']

In [None]:
mnist_df = read_data(dataset_name)
mnist_df = mnist_df[mnist_df['Method'].isin(mnist_rows)]

In [None]:
x, y = "Model", "Reward"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(mnist_df, x, y, filename)

In [None]:
x, y = "Model", "Time"
filename = f"{dataset_name}_{y.lower()}"
plot_figure(mnist_df, x, y, filename, log_scale=True)

# Run tabular subspace experiment

In [None]:
tabular_sub_filename = f"./results/tabular_subspace_results_{timestamp}.csv"
config = get_config(tabular_sub_filename)
tabular_sub_run.main(config)

In [None]:
*_, filename = sorted(glob.glob(f"./results/tabular_subspace_results*.csv"))
tabular_sub_df = pd.read_csv(filename)

In [None]:
dataset_name = "shuttle"
shuttle = tabular_sub_df[tabular_sub_df["Dataset"]==dataset_name]
plot_subspace_figure(shuttle)

In [None]:
dataset_name = "adult"
adult = tabular_sub_df[tabular_sub_df["Dataset"]==dataset_name]
plot_subspace_figure(adult)

In [None]:
dataset_name = "covertype"
covertype = tabular_sub_df[tabular_sub_df["Dataset"]==dataset_name]
plot_subspace_figure(covertype)