In [None]:
%pip install git+https://github.com/coax-dev/coax.git@main --quiet

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./data/tensorboard

In [1]:
!pip install shutup
##At the top of the code
import shutup;
shutup.please()



In [2]:
# Run this cell to fix rendering errors.
import os
os.environ['SDL_VIDEODRIVER'] = 'dummy'

In [49]:
import gymnasium
import jax
import jax.numpy as jnp
import coax
import haiku as hk
from numpy import prod
import optax


# the name of this script
name = 'ppo'

# the Pendulum MDP
env = gymnasium.make('InvertedDoublePendulum-v4', render_mode='rgb_array')
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")


def func_pi(S, is_training):
    shared = hk.Sequential((
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
    ))
    mu = hk.Sequential((
        shared,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    logvar = hk.Sequential((
        shared,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
        hk.Reshape(env.action_space.shape),
    ))
    return {'mu': mu(S), 'logvar': logvar(S)}


def func_v(S, is_training):
    seq = hk.Sequential((
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(8), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
    ))
    return seq(S)


# define function approximators
pi = coax.Policy(func_pi, env)
v = coax.V(func_v, env)


# target network
pi_targ = pi.copy()


# experience tracer
tracer = coax.reward_tracing.NStep(n=5, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=512)


# policy regularizer (avoid premature exploitation)
policy_reg = coax.regularizers.EntropyRegularizer(pi, beta=0.01)


# updaters
simpletd = coax.td_learning.SimpleTD(v, optimizer=optax.adam(1e-3))
ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=policy_reg, optimizer=optax.adam(1e-4))


# train
while env.T < 500000:
    s, info = env.reset()

    for t in range(env.spec.max_episode_steps):
        a, logp = pi_targ(s, return_logp=True)
        s_next, r, done, truncated, info = env.step(a)
    #include for inverted pendulum
        if done:  r=-10000
        # trace rewards
        tracer.add(s, a, r, done or truncated, logp)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= buffer.capacity:
            for _ in range(int(4 * buffer.capacity / 32)):  # 4 passes per round
                transition_batch = buffer.sample(batch_size=32)
                metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
                metrics_pi = ppo_clip.update(transition_batch, td_error)
                env.record_metrics(metrics_v)
                env.record_metrics(metrics_pi)

            buffer.clear()
            pi_targ.soft_update(pi, tau=0.1)

        if done or truncated:
            break

        s = s_next

    # generate an animated GIF to see what's going on
    if env.period(name='generate_gif', T_period=1000) and env.T > 5000:
        T = env.T - env.T % 1000  # round to 10000s
        coax.utils.generate_gif(
            env=env, policy=pi, filepath=f"./data/gifs/{name}/T{T:08d}.gif")


INFO:TrainMonitor:ep: 1,	T: 11,	G: 91.5,	avg_r: 9.15,	avg_G: 91.5,	t: 10,	dt: 28.426ms
INFO:TrainMonitor:ep: 2,	T: 20,	G: 73.4,	avg_r: 9.18,	avg_G: 82.4,	t: 8,	dt: 2.522ms
INFO:TrainMonitor:ep: 3,	T: 27,	G: 54.2,	avg_r: 9.03,	avg_G: 73,	t: 6,	dt: 2.747ms
INFO:TrainMonitor:ep: 4,	T: 32,	G: 36.1,	avg_r: 9.02,	avg_G: 63.8,	t: 4,	dt: 3.249ms
INFO:TrainMonitor:ep: 5,	T: 38,	G: 45.6,	avg_r: 9.13,	avg_G: 60.2,	t: 5,	dt: 3.205ms
INFO:TrainMonitor:ep: 6,	T: 44,	G: 45.6,	avg_r: 9.11,	avg_G: 57.7,	t: 5,	dt: 2.607ms
INFO:TrainMonitor:ep: 7,	T: 51,	G: 54.7,	avg_r: 9.12,	avg_G: 57.3,	t: 6,	dt: 3.506ms
INFO:TrainMonitor:ep: 8,	T: 60,	G: 72.9,	avg_r: 9.11,	avg_G: 59.2,	t: 8,	dt: 3.067ms
INFO:TrainMonitor:ep: 9,	T: 69,	G: 73.6,	avg_r: 9.2,	avg_G: 60.8,	t: 8,	dt: 3.248ms
INFO:TrainMonitor:ep: 10,	T: 80,	G: 91.8,	avg_r: 9.18,	avg_G: 63.9,	t: 10,	dt: 2.900ms
INFO:TrainMonitor:ep: 11,	T: 93,	G: 111,	avg_r: 9.23,	avg_G: 68.6,	t: 12,	dt: 2.735ms
INFO:TrainMonitor:ep: 12,	T: 100,	G: 54.5,	avg_r: 9.08,	avg_G: 

KeyboardInterrupt: 

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp

env = gymnasium.make("Pendulum", render_mode="human")
observation, info = env.reset()
s=observation
for i in range(500):
    a, logp = pi(s, return_logp=True)
    print(i, a)
    obs, reward, terminated, truncated, info = env.step(a)
    s=obs
    # clear_output(wait=True)
    # plt.imshow( env.render(mode='rgb_array') )
    # plt.show()
    image = env.render()
    #img = cv2.imread('a.jpg')
    if (terminated or truncated):  break
env.close()