The simplest way to implement ***policy gradient*** network is one which produces explicit outputs. In <U>two armed bandit</U> problem, there is no need to consider given state for those outputs. As such, the network will consist of just a set of weights, with each eorresponding to each of the possible arms to pull in the bandit, and will represent how good the agent thinks it is to pull each arm. If those weights are initialized to 1, then  the agent will be somewhat optimistic about each arm's potential reward. 

In [2]:
import tensorflow as tf
import numpy as np
import gym

In [3]:
#Currently arm 4 (index #3) is set to most often provide a positive reward.
bandit_arms = [0.2,0,-0.2,-2]
num_arms = len(bandit_arms)
def pullBandit(bandit):
    #Get a random number.
    result = np.random.randn(1)
    if result > bandit:
        #return a positive reward.
        return 1
    else:
        #return a negative reward.
        return -1

$ loss = -log(\pi)*Advantage$

The Advantage corresponds to how much better an action was than some baseline.  In this case, the baseline is 0 and then, it can be thought of as simply the reward the agent received for each action.

The policy corresponds to the chosen action's weight.

> The loss function would increase the weight for actions that yielded a positive reward and decrease it for actions that yields a negative reward.

In [4]:
tf.reset_default_graph()

#These two lines established the feed-forward part of the network. 
weights = tf.Variable(tf.ones([num_arms]))
output = tf.nn.softmax(weights)

#The next six lines establish the training proceedure. We feed the reward and chosen action into the network
#to compute the loss, and use it to update the network.
reward_holder = tf.placeholder(shape=[1],dtype=tf.float32)
action_holder = tf.placeholder(shape=[1],dtype=tf.int32)

responsible_output = tf.slice(output,action_holder,[1])
loss = -(tf.log(responsible_output)*reward_holder)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
update = optimizer.minimize(loss)

In [7]:
total_episodes = 1000 #Set total number of episodes to train agent on.
total_reward = np.zeros(num_arms) #Set scoreboard for bandit arms to 0.

init = tf.global_variables_initializer()

# Launch the tensorflow graph
with tf.Session() as sess:
    sess.run(init)
    i = 0
    while i < total_episodes:
        
        #Choose action according to Boltzmann distribution.
        actions = sess.run(output)
        print(actions)
        a = np.random.choice(actions,p=actions)
        print(a)
        action = np.argmax(actions == a)

        reward = pullBandit(bandit_arms[action]) #Get our reward from picking one of the bandit arms.
        
        #Update the network.
        _,resp,ww = sess.run([update,responsible_output,weights], feed_dict={reward_holder:[reward],action_holder:[action]})
        
        #Update our running tally of scores.
        total_reward[action] += reward
        if i % 50 == 0:
            print("Running reward for the " + str(num_arms) + " arms of the bandit: " + str(total_reward))
        i+=1
print("\nThe agent thinks arm " + str(np.argmax(ww)+1) + " is the most promising....")
if np.argmax(ww) == np.argmax(-np.array(bandit_arms)):
    print("...and it was right!")
else:
    print("...and it was wrong!")

[0.25 0.25 0.25 0.25]
0.25
Running reward for the 4 arms of the bandit: [-1.  0.  0.  0.]
[0.24962518 0.25012493 0.25012493 0.25012493]
0.24962518
[0.24925074 0.25024974 0.25024974 0.25024974]
0.25024974
[0.24902804 0.25011572 0.25042814 0.25042814]
0.25042814
[0.24885288 0.25000253 0.2506343  0.25051028]
0.25000253
[0.24872643 0.2498252  0.25082794 0.25062042]
0.2498252
[0.24863453 0.24960922 0.25101042 0.25074583]
0.24960922
[0.24853995 0.24951145 0.2511412  0.25080737]
0.25080737
[0.2484353  0.24940908 0.25122336 0.25093225]
0.25093225
[0.24832596 0.24930637 0.25126994 0.25109774]
0.25126994
[0.24820767 0.24919814 0.25137466 0.25121948]
0.25137466
[0.24808462 0.249088   0.25152028 0.2513071 ]
0.25152028
[0.24799944 0.24900915 0.25158298 0.2514084 ]
0.24799944
[0.24798977 0.24891868 0.25161836 0.25147316]
0.24798977
[0.24792707 0.2488544  0.25166678 0.25155178]
0.25155178
[0.2478541  0.2487799  0.25169283 0.25167322]
0.2478541
[0.24774183 0.24872796 0.25173137 0.25179884]
0.25179884


