## NOTICE

This PPO implementation is modified from [PureJaxRL](https://github.com/luchris429/purejaxrl):

- https://github.com/luchris429/purejaxrl

Please refer to their work if you use this notebook in your research.

TODO:

- vectorize original MinAtar validation

In [None]:
!pip install pgx pgx-minatar dm-haiku optax distrax

In [2]:
!nvidia-smi

Mon Jun 19 04:15:25 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    42W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import sys
import jax
import jax.numpy as jnp
import haiku as hk
import numpy as np
import optax
import distrax
import pgx

from typing import NamedTuple, Literal
from dataclasses import dataclass
from rich.progress import track

print(f"{pgx.__version__=}")

pgx.__version__='0.9.0'


In [4]:
@dataclass
class PPOConfig:
    ENV_NAME: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
    ] = "minatar-breakout"
    SEED: int = 0
    LR: float = 0.0003
    NUM_ENVS: int = 4096
    NUM_EVAL_ENVS: int = 100
    NUM_STEPS: int = 128
    TOTAL_TIMESTEPS: int = 20000000
    UPDATE_EPOCHS: int = 3
    MINIBATCH_SIZE: int = 4096
    GAMMA: float = 0.99
    GAE_LAMBDA: float = 0.95
    CLIP_EPS: float = 0.2
    ENT_COEF: float = 0.01
    VF_COEF: float = 0.5
    MAX_GRAD_NORM: float = 0.5
    ACTIVATION: str = "tanh"

args = PPOConfig()
print(args, file=sys.stderr)

NUM_UPDATES = args.TOTAL_TIMESTEPS // args.NUM_ENVS // args.NUM_STEPS
NUM_MINIBATCHES = args.NUM_ENVS * args.NUM_STEPS // args.MINIBATCH_SIZE

PPOConfig(ENV_NAME='minatar-breakout', SEED=0, LR=0.0003, NUM_ENVS=4096, NUM_EVAL_ENVS=100, NUM_STEPS=128, TOTAL_TIMESTEPS=20000000, UPDATE_EPOCHS=3, MINIBATCH_SIZE=4096, GAMMA=0.99, GAE_LAMBDA=0.95, CLIP_EPS=0.2, ENT_COEF=0.01, VF_COEF=0.5, MAX_GRAD_NORM=0.5, ACTIVATION='tanh')


