In [1]:
import os
import tempfile
import time

import numpy as np
import tensorflow as tf
import gym
import pybullet_envs
import tensorflow_probability as tfp
tfd = tfp.distributions



  "The `registry.env_specs` property along with `EnvSpecTree` is deprecated. Please use `registry` directly as a dictionary instead."


In [2]:
#EXPERIMENT PARAMETERS: MODIFY THIS

env_name = 'CartPole-v1'
seed = 0
n_critics = 3
batch_size = 10000
epochs = 200
learning_rate = 3e-4
opt = tf.optimizers.Adam(learning_rate)
γ = .99
λ = 0.97
kl_target = 0.01

#run variables
norm_rew = False
norm_obs = False
kl_stop = False
kl_rollback = False
bootstrap = False

#save directory (for kl_rollback optimization)
os.makedirs("saves", exist_ok=True)
save_dir = tempfile.mkdtemp(dir='saves', prefix=env_name)

    

In [3]:
env = gym.make(env_name)
obs_spc = env.observation_space
act_spc = env.action_space

if act_spc.shape:
    env = gym.wrappers.ClipAction(env)

if norm_obs:
    env = gym.wrappers.NormalizeObservation(env)
    env = gym.wrappers.TransformObservation(env, lambda obs: tf.clip_by_value(obs, -10, 10))

#seeding
tf.random.set_seed(seed)
env.seed(seed)
act_spc.seed(seed)
obs_spc.seed(seed)

# policy/actor model
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(120, activation='relu', input_shape=obs_spc.shape),
    tf.keras.layers.Dense(84, activation='relu'),
    tf.keras.layers.Dense(act_spc.shape[0] if act_spc.shape else act_spc.n)
])
if act_spc.shape:
    model.log_std = tf.Variable(tf.fill(env.action_space.shape, -0.5))
model.summary()

# value/critic model
critics = list()

for _ in range(n_critics):
    value_model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=obs_spc.shape),
        tf.keras.layers.Dense(1)
    ])
    value_model.compile('adam', loss='MSE')
    critics.append(value_model)

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 120)               600       
                                                                 
 dense_1 (Dense)             (None, 84)                10164     
                                                                 
 dense_2 (Dense)             (None, 2)                 170       
                                                                 
Total params: 10,934
Trainable params: 10,934
Non-trainable params: 0
_________________________________________________________________


  "Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "


In [4]:
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd:
    """Tracks the mean, variance and count of values."""

    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, epsilon=1e-4, shape=()):
        """Tracks the mean, variance and count of values."""
        self.mean = np.zeros(shape, "float64")
        self.var = np.ones(shape, "float64")
        self.count = epsilon

    def update(self, x):
        """Updates the mean, var and count from a batch of samples."""
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        """Updates from batch mean, variance and count moments."""
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count
        )

def update_mean_var_count_from_moments(
    mean, var, count, batch_mean, batch_var, batch_count
):
    """Updates the mean, var and count using the previous mean, var, count and batch values."""
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count

