In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [7]:
class Bandit:
    
    def __init__(
        self,
        mu,
        sigma
    ):
        self.mu = mu
        self.sigma = sigma
        
    def reward(self):
        return self.mu + self.sigma*np.random.randn()

In [42]:
class MultiArmBandit:
    
    def __init__(
        self,
        bandit_arms : list
    ):
        self.arms = bandit_arms
        self.n = len(bandit_arms)
        self.indices = np.arange(self.n)
        
    def action(self,k):
        if k < self.n:
            return self.arms[k].reward()

In [65]:
def estimate_q(
    multi_arm_bandit : MultiArmBandit,
    epsilon,
    initial_q,
    trials=1000
):
    q_estimate = np.zeros(multi_arm_bandit.n) + initial_q
    n_calls = np.zeros(multi_arm_bandit.n)
    
    for _ in range(trials):
        
        if np.random.rand() < epsilon:
            k = np.random.choice(multi_arm_bandit.indices)
        else:
            k = np.argmax(q_estimate)
        
        n_calls[k] += 1
        
        R = multi_arm_bandit.action(k)
        q_estimate[k] += (R - q_estimate[k])/n_calls[k]
        
    return q_estimate

In [69]:
estimate_q(
    multi_arm_bandit=MultiArmBandit([Bandit(5,2),Bandit(10,3)]),
    epsilon=0.1,
    initial_q=0,
    trials=10000
)

array([ 4.95338387, 10.01167132])