After multi-armed and contextual bandits, full reinforcement learning problem has to further consider taking observations from the world, and taking actions which provide the optimal reward not just in the present, but over the long run. This is referred to as **Markov Decision Processes(MDPs)**. <U>These environments not only provide rewards and state transitions given actions, but those rewards are also dependent on the state of the environment and the action within that state. These dynamics are also temporal, and can be delayed over time.</U>

* ***Delayed reward*** in CartPole of openai gym:

Keeping the pole in the air as long as possible means moving in ways that will be advantageous for both the present and the future. To accomplish this we will adjust the reward value for each observation-action pair using a function that weighs actions over time.

To take reward over time into account, the form of Policy Gradient we used in the previous tutorials will need a few adjustments. The first of which is that we now need to update our agent with more than one experience at a time. To accomplish this, we will collect experiences in a buffer, and then occasionally use them to update the agent all at once. These sequences of experience are sometimes referred to as rollouts, or experience traces. We can’t just apply these rollouts by themselves however, we will need to ensure that the rewards are properly adjusted by a discount factor

Intuitively this allows each action to be a little bit responsible for not only the immediate reward, but all the rewards that followed. 

In [1]:
import tensorflow as tf
import tensorflow.contrib as tc
import numpy as np
import gym
import matplotlib.pyplot as plt
%matplotlib inline

try:
    xrange = xrange
except:
    xrange = range

In [2]:
env = gym.make('CartPole-v0')

In [3]:
gamma = 0.99

def discount_rewards(r):
    """ take 1D float array of rewards and compute discounted reward """
    discounted_r = np.zeros_like(r)
    running_add = 0
    for t in reversed(xrange(0, r.size)):
        running_add = running_add * gamma + r[t]
        discounted_r[t] = running_add
    return discounted_r

In [60]:
class pg_net():
    def __init__(self, lr, n_obs, n_action, n_hidden, layer_norm):
        self.n_obs = n_obs
        self.n_action = n_action
        self.n_hidden = n_hidden
        self.layer_norm = layer_norm
        
        self.obs, self.output, self.chosen_action = self.build_net()
        
        self.reward_plc = tf.placeholder(shape=[None], dtype=tf.float32)
        self.action_plc = tf.placeholder(shape=[None], dtype=tf.int32)
        
        self.sha = tf.range(0, tf.shape(self.output)[0])
        self.sha2 = tf.range(0, tf.shape(self.output)[0])*tf.shape(self.output)[1]
        self.indexes = tf.range(0, tf.shape(self.output)[0])*tf.shape(self.output)[1] + self.action_plc
        self.resha = tf.reshape(self.output, [-1])
        self.responsible_outputs = tf.gather(tf.reshape(self.output, [-1]), self.indexes)

        loss = -tf.reduce_mean(tf.log(self.responsible_outputs)*self.reward_plc)
        
        tvars = tf.trainable_variables()
        self.gradient_holders = []
        for idx,var in enumerate(tvars):
            placeholder = tf.placeholder(tf.float32,name=str(idx)+'_holder')
            self.gradient_holders.append(placeholder)
        
        self.gradients = tf.gradients(loss, tvars)
        
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        self.update_batch = optimizer.apply_gradients(zip(self.gradient_holders,tvars))
        
    def build_net(self):
        obs_plc = tf.placeholder(shape=[None, self.n_obs], dtype=tf.float32)
        
        x = tf.layers.dense(obs_plc, self.n_hidden)
        x = tf.nn.relu(x)
        if self.layer_norm:
            self.x = tc.layers.layer_norm(x, center=True, scale=True)
        
        x = tf.layers.dense(x, self.n_action)
        if self.layer_norm:
            self.x = tc.layers.layer_norm(x, center=True, scale=True)
        x = tf.nn.softmax(x)
        
        chosen_action = tf.argmax(x, 1)
    
        return obs_plc, x, chosen_action

In [63]:
tf.reset_default_graph() #Clear the Tensorflow graph.

myAgent = pg_net(lr=1e-2, n_obs=4, n_action=2, n_hidden=8, layer_norm=False) #Load the agent.