In [5]:
class Buffer(object):
    def __init__(self, obs_spc, act_spc, model, critics, size, gam=0.99, lam=0.97):
        self.ptr = 0
        self.last_idx = 0
        self.size = size
        self.continuous = bool(act_spc.shape)

        self.model = model
        self.critics = critics

        self.obs_buf = tf.TensorArray(obs_spc.dtype, size)
        self.act_buf = tf.TensorArray(act_spc.dtype, size)
        self.rew_buf = tf.TensorArray(tf.float32, size)
        self.prob_buf = tf.TensorArray(tf.float32, size)

        self.rets = []
        self.ret_rms = RunningMeanStd(shape=())
        self.lens = []

        self.V_hats = tf.TensorArray(tf.float32, size)
        self.gae = tf.TensorArray(tf.float32, size)

        self.gam = gam
        self.lam = lam

    # @tf.function
    def store(self, obs, act, rew, prob):
        self.obs_buf = self.obs_buf.write(self.ptr, obs)
        self.act_buf = self.act_buf.write(self.ptr, act)
        self.rew_buf = self.rew_buf.write(self.ptr, rew)
        self.prob_buf = self.prob_buf.write(self.ptr, prob)
        self.ptr += 1

    # @tf.function
    def finish_path(self, last_obs=None):
        current_episode = tf.range(self.last_idx, self.ptr)

        #bootstrapping the remaining values if the episode was interrupted
        if last_obs == None:
            last_val = 0
        else:
            predictions = [tf.squeeze(value_model((tf.expand_dims(last_obs, 0)))) for value_model in self.critics]
            last_val = tf.math.reduce_mean(predictions)

        # last_val = tf.squeeze(self.value_model(tf.expand_dims(last_obs, 0))) if last_obs is not None else 0

        length = self.ptr - self.last_idx
        ep_rew = self.rew_buf.gather(current_episode)
        ret = tf.reduce_sum(ep_rew) + last_val
        self.lens.append(length)
        self.rets.append(ret)

        #(attempt at) scaling the rewards
        if norm_rew:
            self.ret_rms.update(np.array(self.rets))
            ep_rew = ep_rew / tf.sqrt(tf.cast(self.ret_rms.var, tf.float32) + 1e-8)

        # v_hats = discounted cumulative sum
        discounts = tf.math.cumprod(tf.fill(ep_rew.shape, self.gam), exclusive=True)
        v_hats = tf.math.cumsum(discounts * ep_rew, reverse=True)


        self.V_hats = self.V_hats.scatter(current_episode, v_hats)

        #Vs = tf.squeeze(value_model(self.obs_buf.gather(current_episode)), axis=1)

        predictions = [tf.squeeze(value_model(self.obs_buf.gather(current_episode)), axis=1) for value_model in self.critics]
        Vs = tf.math.reduce_mean(predictions, axis=0)
        Vsp1 = tf.concat([Vs[1:], [last_val]], axis=0)
        deltas = self.rew_buf.gather(current_episode) + self.gam * Vsp1 - Vs

        # compute the advantage function (gae)
        discounts = tf.math.cumprod(tf.fill(deltas.shape, self.gam * self.lam), exclusive=True)
        gae = tf.math.cumsum(discounts * deltas, reverse=True)

        #Normalise the advantage
        gae = (gae - tf.math.reduce_mean(gae)) / (tf.math.reduce_std(gae) + 1e-8)

        self.gae = self.gae.scatter(current_episode, gae)

        self.last_idx = self.ptr

        if self.ptr == self.size:
            self.obs_buf = self.obs_buf.stack()
            self.act_buf = self.act_buf.stack()
            self.rew_buf = self.rew_buf.stack()
            self.prob_buf = self.prob_buf.stack()

            self.V_hats = self.V_hats.stack()
            self.gae = self.gae.stack()

    def approx_kl(self):
        obs, act, logprob = self.obs_buf, self.act_buf, self.prob_buf

        if self.continuous:
            dist = tfd.MultivariateNormalDiag(model(obs), tf.exp(self.model.log_std))
        else:
            dist = tfd.Categorical(logits=model(obs))

        new_logprob = dist.log_prob(act)

        return tf.reduce_mean(logprob - new_logprob)

    # @tf.function
    def loss(self):
        eps = 0.1
        obs, act, adv, logprob = self.obs_buf, self.act_buf, self.gae, self.prob_buf

        if self.continuous:
            dist = tfd.MultivariateNormalDiag(model(obs), tf.exp(self.model.log_std))
        else:
            dist = tfd.Categorical(logits=model(obs))

        new_logprob = dist.log_prob(act)

        mask = tf.cast(adv >= 0, tf.float32)
        epsilon_clip = mask * (1 + eps) + (1 - mask) * (1 - eps)
        ratio = tf.exp(new_logprob - logprob)

        return -tf.reduce_mean(tf.minimum(ratio * adv, epsilon_clip * adv))

    #%%

