In [10]:
import time
import numpy as np

import tensorflow as tf
import gym

tf.__version__

'2.2.0-rc3'

In [2]:
def reward_to_go(rews):
    n = len(rews)
    rtgs = np.zeros_like(rews, dtype='float32')
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
        
    return rtgs

def random_sample(logits):
    return tf.random.categorical(logits, 1)

def vstack(arr):
    return tf.squeeze(tf.stack(arr), axis=1)

In [68]:
class net(tf.keras.Model):
    '''tensorflow model with linear activation on output level'''

    def __init__(self, units, activations):
        assert len(units) == len(activations) + 2
        super().__init__()
        
        self.units = units
        self.activations = activations
        self.build_graph()
        
    def build_graph(self):
        self.graph = [tf.keras.layers.Dense(units, activation=activation)
                       for (units, activation) in zip(self.units[1:], self.activations + ['linear'])
                      ]
        
    def call(self, x):
        h = x
        for layer in self.graph:
            h = layer(h)
        return h
    
class VPG(object):
    '''Vanilla Policy Gradient'''
    def __init__(self, env, model, optimizer, render=False):
        assert isinstance(env.observation_space, gym.spaces.Box), \
            "This example only works for envs with continuous state spaces."
        assert isinstance(env.action_space, gym.spaces.Discrete), \
            "This example only works for envs with discrete action spaces."
        self.env = env
        self.model = model
        self.optimizer = optimizer
        
        self.render = render
        
        self.obs_dim = env.observation_space.shape[0]
        self.n_acts = env.action_space.n
        
    def loss_object(self, actions, logits, weights):
        log_probs = tf.reduce_sum(tf.one_hot(actions, self.n_acts) * tf.nn.log_softmax(logits),
                                  axis=1
                                 )
        return -tf.reduce_mean(weights * log_probs)
    
    def train_one_epoch(self, batch_size):
        with tf.GradientTape() as tape:
            # make some empty lists for logging.
            batch_logits = []          # for observations
            batch_acts = []         # for actions
            batch_weights = []      # for R(tau) weighting in policy gradient
            batch_rets = []         # for measuring episode returns
            batch_lens = []         # for measuring episode lengths

            # reset episode-specific variables
            obs = self.env.reset()       # first obs comes from starting distribution
            done = False            # signal from environment that episode is over
            ep_rews = []            # list for rewards accrued throughout ep

            # render first episode of each epoch
            finished_rendering_this_epoch = False

            # collect experience by acting in the environment with current policy
            while True:
                # rendering
                if (not finished_rendering_this_epoch) and self.render:
                    env.render()

                # act in the environment
                logits = self.model(obs.reshape(1,-1), training=True)
                act = random_sample(logits)[0, 0].numpy()
                obs, rew, done, _ = self.env.step(act)

                # save action, reward
                batch_logits.append(logits)
                batch_acts.append(act)
                ep_rews.append(rew)

                if done:
                    # if episode is over, record info about episode
                    ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                    batch_rets.append(ep_ret)
                    batch_lens.append(ep_len)

                    # the weight for each logprob(a|s) is R(tau)
                    batch_weights += list(reward_to_go(ep_rews))

                    # reset episode-specific variables
                    obs, done, ep_rews = env.reset(), False, []

                    # won't render again this epoch
                    finished_rendering_this_epoch = True

                    # end experience loop if we have enough of it
                    if len(batch_logits) > batch_size:
                        break

            loss = self.loss_object(batch_acts, vstack(batch_logits), tf.stack(batch_weights))

            gradients = tape.gradient(loss, self.model.trainable_variables)
            self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        return loss, batch_rets, batch_lens

    def train(self, epochs=50, batch_size=5000, verbose=1):
        for i in range(epochs):
            batch_loss, batch_rets, batch_lens = self.train_one_epoch(batch_size)
            if verbose > 0:
                print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                        (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))
                
    def act(self, observation):
        logits = self.model(observation.reshape(1,-1), training=False)
        return random_sample(logits)[0, 0].numpy()
    
    def demo(self, n_steps=200):
        observation = self.env.reset()
        for _ in range(n_steps):
            env.render()
            action = self.act(observation)
            observation, reward, done, info = self.env.step(action)

            if done:
                observation = self.env.reset()
                time.sleep(1)


In [73]:
env_name = 'CartPole-v0'
env = gym.make(env_name)

In [74]:
model = net([env.observation_space.shape[0], 32, env.action_space.n], ['tanh'])
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)

In [None]:
vpg = VPG(env, model, optimizer)
vpg.train()



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

epoch:   0 	 loss: -10.118 	 return: 22.044 	 ep_len: 22.044
epoch:   1 	 loss: -8.190 	 return: 18.266 	 ep_len: 18.266