In [5]:
class ActorCritic(hk.Module):
    def __init__(self, num_actions, activation="tanh"):
        super().__init__()
        self.num_actions = num_actions
        self.activation = activation
        assert activation in ["relu", "tanh"]

    def __call__(self, x):
        x = x.astype(jnp.float32)
        if self.activation == "relu":
            activation = jax.nn.relu
        else:
            activation = jax.nn.tanh
        x = hk.Conv2D(32, kernel_shape=2)(x)
        x = jax.nn.relu(x)
        x = hk.avg_pool(x, window_shape=(2, 2),
                        strides=(2, 2), padding="VALID")
        x = x.reshape((x.shape[0], -1))  # flatten
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        actor_mean = hk.Linear(64)(x)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(64)(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = hk.Linear(self.num_actions)(actor_mean)

        critic = hk.Linear(64)(x)
        critic = activation(critic)
        critic = hk.Linear(64)(critic)
        critic = activation(critic)
        critic = hk.Linear(1)(critic)

        return actor_mean, jnp.squeeze(critic, axis=-1)


def forward_fn(x, is_eval=False):
    net = ActorCritic(env.num_actions, activation="tanh")
    logits, value = net(x)
    return logits, value


forward = hk.without_apply_rng(hk.transform(forward_fn))


optimizer = optax.chain(optax.clip_by_global_norm(
    args.MAX_GRAD_NORM), optax.adam(args.LR, eps=1e-5))

In [6]:
env = pgx.make(str(args.ENV_NAME))
print(f"{env.id=}")
print(f"{env.version=}")
print(f"{pgx.__version__=}")

def auto_reset(step_fn, init_fn):
    def wrapped_step_fn(state, action):
        state = jax.lax.cond(
            (state.terminated | state.truncated),
            lambda: state.replace(  # type: ignore
                terminated=jnp.bool_(False),
                truncated=jnp.bool_(False),
                rewards=jnp.zeros_like(state.rewards),
            ),
            lambda: state,
        )
        state = step_fn(state, action)
        state = jax.lax.cond(
            (state.terminated | state.truncated),
            # state is replaced by initial state,
            # but preserve (terminated, truncated, reward)
            lambda: init_fn(state._rng_key).replace(
                terminated=state.terminated,
                truncated=state.truncated,
                rewards=state.rewards,
            ),
            lambda: state,
        )
        return state
    return wrapped_step_fn

env.id='minatar-breakout'
env.version='v0'
pgx.__version__='0.9.0'


In [7]:
class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    legal_action_mask: jnp.ndarray


def make_update_fn():
    # TRAIN LOOP
    def _update_step(runner_state):
        # COLLECT TRAJECTORIES
        step_fn = jax.vmap(auto_reset(env.step, env.init))

        def _env_step(runner_state, unused):
            params, opt_state, env_state, last_obs, rng = runner_state
            # SELECT ACTION
            rng, _rng = jax.random.split(rng)
            logits, value = forward.apply(params, last_obs)
            mask = env_state.legal_action_mask
            logits = logits + jnp.finfo(np.float64).min * (~mask)
            pi = distrax.Categorical(logits=logits)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            env_state = step_fn(
                env_state, action
            )
            transition = Transition(
                env_state.terminated, action, value, env_state.rewards[:,
                                                                       0], log_prob, last_obs, mask
            )
            runner_state = (params, opt_state, env_state,
                            env_state.observation, rng)
            return runner_state, transition

        runner_state, traj_batch = jax.lax.scan(
            _env_step, runner_state, None, args.NUM_STEPS
        )

        # CALCULATE ADVANTAGE
        params, opt_state, env_state, last_obs, rng = runner_state
        _, last_val = forward.apply(params, last_obs)

        def _calculate_gae(traj_batch, last_val):
            def _get_advantages(gae_and_next_value, transition):
                gae, next_value = gae_and_next_value
                done, value, reward = (
                    transition.done,
                    transition.value,
                    transition.reward,
                )
                delta = reward + args.GAMMA * next_value * (1 - done) - value
                gae = (
                    delta
                    + args.GAMMA * args.GAE_LAMBDA * (1 - done) * gae
                )
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)

        # UPDATE NETWORK
        def _update_epoch(update_state, unused):
            def _update_minbatch(tup, batch_info):
                params, opt_state = tup
                traj_batch, advantages, targets = batch_info

                def _loss_fn(params, traj_batch, gae, targets):
                    # RERUN NETWORK
                    logits, value = forward.apply(params, traj_batch.obs)
                    mask = traj_batch.legal_action_mask
                    logits = logits + jnp.finfo(np.float64).min * (~mask)
                    pi = distrax.Categorical(logits=logits)
                    log_prob = pi.log_prob(traj_batch.action)

                    # CALCULATE VALUE LOSS
                    value_pred_clipped = traj_batch.value + (
                        value - traj_batch.value
                    ).clip(-args.CLIP_EPS, args.CLIP_EPS)
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(
                        value_pred_clipped - targets)
                    value_loss = (
                        0.5 * jnp.maximum(value_losses,
                                          value_losses_clipped).mean()
                    )

                    # CALCULATE ACTOR LOSS
                    ratio = jnp.exp(log_prob - traj_batch.log_prob)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = (
                        jnp.clip(
                            ratio,
                            1.0 - args.CLIP_EPS,
                            1.0 + args.CLIP_EPS,
                        )
                        * gae
                    )
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                    loss_actor = loss_actor.mean()
                    entropy = pi.entropy().mean()

                    total_loss = (
                        loss_actor
                        + args.VF_COEF * value_loss
                        - args.ENT_COEF * entropy
                    )
                    return total_loss, (value_loss, loss_actor, entropy)

                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                total_loss, grads = grad_fn(
                    params, traj_batch, advantages, targets)
                updates, opt_state = optimizer.update(grads, opt_state)
                params = optax.apply_updates(params, updates)
                return (params, opt_state), total_loss

            params, opt_state, traj_batch, advantages, targets, rng = update_state
            rng, _rng = jax.random.split(rng)
            batch_size = args.MINIBATCH_SIZE * NUM_MINIBATCHES
            assert (
                batch_size == args.NUM_STEPS * args.NUM_ENVS
            ), "batch size must be equal to number of steps * number of envs"
            permutation = jax.random.permutation(_rng, batch_size)
            batch = (traj_batch, advantages, targets)
            batch = jax.tree_util.tree_map(
                lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
            )
            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), batch
            )
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(
                    x, [NUM_MINIBATCHES, -1] + list(x.shape[1:])
                ),
                shuffled_batch,
            )
            (params, opt_state),  total_loss = jax.lax.scan(
                _update_minbatch, (params, opt_state), minibatches
            )
            update_state = (params, opt_state, traj_batch,
                            advantages, targets, rng)
            return update_state, total_loss

        update_state = (params, opt_state, traj_batch,
                        advantages, targets, rng)
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, args.UPDATE_EPOCHS
        )
        params, opt_state, _, _, _, rng = update_state

        runner_state = (params, opt_state, env_state, last_obs, rng)
        return runner_state, loss_info
    return _update_step

