## n-armed bandit problem

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

**Bandit**

In [3]:
# List our bandit arms
#Currently arm 4 (index #3) is set to most often provide a positive reward.
bandit_arms = [0.2,0,-2,-0.2]
num_arms = len(bandit_arms)
def pullBandit(bandit):
    #Get a random number.
    result = np.random.randn(1)
    if result > bandit:
        #return a positive reward.
        return 1
    else:
        #return a negative reward.
        return -1


In [6]:
# Train
w = Variable(torch.ones(num_arms), requires_grad = True)
output = nn.functional.softmax(w)

optimizer = optim.Adam([w])

total_episodes = 1000
total_reward = np.zeros(num_arms)

i = 0
while i < total_episodes:
    #Choose action according to Boltzmann distribution.
    actions = output
    a = np.random.choice(actions.data.numpy(), p=actions.data.numpy())
    
    action = np.argmax(actions.data.numpy == a)
    
    reward = pullBandit(bandit_arms[action]) # Get reward for picking one of the bandit arms

    action_v = Variable(torch.from_numpy(np.array(action)), requires_grad = False)
    reward_v = Variable(torch.from_numpy(np.array(reward)), requires_grad = False)
    
    responsible_output = output[action:action+1]
#     responsible_output_f = responsible_output.type(torch.FloatTensor)
#     print(output, responsible_output)
    optimizer.zero_grad()
    
    loss = -(torch.log(responsible_output))
    loss.backward(retain_variables=True)
    optimizer.step()
#     print(w)
#     print('=========')
#     print(w)
#     w = w - 0.1*w.grad.data
#     print('=========')
    
    total_reward[action] += reward
    
    if i % 50 == 0:
        print("Running reward for the " + str(num_arms) + " arms of the bandit: " + str(total_reward))
    i += 1
    
print("\nThe agent thinks arm " + str(np.argmax(w)+1) + " is the most promising....")
if np.argmax(w) == np.argmax(-np.array(bandit_arms)):
    print("...and it was right!")
else:
    print("...and it was wrong!")

        

Running reward for the 4 arms of the bandit: [-1.  0.  0.  0.]
Running reward for the 4 arms of the bandit: [-15.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-33.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-21.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-15.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-27.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-29.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-35.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-47.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-41.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-53.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-67.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-79.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-75.   0.   0.   0.]
Running reward for the 4 arms of the bandit: [-81.   0.   0.   0.]