total_episodes = 5000 #Set total number of episodes to train agent on.
max_ep = 999
update_frequency = 5

init = tf.global_variables_initializer()

# Launch the tensorflow graph
with tf.Session() as sess:
    sess.run(init)
    i = 0
    total_reward = []
    total_lenght = []
        
    gradBuffer = sess.run(tf.trainable_variables())
    for ix,grad in enumerate(gradBuffer):
        gradBuffer[ix] = grad * 0
        
    while i < total_episodes:
        s = env.reset()
        running_reward = 0
        ep_history = []
        for j in range(max_ep):
            a_dist = sess.run(myAgent.output, feed_dict={myAgent.obs:[s]})
            a = np.random.choice(a_dist[0],p=a_dist[0])
            a = np.argmax(a_dist == a)

            s1,r,d,_ = env.step(a) 
            ep_history.append([s,a,r,s1])
            s = s1
            running_reward += r
            if d == True: # per episode, update the network!!!
                ep_history = np.array(ep_history)
                ep_history[:,2] = discount_rewards(ep_history[:,2])
                feed_dict={myAgent.reward_plc:ep_history[:,2],
                        myAgent.action_plc:ep_history[:,1],myAgent.obs:np.vstack(ep_history[:,0])}
                output, resha, sha, sha2, grads, indx, rep  = sess.run([myAgent.output, myAgent.resha, myAgent.sha, myAgent.sha2, myAgent.gradients, myAgent.indexes, myAgent.responsible_outputs], \
                                                feed_dict=feed_dict)
                print('shape1',sha)
                print('shape2',sha2)
                print('action', ep_history[:,1])
                print('resha', resha)
                print('indx', indx)
                print('res', rep)
                print('-------------------------------------------------------')
                for idx,grad in enumerate(grads):
                    gradBuffer[idx] += grad

                if i % update_frequency == 0 and i != 0:
                    feed_dict= dictionary = dict(zip(myAgent.gradient_holders, gradBuffer))
                    _ = sess.run([myAgent.update_batch], feed_dict=feed_dict)
                    for ix,grad in enumerate(gradBuffer):
                        gradBuffer[ix] = grad * 0
                
                total_reward.append(running_reward)
                total_lenght.append(j)
                break

        
            #Update our running tally of scores.
        if i % 100 == 0:
            print(np.mean(total_reward[-100:]))
        i += 1

shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40]
action [0 1 1 0 0 1 1 1 0 0 1 1 0 1 0 1 0 1 0 1 0]
resha [0.50085986 0.4991401  0.49579355 0.5042064  0.49984673 0.50015324
 0.48060802 0.519392   0.49945408 0.5005459  0.49844992 0.5015501
 0.5013579  0.4986421  0.4833899  0.5166101  0.46491128 0.5350887
 0.48413965 0.5158603  0.5038809  0.49611902 0.48602712 0.5139729
 0.46799842 0.53200155 0.48777464 0.51222533 0.4700838  0.5299162
 0.49028343 0.50971663 0.47296348 0.5270365  0.49364758 0.5063524
 0.47673914 0.5232609  0.49798143 0.5020186  0.4815328  0.51846725]
indx [ 0  3  5  6  8 11 13 15 16 18 21 23 24 27 28 31 32 35 36 39 40]
res [0.50085986 0.5042064  0.50015324 0.48060802 0.49945408 0.5015501
 0.4986421  0.5166101  0.46491128 0.48413965 0.49611902 0.5139729
 0.46799842 0.51222533 0.4700838  0.50971663 0.47296348 0.5063524
 0.47673914 0.5020186  0.4815328 ]
---------------------------

shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46
 48 50 52 54 56 58 60 62 64 66 68 70 72 74 76 78]
action [1 0 1 0 1 0 1 0 0 1 1 0 0 0 1 0 1 0 0 1 0 0 0 0 1 0 1 1 0 0 1 0 0 1 1 1 0
 1 1 1]
