# Correlated Bandit Problem

Experiments on https://arxiv.org/abs/1902.02953

In [None]:
import numpy as np
import random
from scipy import stats

In [None]:
class BaseCorrelatedBandit():
    def __init__(self, num_arms):
        self.num_arms = num_arms
        
    def _sample_distr(self):
        # throw error not implemented
        pass
    
    # 0-indexed
    def sample_arms(self, arms=None, size=1):
        multi_sample = self._sample_distr(size=size)
        if arms is None:
            arms = list(range(len(multi_sample)))
        else:
            arms = list(arms)
        samples = np.reshape(multi_sample, (size, -1))
        return samples[:, arms]
    
    def mse(self):
        # throw error not implemented
        pass
    
    def best_arm(self):
        return np.argmin(self.mse())
    
    def hbar(self):
        mse_vals = self.mse()
        best_arm = self.best_arm()
        return sum(1/(v - mse_vals[best_arm])**2 for i,v in enumerate(mse_vals) if i != best_arm)

In [None]:
class GaussianBandit(BaseCorrelatedBandit):
    def __init__(self, num_arms, mean=None, cov=1):
        super(GaussianBandit, self).__init__(num_arms)
        if mean is None:
            mean = np.zeros(self.num_arms)
        self.distr = stats.multivariate_normal(mean=mean, cov=cov)
        
    def _sample_distr(self, size=1):
        return self.distr.rvs(size=size)
    
    def mse(self):
        xvars = np.diagonal(self.distr.cov)
        xcorr = self.distr.cov / np.sqrt(xvars) / np.sqrt(np.reshape(xvars, (-1,1)))
        return np.sum(xvars*(1-xcorr*xcorr), axis=1)      

In [None]:
K = 10
c = np.random.rand(K,K)
c = c.dot(c.T)
B = GaussianBandit(K, cov=c)
print(B.mse())

In [None]:
# arm_sample_dict = dict of pair => 2xM size matrix of M samples
def sample_stats(arm_sample_dict, num_arms=None):
    emeans = np.zeros((num_arms, num_arms))
    evars = np.zeros((num_arms, num_arms))
    ecorrs = np.zeros((num_arms, num_arms))
    for inds,xs in arm_sample_dict.items():
        i,j = inds
        xs = xs.transpose()
#         print(xs.shape)
        emeans[i,j],  emeans[j,i] = np.mean(xs, axis=1)
        # ddof = 0: estimator used in paper; ddof = 1: unbiased estimator
        evars[i,j],   evars[j,i]  = np.var(xs, ddof=0, axis=1)
        ecorrs[i,j] = ecorrs[j,i] = np.corrcoef(xs)[0,1]
    return emeans, evars, ecorrs

In [None]:
def mse_arm(i, evars, ecorrs):
    return sum(evars[p,i]*(1-ecorrs[p,i]**2) for p in range(ecorrs.shape[0]) if p != i)

def mse(arms, samples, num_arms=None):
    _, evars, ecorrs = sample_stats(arms, samples, num_arms=num_arms)
    return np.sum(np.transpose(evars)*(1-ecorrs*ecorrs), axis=1)

def mse(arms_dict, num_arms=None):
    _, evars, ecorrs = sample_stats(arms_dict, num_arms=num_arms)
    return np.sum(np.transpose(evars)*(1-ecorrs*ecorrs), axis=1)

In [None]:
def uniform_sampling(bandit, trials):
    n = (2*trials) // (bandit.num_arms*(bandit.num_arms-1))
    arms_dict = dict()
    for i in range(bandit.num_arms):
        for j in range(i+1, bandit.num_arms):
            arms_dict[(i,j)] = bandit.sample_arms(arms=[i,j], size=n)
    mse_vals = mse(arms_dict, num_arms=bandit.num_arms)
    return np.argmin(mse_vals), mse_vals

In [None]:
def sr_trials(n, K):
    ck = (K-1)/2 + sum(j/(K-j) for j in range(1, K-1))
    K2 = K*(K-1)/2
    # ceiling division
