In [None]:
import gym
import tensorflow as tf
# from VPG import VPG_agent
from PPO import PPO_agent
import time
import json
import os
import sys

%matplotlib notebook

In [None]:
run_name = 'cartpole'
cfg_fp = os.path.join('configs', run_name + '.json')
with open(cfg_fp, 'r') as f:
    config = json.load(f)

In [None]:
env_name = config['env']
env = gym.make(env_name).env if 'use_raw_env' in config else gym.make(env_name)

In [None]:
print(env.reset())
print(env.action_space)

In [None]:
if env_name == "CartPole-v0":   # Find a way to put this in config
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(4,)),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(2, activation='softmax')
    ])
    value = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(4,)),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(1, activation=None)
    ])
elif env_name == "MountainCar-v0":
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(shape=(2,)),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(3, activation='softmax')
    ])
elif env_name == "Acrobot-v1":
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(shape=(6,)),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(48, activation='relu'),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(3, activation='softmax')
    ])
    value = tf.keras.Sequential([
        tf.keras.Input(shape=(6,)),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(48, activation='relu'),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(1, activation=None)
    ])
elif env_name == "gym_snake:snake-v0":
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(15, 15, 3)),
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(4, activation='softmax')
    ])
    value = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(15, 15, 3)),
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dense(1, activation=None)
    ])
elif env_name == "Taxi-v3":
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(1,)),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(6, activation='softmax')
    ])
    value = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(1,)),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(1, activation=None)
    ])

In [None]:
agent = PPO_agent(
    model,
    value,
    env=env,
    learning_rate=config['learning_rate'],
    minibatch_size=config['minibatch_size'],
    env_name=config['env_name']
)

In [None]:
t_max = config['t_max']

In [None]:
agent.train(epochs=config['train_epochs'], t_max=t_max)

In [None]:
# env = gym.make(env_name) # .env
obs = agent.preprocess(env.reset())
reward = 0
for i in range(t_max):
    print(agent.get_policy(obs))
    act = agent.get_action(obs, greedy=True)[0]
    obs, r, dn, info = env.step(agent.action_wrapper(act))
    env.render()
    time.sleep(0.005)
    obs = agent.preprocess(obs)
    reward += r
    if dn:
        break

print("Total reward: {}".format(reward), file=sys.stderr)
env.close()