resha [0.4938124  0.5061876  0.4691258  0.5308742  0.4925341  0.5074659
 0.46850416 0.5314958  0.49150854 0.50849146 0.46977055 0.53022945
 0.4929388  0.50706124 0.47135538 0.5286446  0.49471623 0.5052838
 0.50486743 0.4951325  0.4972213  0.5027787  0.4759028  0.52409714
 0.49960616 0.50039387 0.51773727 0.48226276 0.5112467  0.4887533
 0.5250697  0.47493026 0.5252106  0.47478938 0.5311576  0.4688424
 0.5340545  0.46594548 0.51637316 0.4836268  0.54310673 0.45689324
 0.5294172  0.47058278 0.50570226 0.49429774 0.48409802 0.515902
 0.46146232 0.5385377  0.49028045 0.5097195  0.46544102 0.534559
 0.49235672 0.5076433  0.5198914  0.48010856 0.49

shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22 24 26]
action [1 0 1 1 1 1 0 0 1 1 0 1 1 1]
resha [0.4993648  0.50063527 0.48227933 0.51772064 0.5001525  0.49984744
 0.48242122 0.5175788  0.457702   0.54229796 0.43315217 0.56684786
 0.4089549  0.5910451  0.4333713  0.5666287  0.4585267  0.5414733
 0.43480977 0.5651902  0.41145533 0.58854467 0.43694764 0.5630524
 0.414112   0.585888   0.39180437 0.60819566]
indx [ 1  2  5  7  9 11 12 14 17 19 20 23 25 27]
res [0.50063527 0.48227933 0.49984744 0.5175788  0.54229796 0.56684786
 0.4089549  0.4333713  0.5414733  0.5651902  0.41145533 0.5630524
 0.585888   0.60819566]
-------------------------------------------------------
shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22 24]
action [1 1 1 0 1 1 1 0 1 1 1 1 1]
resha [0.47499678 0.5250032  0.46946853 0.53053147 0.4446277  0.55537224
 0.42004532 0.5799547  0.4458803  0.5541197  0.420829   0.57917106
 0.396

shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46
 48 50 52 54]
action [0 0 1 1 1 1 1 0 0 1 0 0 0 1 0 1 1 1 1 0 1 1 1 0 1 0 1 0]
resha [0.487787   0.51221305 0.4438792  0.55612075 0.39676788 0.60323215
 0.43782687 0.5621731  0.47689483 0.52310514 0.46700612 0.53299385
 0.43974358 0.5602565  0.41281763 0.58718234 0.4386707  0.5613293
 0.46558854 0.53441143 0.43893468 0.5610653  0.46589097 0.53410906
 0.4729305  0.52706945 0.4301355  0.56986445 0.47335815 0.52664185
 0.430515   0.569485   0.47372538 0.5262746  0.46520272 0.5347973
 0.43842822 0.5615717  0.41200653 0.5879935  0.43835574 0.56164426
 0.41222203 0.5877779  0.38664684 0.6133532  0.3618149  0.63818514
 0.3877602  0.6122398  0.36344767 0.6365524  0.39012882 0.6098712
 0.36635864 0.63364136]
indx [ 0  2  5  7  9 11 13 14 16 19 20 22 24 27 28 31 33 35 37 38 41 43 45 46
 49 50 53 54]
res [0.487787   0.4438792

 0.70455664 0.6703635  0.36536303 0.30957457]
-------------------------------------------------------
shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22 24 26 28 30 32]
action [0 1 1 1 0 0 1 0 1 0 1 1 1 1 1 0 1]
resha [0.48891434 0.5110857  0.44418362 0.55581635 0.49329758 0.50670236
 0.4796827  0.5203173  0.45557982 0.5444202  0.48077545 0.5192246
 0.504712   0.49528804 0.48247984 0.5175202  0.50697726 0.49302274
 0.48458198 0.515418   0.509693   0.49030703 0.48714653 0.5128535
 0.4639379  0.53606206 0.4410555  0.55894446 0.4186688  0.5813312
 0.3969369  0.6030631  0.42365488 0.57634515]
