In [19]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as spst

In [20]:
# K-armed Bernoulli bandit
# means: List of reward means 
# strat: epsilon-greedy (0) or UCB (1)
# M rounds of simulation, each for T timesteps

def bernoulli_bandit(means, strat, param, M, T):
    best = np.argmax(means)
    K = len(means)

    regrets = np.zeros((M,T))
    bestArmFreq = np.zeros((M,T))
    bestArmQ = np.zeros((M,T))

    for m in range(M):
        # Keep track of action values and action counts
        Q = np.zeros(K)
        N = np.zeros(K)

        for t in range(T):
            # Action selection depending on strategy
            if strat == 0:
                if np.random.random() < param:
                    arm = np.random.randint(0, K)
                else: 
                    arm = np.argmax(Q)
            else:                
                 arm = np.argmax(Q+param*np.sqrt(np.log(t+1)/(N+1)))

            # Reward and Q value update
            reward = np.random.binomial(1, means[arm])
            N[arm] += 1
            Q[arm] += 1/N[arm]*(reward-Q[arm])

            # Track frequency of choosing best arm and actual regret
            if arm == best:
                regrets[m,t] = 0
            else:
                regrets[m,t] = np.random.binomial(1,means[best])-reward
            bestArmFreq[m,t] = N[best]/(t+1)
            bestArmQ[m,t] = Q[best]

    return np.mean(regrets,axis=0), np.mean(bestArmFreq,axis=0)

In [21]:
def execute(means, strat, params=[0.1,0.2,0.3], M=100, T=10000):
    f = plt.figure()
    ax1 = f.add_subplot(211)
    ax2 = f.add_subplot(212,sharex=ax1)

    for p in params:
        regrets, bestArmFreq = bernoulli_bandit(means, strat, p, M, T)
        if strat == 0:
            ax1.plot(np.cumsum(regrets), label="e=%.2f"%p)
        else:
            ax1.plot(np.cumsum(regrets), label="c=%.2f"%p)
        ax2.plot(bestArmFreq)

    ax1.set_xscale('log',base=10)
    ax1.set_title("cumulative regret")
    ax2.set_title("percentage best arm played")
    f.legend()
    f.tight_layout()
    f.show()