In [8]:
rng = jax.random.PRNGKey(args.SEED)
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros((1, ) + env.observation_shape)
params = forward.init(_rng, init_x)
opt_state = optimizer.init(params=params)

# INIT UPDATE FUNCTION
_update_step = make_update_fn()
jitted_init = jax.jit(jax.vmap(env.init))
jitted_update_step = jax.jit(_update_step)

# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, args.NUM_ENVS)
env_state = jitted_init(reset_rng)

rng, _rng = jax.random.split(rng)
runner_state = (params, opt_state, env_state, env_state.observation, _rng)
_, _ = jitted_update_step(runner_state)  # warming up

steps = 0
for i in track(range(NUM_UPDATES), description=f"Training {args.TOTAL_TIMESTEPS // 1_000_000}M frames..."):
    rng, _rng = jax.random.split(rng)
    runner_state, loss_info = jitted_update_step(runner_state)
    steps += args.NUM_ENVS * args.NUM_STEPS

params = runner_state[0]

Output()

In [9]:
@jax.jit
def evaluate(params, rng_key):
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.NUM_EVAL_ENVS)
    state = jax.vmap(env.init)(subkeys)
    R = jnp.zeros_like(state.rewards)

    def cond_fn(tup):
        state, _, _ = tup
        return ~state.terminated.all()

    def loop_fn(tup):
        state, R, rng_key = tup
        logits, value = forward.apply(params, state.observation)
        logits = jnp.where(state.legal_action_mask, logits,
                           jnp.finfo(np.float64).min)
        pi = distrax.Categorical(logits=logits)
        rng_key, _rng = jax.random.split(rng_key)
        action = pi.sample(seed=_rng)
        rng_key, _rng = jax.random.split(rng_key)
        state = step_fn(state, action)
        return state, R + state.rewards, rng_key

    state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key))
    return R

In [10]:
rng, _rng = jax.random.split(rng)
R = evaluate(params, _rng)
print(float(R.mean()))

32.06999969482422


In [11]:
!pip install minatar

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting minatar
  Downloading MinAtar-1.0.15-py3-none-any.whl (16 kB)
Installing collected packages: minatar
Successfully installed minatar-1.0.15


In [12]:
from minatar import Environment

In [13]:
# TODO: vectorize
def eval_minatar(env, model, key, eval_episodes=10, random=True):
    action_set = env.minimal_action_set()  # set to minimal action set
    Rs = []
    for i in track(range(eval_episodes), description="Evaluating ..."):
        env.reset()
        done = False
        total_reward = 0
        while not done:
            obs = env.state()
            logits, value = forward.apply(model, jnp.expand_dims(jnp.array(obs), axis=0)) # (1, obs_dim)
            logits = logits[0]
            assert len(logits) == len(action_set)
            if random:
                key, _key = jax.random.split(key)
                action_idx = jax.random.categorical(_key, logits)
                action = action_set[action_idx]
            else:
                action = action_set[jnp.argmax(logits)]
            reward, done = env.act(action)
            total_reward += reward
        Rs.append(total_reward)
    return np.array(Rs)

In [14]:
minatar_env = Environment("breakout")
rng, _rng = jax.random.split(rng)
R = eval_minatar(minatar_env, params, _rng, 10, True)
print(float(R.mean()))

Output()

35.7
