In [76]:
import numpy as np
import matplotlib.pyplot as plt
from numba import jit


In [2]:
n = 10 # 10个臂
bandit_p = np.random.rand(n)

def gen_data(i):
    assert i < n
    if np.random.rand() < bandit_p[i]:
        return 1  # reward
    return 0 # no reward

In [1]:
# 随机选择一个

def random_select(ctx):
    return np.random.randint(ctx['n'])


def determine_select(ctx):
    return np.argmax(ctx['cum_reward_action'] / (ctx['cum_action'] + 1))

# epsilon 贪心

def epsilon_greedy(ctx):
    epsilon = 1 / (1 + np.sqrt(ctx['step']))
    if np.random.rand() < epsilon: # random
        return random_select(ctx)
    else:
        return determine_select(ctx)
        
        
# naive

def naive_select(ctx):
    if ctx['step'] < 500:
        return random_select(ctx)
    else:
        return determine_select(ctx)

# softmax
def softmax_select(ctx):
    n0 = 0.1
    avg_reward = ctx['cum_reward_action'] / (ctx['cum_action'] + 1)
    p = np.exp(avg_reward / n0)
    p = p / np.sum(p)
    #print ctx['step'], p
    return np.random.multinomial(1, p)[0]

###

In [92]:
@jit
def simulation(select_action, n_ = 3000):
    t = []
    reward = []
    cum_reward_per = 0.0
    cum_reward_action = np.zeros(n)
    cum_action = np.zeros(n)
    for i in range(n_):
        t.append(i)
        ctx = {'n' : n, 'cum_reward_action' : cum_reward_action, 'cum_action' : cum_action, 'step': i}
        action = select_action(ctx)
        cum_action[action] += 1
        r = gen_data(action)
        action_n = cum_action[action]
        cum_reward_action[action] += r
        cum_reward_per += r
        reward.append(cum_reward_per)
    return np.array(t), np.array(reward)

@jit
def run(select_action, rnd_ = 100, n_ = 3000):
    t = []
    avg_reward = 0
    for i in range(rnd_):
        t, reward = simulation(select_action, n_)
        avg_reward = avg_reward * i / (i + 1) + reward / (i+1)
    return t, np.max(bandit_p) * (1 + np.arange(n_)) - avg_reward

In [2]:
t, r = run(random_select)
plt.plot(t, r)

t, r = run(epsilon_greedy)
plt.plot(t, r)

t, r = run(naive_select)
plt.plot(t, r)

t, r = run(softmax_select)
plt.plot(t, r)

plt.legend(['random', 'epsilon greedy', 'naive', 'softmax select'])
plt.xlabel('t')
plt.ylabel('regret')

NameError: name 'run' is not defined

In [30]:
np.argmax(bandit_p) - np.mean(bandit_p)

3.5700523317477573

In [102]:
run(softmax_select, 1, 50)

0 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
1 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
2 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
3 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
4 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
5 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
6 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
7 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
8 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
9 [0.23196932 0.08533674 0.08533674 0.08533674 0.08533674 0.08533674
 0.08533674 0.08533674 0.08533674 0.08533674]
10 [0.21616887 0.08709235 0.08709235 0.08709235 0.08709235 0.08709235
 0.08709235 0.08709235 0.08709235 0.08709235]
11 [0.37038677 0.06995703 0.06995703 0.06995703 0.06995703 0.06995703
 0.06995703 0.06995703 0.06995703 0.06995703]
12 [0.34101753 0.07322027 0.07322027 0.07322027 0.07322027 0.07322027
 0.07322027 0.07322027 0.07322027 0.07322027]
13 [0.0289154  0.92141694 0.00620846 0.00620846 0.00620846 0.00620846
 0.00620846 0.00620846 0.00620846 0.00620846]
14 [0.02598444 0.92419799

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]),
 array([ 0.89353146,  1.78706293,  2.68059439,  3.57412585,  4.46765731,
         5.36118878,  6.25472024,  7.1482517 ,  7.04178316,  7.93531463,
         7.82884609,  8.72237755,  8.61590901,  9.50944048, 10.40297194,
        11.2965034 , 12.19003486, 13.08356633, 13.97709779, 14.87062925,
        15.76416071, 16.65769218, 17.55122364, 18.4447551 , 19.33828656,
        19.23181803, 20.12534949, 21.01888095, 21.91241241, 22.80594388,
        23.69947534, 24.5930068 , 25.48653826, 26.38006973, 27.27360119,
        27.16713265, 28.06066412, 27.95419558, 28.84772704, 28.7412585 ,
        29.63478997, 30.52832143, 30.42185289, 31.31538435, 32.20891582,
        33.10244728, 33.99597874, 34.8895102 , 35.78304167, 36.67657313]))