indx [ 0  3  5  7  8 10 13 14 17 18 21 23 25 27 29 30 33]
res [0.48891434 0.55581635 0.50670236 0.5203173  0.45557982 0.48077545
 0.49528804 0.48247984 0.49302274 0.48458198 0.49030703 0.5128535
 0.53606206 0.55894446 0.5813312  0.3969369  0.57634515]
-------------------------------------------------------
shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11

shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46
 48 50 52 54 56 58 60 62 64 66 68]
action [1 0 0 1 0 1 0 1 1 0 0 1 0 1 1 0 0 1 1 0 0 1 0 1 1 1 0 0 0 0 1 0 1 0 1]
resha [0.46332374 0.5366762  0.47992527 0.5200747  0.46287197 0.53712803
 0.40771008 0.5922899  0.4615407  0.5384593  0.40573496 0.59426504
 0.4588049  0.54119515 0.4023835  0.5976165  0.45460922 0.5453908
 0.47727087 0.5227291  0.45038807 0.5496119  0.39370102 0.606299
 0.4450084  0.5549916  0.3876936  0.61230636 0.43784907 0.5621509
 0.46439415 0.53560585 0.43029895 0.569701   0.3727223  0.6272777
 0.4211528  0.5788472  0.452226   0.54777396 0.4130605  0.5869395
 0.35757852 0.6424214  0.40470672 0.5952932  0.34888482 0.6511152
 0.39457533 0.60542464 0.43270853 0.56729144 0.4435172  0.55648273
 0.42750245 0.5724976  0.37539628 0.62460375 0.32021824 0.67978173
 0.26855627 0.73144376

shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46
 48 50 52 54 56 58]
action [0 0 1 0 0 1 1 1 0 1 1 0 0 1 0 1 1 1 1 1 1 1 1 1 0 0 1 0 1 1]
resha [0.46707946 0.53292054 0.4176137  0.5823863  0.3671304  0.6328696
 0.41771302 0.582287   0.36575618 0.63424385 0.31563056 0.68436944
 0.36006063 0.63993937 0.40618953 0.5938105  0.4516051  0.5483949
 0.39646572 0.60353434 0.44216526 0.55783474 0.45730847 0.5426915
 0.4319552  0.5680448  0.37480167 0.6251983  0.41948783 0.5805122
 0.36199838 0.6380016  0.4048631  0.5951369  0.43268394 0.56731606
 0.43862918 0.56137085 0.4176772  0.5823228  0.39312786 0.60687214
 0.36901107 0.63098896 0.34535602 0.654644   0.32219917 0.67780083
 0.29958633 0.70041364 0.323171   0.67682904 0.3469497  0.6530503
 0.32185853 0.67814153 0.34459266 0.6554073  0.3188609  0.6811391 ]
indx [ 0  2  5  6  8 11 13 15 16 19 21 22 24 27 28 31 33 35

shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22]
action [1 0 1 1 1 1 1 1 1 1 1 0]
resha [0.46932128 0.53067875 0.45791918 0.5420809  0.46791655 0.53208345
 0.45743227 0.54256773 0.4301276  0.56987244 0.40335953 0.5966405
 0.3770921  0.62290794 0.35138342 0.6486166  0.3262973  0.67370266
 0.30190614 0.69809383 0.2782912  0.7217088  0.2555432  0.7444568 ]
indx [ 1  2  5  7  9 11 13 15 17 19 21 22]
res [0.53067875 0.45791918 0.53208345 0.54256773 0.56987244 0.5966405
 0.62290794 0.6486166  0.67370266 0.69809383 0.7217088  0.2555432 ]
-------------------------------------------------------
shape1 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18]
shape2 [ 0  2  4  6  8 10 12 14 16 18 20 22 24 26 28 30 32 34 36]
action [0 1 1 0 0 1 1 1 0 1 0 1 1 1 1 0 1 1 1]
resha [0.47878355 0.5212164  0.432659   0.56734097 0.4793469  0.52065307
 0.4681527  0.5318473  0.4805325  0.51946753 0.43460655 0.56539345
 0.481171   0.518829   0.4683128  0.53168714 0.

KeyboardInterrupt: 