[0.23778509 0.2445667  0.24927044 0.26837775]
0.26837775
[0.23782134 0.24457209 0.24919155 0.26841497]
0.24457209
[0.23786618 0.24453394 0.24913427 0.26846564]
0.23786618
[0.23786515 0.244512   0.24909553 0.26852733]
0.244512
[0.23785214 0.24453524 0.24904674 0.2685659 ]
0.23785214
[0.23788163 0.24454385 0.24898988 0.26858464]
0.26858464
[0.23789419 0.24453615 0.24892256 0.26864713]
0.24892256
[0.23791765 0.24454269 0.24881902 0.26872066]
0.26872066
[0.23792478 0.24453321 0.24870986 0.26883215]
0.24453321
[0.23794332 0.24448171 0.24862528 0.26894966]
0.23794332
[0.23791867 0.24444772 0.24856196 0.2690717 ]
0.24856196
[0.23788409 0.24440335 0.24854831 0.26916426]
0.24854831
[0.23786508 0.24437687 0.24849309 0.26926503]
0.24437687
[0.23783578 0.24439605 0.24842952 0.26933864]
0.24439605
[0.23779738 0.24445602 0.24835859 0.269388  ]
0.24445602
[0.23775084 0.24455248 0.24828121 0.26941553]
0.26941553
[0.23769459 0.24462403 0.24819544 0.26948595]
0.24819544
[0.23765603 0.24470185 0.24807552

[0.22990519 0.23427732 0.25031048 0.285507  ]
0.25031048
[0.22989202 0.23420615 0.25031513 0.28558674]
0.23420615
[0.2298908  0.23410009 0.25033256 0.28567657]
0.25033256
[0.22990145 0.234017   0.25030473 0.28577688]
0.25030473
[0.2298993  0.23392977 0.25032303 0.2858479 ]
0.23392977
[0.22990806 0.23380934 0.25035265 0.2859299 ]
0.2859299
[0.22990185 0.23368616 0.25036252 0.2860495 ]
0.25036252
[0.22988449 0.23356293 0.25041464 0.2861379 ]
0.22988449
[0.22982752 0.23346289 0.2504744  0.2862352 ]
0.2862352
[0.22976214 0.23335806 0.25051138 0.28636846]
0.28636846
[0.22968921 0.23324895 0.25052792 0.28653395]
0.25052792
[0.2296118  0.23313838 0.2505861  0.2866638 ]
0.23313838
[0.22953115 0.2330812  0.25062516 0.28676254]
0.2330812
[0.22946918 0.2329878  0.25067347 0.2868695 ]
0.22946918
[0.22945495 0.23289278 0.25070414 0.28694814]
0.22945495
[0.22940083 0.23281801 0.25074455 0.28703666]
0.25074455
[0.22934042 0.23273842 0.25082424 0.28709695]
0.23273842
[0.22929657 0.23262493 0.25090912 

[0.22286302 0.22740275 0.24471779 0.30501637]
0.24471779
[0.22278486 0.2273838  0.24472873 0.30510265]
0.30510265
[0.22270016 0.2273519  0.24472184 0.3052261 ]
0.3052261
[0.22260968 0.22730839 0.244699   0.305383  ]
0.305383
[0.22251397 0.22725442 0.24466178 0.30556983]
0.24466178
[0.22243845 0.22721691 0.24458565 0.30575898]
0.22721691
[0.22238041 0.22714123 0.24452902 0.30594936]
0.22714123
[0.22233802 0.22703135 0.24448986 0.30614075]
0.24448986
[0.22228917 0.22692129 0.24449728 0.3062922 ]
0.3062922
[0.22223091 0.22680749 0.24448732 0.30647424]
0.30647424
[0.2221642  0.22669032 0.24446177 0.30668372]
0.30668372
[0.22208989 0.22657016 0.24442214 0.30691782]
0.22657016
[0.22203282 0.22642031 0.24439824 0.3071486 ]
0.24439824
[0.22199208 0.22629637 0.24433403 0.3073775 ]
0.3073775
[0.22194113 0.22617017 0.24425964 0.30762908]
0.22617017
[0.22190507 0.22601493 0.24420434 0.3078756 ]
0.22190507
[0.22183086 0.22588529 0.24416623 0.3081176 ]
0.3081176
[0.2217498  0.22575392 0.24411531 0.3

[0.21198949 0.21390498 0.23707171 0.3370338 ]
0.23707171
[0.21190602 0.21377319 0.2371536  0.33716714]
0.33716714
[0.21184485 0.21366876 0.23724438 0.337242  ]
0.21366876
[0.211798   0.21353367 0.23733672 0.33733165]
0.23733672
[0.21174641 0.21340257 0.23746277 0.33738825]
0.21340257
[0.21170813 0.21324359 0.23758684 0.33746138]
0.23758684
[0.21166429 0.21309109 0.23774147 0.33750314]
0.21309109
[0.21161646 0.21299519 0.23787013 0.33751827]
0.23787013
[0.21158284 0.21291834 0.23794268 0.33755612]
0.21158284
[0.21151143 0.21285743 0.2380186  0.33761254]
0.21151143
[0.21140614 0.21281078 0.23809753 0.33768553]
0.33768553
[0.21129744 0.21275476 0.23815155 0.3377963 ]
0.23815155
[0.21120895 0.21271375 0.23815699 0.3379203 ]
0.23815699
[0.21111995 0.21266739 0.23820485 0.33800778]
0.23820485
[0.21103051 0.21261625 0.23829086 0.33806238]
0.33806238
[0.21093607 0.2125561  0.2383513  0.33815655]
0.21093607
[0.21081011 0.21251012 0.23841618 0.33826363]
0.21251012
[0.21068852 0.21250997 0.238463

[0.19481462 0.20524298 0.23101659 0.3689258 ]
0.23101659
[0.19483413 0.20517418 0.23104748 0.36894423]
0.36894423
[0.19483875 0.20509821 0.23105831 0.3690048 ]
0.23105831
[0.19483534 0.20502146 0.23111103 0.36903214]
0.23111103
[0.19482477 0.20494404 0.23120144 0.36902973]
0.23120144
[0.19480775 0.20486608 0.23132563 0.36900052]
0.36900052
[0.19477943 0.20478186 0.23142052 0.3690182 ]
0.3690182
[0.19474095 0.204692   0.231489   0.36907804]
0.231489
[0.19469877 0.2046028  0.2315935  0.36910486]
0.2315935
[0.19465329 0.20451428 0.2317304  0.3691021 ]
0.3691021
[0.19462527 0.20444854 0.23187065 0.3690555 ]
0.3690555
[0.19458708 0.20437536 0.23198001 0.36905754]
0.36905754
[0.19453974 0.20429547 0.23206155 0.3691033 ]
0.19453974
[0.19445668 0.20423046 0.23214394 0.36916894]
0.36916894
[0.19436902 0.20415798 0.23220113 0.3692719 ]
0.23220113
[0.19429767 0.20410107 0.23220956 0.3693917 ]
0.20410107
[0.19423997 0.20400864 0.23222657 0.3695248 ]
0.23222657
[0.19418055 0.20391713 0.23228475 0.3