@tf.function
def action(model, obs, env):
    est = tf.squeeze(model(tf.expand_dims(obs, 0)), axis=0)
    if env.action_space.shape:
        dist = tfd.MultivariateNormalDiag(est, tf.exp(model.log_std))
    else:
        dist = tfd.Categorical(logits=est, dtype=env.action_space.dtype)

    action = dist.sample()
    logprob = tf.reduce_sum(dist.log_prob(action))

    return action, logprob

In [6]:
def save_model(model, save_path):
    ckpt = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(ckpt, save_path, max_to_keep=None)
    manager.save()

def load_model(model, load_path):
    ckpt = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(ckpt, load_path, max_to_keep=None)
    ckpt.restore(manager.latest_checkpoint)
    print("Restoring from {}".format(manager.latest_checkpoint))

In [7]:
def run_one_episode(env, buf):
    obs_dtype = env.observation_space.dtype

    obs = env.reset()
    obs = tf.cast(obs, obs_dtype)
    done = False

    for i in range(buf.ptr, buf.size):
        act, prob = action(buf.model, obs, env)
        new_obs, rew, done, _ = env.step(act.numpy())

        rew = tf.cast(rew, 'float32')

        buf.store(obs, act, rew, prob)
        obs = tf.cast(new_obs, obs_dtype)

        if done:
            break

    critic_start = time.time()
    if done:
        buf.finish_path()
    else:
        buf.finish_path(obs)

    return time.time() - critic_start

In [8]:
def train_one_epoch(env, batch_size, model, critics, γ, λ, save_dir):
    obs_spc = env.observation_space
    act_spc = env.action_space

    batch = Buffer(obs_spc, act_spc, model, critics, batch_size, gam=γ, lam=λ)
    start_time = time.time()

    critic_time = 0
    while batch.ptr < batch.size:
        critic_time += run_one_episode(env, batch)

    train_start_time = time.time()

    var_list = list(model.trainable_weights)
    if act_spc.shape:
        var_list.append(model.log_std)

    for i in range(80):
        save_model(model, save_dir)
        opt.minimize(batch.loss, var_list=var_list)

        # do we want early stopping?
        if not kl_stop:
            continue

        if batch.approx_kl() > 1.5 * kl_target:
            print(f"Early stopping at step {i}")
            # rollback if asked to
            if kl_rollback:
                load_model(model, save_dir)
            break

    train_time = time.time() - train_start_time
    run_time = train_start_time - start_time

    print('run time', run_time, 'critic time (included in run time):', critic_time, 'train time', train_time)
    print('AvgEpRet:', tf.reduce_mean(batch.rets).numpy())

    for i in range(len(critics)):
        bootstrap_value = 0.9 if bootstrap else 1
        mask = tf.random.uniform([batch.size]) < bootstrap_value
        masked_obs = tf.boolean_mask(batch.obs_buf, mask)
        masked_vhats = tf.boolean_mask(batch.V_hats, mask)
        hist = critics[i].fit(batch.obs_buf.numpy(), batch.V_hats.numpy(), epochs=80, steps_per_epoch=1, verbose=0)
        tf.reduce_mean(hist.history['loss']).numpy()

    return batch.rets, batch.lens

In [9]:
def train(epochs, env, batch_size, model, critics, γ, λ, save_dir):
    for i in range(1, epochs + 1):
        start_time = time.time()
        print('Epoch: ', i)
        batch_loss = train_one_epoch(env, batch_size, model, critics, γ, λ, save_dir)
        now = time.time()

In [None]:
train(epochs, env, batch_size, model, critics, γ, λ, save_dir)

Epoch:  1


  return _tree.flatten(structure)


run time 15.501555919647217 critic time (included in run time): 4.530521392822266 train time 3.1431703567504883
AvgEpRet: 21.691923
Epoch:  2
run time 14.278815984725952 critic time (included in run time): 3.7764406204223633 train time 2.9412949085235596
AvgEpRet: 25.512894
Epoch:  3
run time 13.319865703582764 critic time (included in run time): 3.192547559738159 train time 2.9303970336914062
AvgEpRet: 33.564663
Epoch:  4