#     [print((n-K2),"/(",ck,"*",(K+1-k),")", -(-(n-K2) / (ck*(K+1-k)))) for k in range(1, K)]
    return [-(-(n-K2) // (ck*(K+1-k))) for k in range(1, K)]

# verbose: 0 => no prints, 1 => print 0-indexed, 2 => print 1-indexed
def successive_rejects(bandit, trials, verbose=0):
    B = set(range(bandit.num_arms))
    A = set((i,j) for i in range(bandit.num_arms) for j in range(i+1, bandit.num_arms))
    ns = [0] + sr_trials(trials, bandit.num_arms)
    ts = np.subtract(ns[1:], ns[:-1])
    mse_ests = np.zeros(bandit.num_arms)
    arms_dict = dict()
    for k,t in enumerate(ts):
        if verbose:
            print(f"Phase {k+1}: {int(t)} samples per pair")
        for pair in A:
#         for pair in arms_dict:
            if int(t) == 0:
                break
            if pair not in arms_dict:
                arms_dict[pair] = bandit.sample_arms(arms=list(pair), size=int(t))
            else:
                arms_dict[pair] = np.concatenate((arms_dict[pair], bandit.sample_arms(arms=list(pair), size=int(t))), axis=0)
        mse_vals = mse(arms_dict, num_arms=bandit.num_arms)
        reject, val = max(((i,v) for (i,v) in enumerate(mse_vals) if i in B), key=lambda x: x[1])
        if verbose:
            print(f"  MSE Estimates: {mse_vals}")
            print(f"  Rejected arm:  {reject+verbose-1}")
        mse_ests[reject] = val
        min_mse = min(v for (i,v) in enumerate(mse_vals) if i in B)
        B.remove(reject)
        A = {p for p in A if p[0] in B or p[1] in B}
        arms_dict = {p:v for p,v in arms_dict.items() if p[0] in B or p[1] in B}
    arm = B.pop()
    # these estimates might be really bad
    mse_ests[arm] = min_mse
    if verbose:
        print(f"Best arm: {arm+verbose-1}")
    return arm, mse_ests

### Ex 0: Illustrating the Algorithms

In [None]:
S0 = np.array([[1,0.6,0.2,0.1],
               [0.6,1,0.1,0.1],
               [0.2,0.1,1,0.0],
               [0.1,0.1,0.0,1]])
B0 = GaussianBandit(S0.shape[0], cov=S0)
np.argmin(B0.mse()), B0.mse()

In [None]:
successive_rejects(B0, 50000, verbose=2)

In [None]:
bandit = B0
trials = 50000
print("#trials =", trials)
%time b0_unif = np.mean([uniform_sampling(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
%time b0_succ = np.mean([successive_rejects(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
b0_unif, b0_succ

### Ex 1-3: Recreating Experiments from Boda & Prashanth

In [None]:
M1 = np.array([[1, 0.9, 0.9, 0.9],  [0.9, 1, 0.85, 0.85], [0.9, 0.85, 1, 0.85], [0.9, 0.85, 0.85, 1]])
S1 = np.block([[M1, np.zeros((4,25))], [np.zeros((25,4)), np.eye(25)]])
B1 = GaussianBandit(S1.shape[0], cov=S1)
np.argmin(B1.mse()), B1.mse()

In [None]:
bandit = B1
trials = int(bandit.hbar() * 32 * 32) * 4
print("#trials =", trials)
%time b1_unif = np.mean([uniform_sampling(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
%time b1_succ = np.mean([successive_rejects(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
b1_unif, b1_succ

In [None]:
Tr = np.eye(31) + 0.2*np.eye(31, k=1) + 0.2*np.eye(31,k=-1)
S2 = np.block([[M1, np.zeros((4,31))], [np.zeros((31,4)), Tr]])
B2 = GaussianBandit(S2.shape[0], cov=S2)
np.argmin(B2.mse()), B2.mse()

In [None]:
bandit = B2
# trials = 500000
trials = int(bandit.hbar() * 32 * 32) * 4
print("#trials =", trials)
%time b2_unif = np.mean([uniform_sampling(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
%time b2_succ = np.mean([successive_rejects(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
b2_unif, b2_succ

In [None]:
M3 = np.array([[1,0.5,0.45,0.5], [0.5,1,0.45,0.4], [0.45,0.45,1,0.4], [0.5,0.4,0.4,1]])
S3 = np.block([[M3, np.zeros((4,30))], [np.zeros((30,4)), np.eye(30)]])
B3 = GaussianBandit(S3.shape[0], cov=S3)
np.argmin(B3.mse()), B3.mse()

In [None]:
bandit = B3
# trials = 500000
trials = int(bandit.hbar() * 32 * 32) * 4
print("#trials =", trials)
%time b3_unif = np.mean([uniform_sampling(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
%time b3_succ = np.mean([successive_rejects(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
b3_unif, b3_succ

### Ex 4: Many arms correlated, many arms uncorrelated

In [None]:
a = [np.linspace(0.5, 0.35, num=15)]
M5 = np.outer(a,a)
np.fill_diagonal(M5, 1)
S5 = np.block([[M5, np.zeros((15,15))], [np.zeros((15,15)), np.eye(15)]])
B5 = GaussianBandit(S5.shape[0], cov=S5)
np.argmin(B5.mse()), B5.mse()

In [None]:
bandit = B5
# trials = 500000
trials = int(bandit.hbar() * 32 * 32 / 2)
print("#trials =", trials)
%time b5_unif = np.mean([uniform_sampling(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
%time b5_succ = np.mean([successive_rejects(bandit, trials)[0] == bandit.best_arm() for _ in range(100)])
b5_unif, b5_succ

### Results

In [None]:
import matplotlib.pyplot as plt

In [None]:
def bar_graph(results, width=0.15, labels=None):
    if labels is None:
        labels = ['Experiment ' + str(i+1) for i in range(len(results))]
    plt.figure(figsize=(8,6))
    x = np.arange(len(results))
    res = np.array(list(zip(*results)))
    plt.bar(x - width/2, 1 - res[0,:], width, label='Uniform Sampling', tick_label=labels, zorder=2)
    plt.bar(x + width/2, 1 - res[1,:], width, label='Successive Rejects', zorder=2)
    # Add some text for labels, title and custom x-axis tick labels, etc.
    plt.ylabel('Probability of Error')
    plt.xticks(x)
    plt.yticks(np.arange(0, np.max(1-res) + 0.15, 0.05))
    plt.grid(b=True, axis='y', zorder=1)
    plt.legend(loc='upper right')

In [None]:
bar_graph([[b0_unif, b0_succ]], labels=['Experiment 0'])
plt.xlim(-1,1)

In [None]:
bar_graph([[b1_unif, b1_succ], [b2_unif, b2_succ], [b3_unif, b3_succ]])

In [None]:
bar_graph([[b5_unif, b5_succ]], width=0.15, labels=['Experiment 5'])
plt.xlim(-1,1)
plt.show()