In [1]:

from ppo_actor_model import PPOActorModel
from ppo_critic_model import PPOCriticModel
from ppo_clipped_agent import PPOAgent

In [2]:
import gym

env = gym.make('MountainCar-v0')

# hyperparameters
iterations                = 400
n_steps                   = 80
n_epochs                  = 10
batch_size                = 8
n_actions                 = env.action_space.n
learning_rate             = 1e-4
gamma                     = 0.99
gae_lambda                = 0.95
clip_epsilon              = 0.2
critic_loss_coefficient   = 0.5
entropy_bonus_coefficient = 0 #0.001

policy = PPOActorModel(n_actions, entropy_bonus_coefficient, learning_rate, epsilon=clip_epsilon)
critic = PPOCriticModel(critic_loss_coefficient, learning_rate, epsilon=clip_epsilon)
agent  = PPOAgent(n_actions, policy, critic, gamma, gae_lambda)

agent.train(env, iterations, n_steps, n_epochs, batch_size)

Iteration    0 score -200 avg_score -200.0 actor_loss 13.421 critic_loss 95.547


In [None]:
agent.train(env, iterations, n_steps, n_epochs, batch_size)

In [15]:

done = False
score = 0
action_list = []
ap_list = []
state = env.reset()
while not done:
    action, ap = agent.choose_action(state)
    action_list.append(action)
    ap_list.append(ap)
    next_state, reward, done, info = env.step(action)
    score += reward
    state = next_state

print(score, action_list)

10.0 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [4]:
import tensorflow as tf
tf_ap_list = tf.convert_to_tensor(ap_list, dtype=tf.float32)

In [5]:
print(tf_ap_list)

tf.Tensor(
[0.999738   0.99996006 0.999993   0.99999857 0.99999964 0.9999999
 1.         1.         1.        ], shape=(9,), dtype=float32)


In [16]:
agent.policy.trainable_variables

[<tf.Variable 'ppo_actor_model/dense/kernel:0' shape=(4, 128) dtype=float32, numpy=
 array([[-2.39920184e-01, -2.09141970e-01,  1.00247018e-01,
         -6.19176626e-02, -7.35186972e-03,  1.89309031e-01,
          2.28419289e-01, -5.15020974e-02,  2.78319381e-02,
          7.03939274e-02, -1.72077805e-01,  1.58341125e-01,
         -1.38560221e-01, -9.80596393e-02, -2.72068650e-01,
          6.36870712e-02,  1.13004521e-01, -1.56172052e-01,
          8.63138810e-02, -1.65591076e-01, -1.62730485e-01,
         -6.85890242e-02,  5.36173321e-02, -3.36036921e-01,
         -2.75272399e-01, -5.35711311e-02,  1.71537995e-01,
         -5.32090403e-02, -2.06921875e-01,  2.25400567e-01,
          1.82146683e-01, -5.42874299e-02, -2.10020840e-01,
          2.63293922e-01, -1.64571673e-01, -4.19965200e-02,
          2.04847738e-01,  1.52621508e-01, -1.21874265e-01,
         -6.08422495e-02, -3.22520472e-02,  8.23470205e-02,
          1.79520711e-01,  1.75210401e-01,  1.40816882e-01,
          2.8807