# Multi-armed Bandit

[*reference*](https://medium.com/@awjuliani/super-simple-reinforcement-learning-tutorial-part-1-fd544fab149#.y721lsdjn)

In [27]:
import tensorflow as tf
import numpy as np
import random

In [26]:
# list of bandits
bandits = [0.2, 0, -0.2, -5]
num_bandits = len(bandits)

# generate a random number with mean 0
def pull_bandit(bandit):
    return 1 if np.random.randn(1) > bandit else -1    

In [56]:
# demo of how function works
pull_bandit(random.choice(bandits))

-1

[tf.slice](https://www.tensorflow.org/api_docs/python/array_ops/slicing_and_joining#slice)

In [65]:
# let us build a neural network
#  output : action
#  input  : none
#  training inputs : reward, action
#  choose action : one with maximum weight (index)
tf.reset_default_graph()
w = tf.Variable(tf.ones([num_bandits])) # vector of len num_bandits
action = tf.argmax(w,0) # tis a vector
# training
reward_ = tf.placeholder(shape=[1], dtype=tf.float32)
action_ = tf.placeholder(shape=[1], dtype=tf.int32)
best_weight = tf.slice(w,action_,[1])
loss = -(tf.log(best_weight)*reward_)
train_fn = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss)

In [66]:
# params
num_epi = 1000
rall = np.zeros(num_bandits)
e = 0.1

In [67]:
init = tf.initialize_all_variables()
with tf.Session() as sess:
    sess.run(init)
    for i in range(num_epi):
        # choose an action with prob > e
        if np.random.randn(1) < e:
            action_v = np.random.randint(num_bandits)
        else:
            action_v = sess.run(action)
        # get reward for chosen action
        reward_v = pull_bandit(bandits[action_v])
        
        # based on the reward and chosen action, update the network
        sess.run(train_fn, feed_dict = {reward_ : [reward_v],
                                        action_ : [action_v]
                                       })
        # keep track of rewards
        rall[action_v] += reward_v
        if i%100 == 0:
            print('Reward status : {}'.format(rall))

Reward status : [-1.  0.  0.  0.]
Reward status : [  1.  -4.  -2.  48.]
Reward status : [  -2.   -6.    3.  104.]
Reward status : [   1.   -8.   12.  160.]
Reward status : [   3.  -14.   10.  210.]
Reward status : [   0.  -12.   12.  271.]
Reward status : [  -2.   -3.   22.  322.]
Reward status : [  -6.   -7.   28.  380.]
Reward status : [ -12.  -16.   35.  444.]
Reward status : [  -6.  -13.   37.  495.]


In [69]:
print('And the best bandit to pull is bandit #{}'.format(np.argmax(rall)+1))

And the best bandit to pull is bandit #4
