From 8430731af806173c23bc80f824ac1840c99b2afc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sotetsu=20KOYAMADA=EF=BC=88=E5=B0=8F=E5=B1=B1=E7=94=B0?= =?UTF-8?q?=E5=89=B5=E5=93=B2=EF=BC=89?= Date: Sun, 27 Aug 2023 20:40:50 +0900 Subject: [PATCH] [Examples] Add MinAtar PPO training example (#1020) --- examples/minatar-ppo/requirements.txt | 8 + examples/minatar-ppo/train.py | 360 ++++++++++++++++++++++++++ 2 files changed, 368 insertions(+) create mode 100644 examples/minatar-ppo/requirements.txt create mode 100644 examples/minatar-ppo/train.py diff --git a/examples/minatar-ppo/requirements.txt b/examples/minatar-ppo/requirements.txt new file mode 100644 index 000000000..baaafebb3 --- /dev/null +++ b/examples/minatar-ppo/requirements.txt @@ -0,0 +1,8 @@ +pgx>=1.3.1 +pgx-minatar +dm-haiku +optax +distrax +omegaconf +pydantic +wandb diff --git a/examples/minatar-ppo/train.py b/examples/minatar-ppo/train.py new file mode 100644 index 000000000..133d30d06 --- /dev/null +++ b/examples/minatar-ppo/train.py @@ -0,0 +1,360 @@ +"""This PPO implementation is modified from PureJaxRL: + + https://github.com/luchris429/purejaxrl + +Please refer to their work if you use this example in your research.""" + +import sys +import jax +import jax.numpy as jnp +import haiku as hk +import optax +from typing import NamedTuple, Literal +import distrax +import pgx +from pgx.experimental import auto_reset +import time + +import pickle +from omegaconf import OmegaConf +from pydantic import BaseModel +import wandb + + +class PPOConfig(BaseModel): + env_name: Literal[ + "minatar-breakout", + "minatar-freeway", + "minatar-space_invaders", + "minatar-asterix", + "minatar-seaquest", + ] = "minatar-breakout" + seed: int = 0 + lr: float = 0.0003 + num_envs: int = 4096 + num_eval_envs: int = 100 + num_steps: int = 128 + total_timesteps: int = 20000000 + update_epochs: int = 3 + minibatch_size: int = 4096 + gamma: float = 0.99 + gae_lambda: float = 0.95 + clip_eps: float = 0.2 + ent_coef: float = 0.01 + vf_coef: float = 0.5 + max_grad_norm: float = 0.5 + wandb_project: str = "pgx-minatar-ppo" + save_model: bool = False + + +args = PPOConfig(**OmegaConf.to_object(OmegaConf.from_cli())) +print(args, file=sys.stderr) +env = pgx.make(str(args.env_name)) + + +num_updates = args.total_timesteps // args.num_envs // args.num_steps +num_minibatches = args.num_envs * args.num_steps // args.minibatch_size + + +class ActorCritic(hk.Module): + def __init__(self, num_actions, activation="tanh"): + super().__init__() + self.num_actions = num_actions + self.activation = activation + assert activation in ["relu", "tanh"] + + def __call__(self, x): + x = x.astype(jnp.float32) + if self.activation == "relu": + activation = jax.nn.relu + else: + activation = jax.nn.tanh + x = hk.Conv2D(32, kernel_shape=2)(x) + x = jax.nn.relu(x) + x = hk.avg_pool(x, window_shape=(2, 2), + strides=(2, 2), padding="VALID") + x = x.reshape((x.shape[0], -1)) # flatten + x = hk.Linear(64)(x) + x = jax.nn.relu(x) + actor_mean = hk.Linear(64)(x) + actor_mean = activation(actor_mean) + actor_mean = hk.Linear(64)(actor_mean) + actor_mean = activation(actor_mean) + actor_mean = hk.Linear(self.num_actions)(actor_mean) + + critic = hk.Linear(64)(x) + critic = activation(critic) + critic = hk.Linear(64)(critic) + critic = activation(critic) + critic = hk.Linear(1)(critic) + + return actor_mean, jnp.squeeze(critic, axis=-1) + + +def forward_fn(x, is_eval=False): + net = ActorCritic(env.num_actions, activation="tanh") + logits, value = net(x) + return logits, value + + +forward = hk.without_apply_rng(hk.transform(forward_fn)) + + +optimizer = optax.chain(optax.clip_by_global_norm( + args.max_grad_norm), optax.adam(args.lr, eps=1e-5)) + + +class Transition(NamedTuple): + done: jnp.ndarray + action: jnp.ndarray + value: jnp.ndarray + reward: jnp.ndarray + log_prob: jnp.ndarray + obs: jnp.ndarray + + +def make_update_fn(): + # TRAIN LOOP + def _update_step(runner_state): + # COLLECT TRAJECTORIES + step_fn = jax.vmap(auto_reset(env.step, env.init)) + + def _env_step(runner_state, unused): + params, opt_state, env_state, last_obs, rng = runner_state + # SELECT ACTION + rng, _rng = jax.random.split(rng) + logits, value = forward.apply(params, last_obs) + pi = distrax.Categorical(logits=logits) + action = pi.sample(seed=_rng) + log_prob = pi.log_prob(action) + + # STEP ENV + rng, _rng = jax.random.split(rng) + env_state = step_fn( + env_state, action + ) + transition = Transition( + env_state.terminated, + action, + value, + jnp.squeeze(env_state.rewards), + log_prob, + last_obs + ) + runner_state = (params, opt_state, env_state, + env_state.observation, rng) + return runner_state, transition + + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, None, args.num_steps + ) + + # CALCULATE ADVANTAGE + params, opt_state, env_state, last_obs, rng = runner_state + _, last_val = forward.apply(params, last_obs) + + def _calculate_gae(traj_batch, last_val): + def _get_advantages(gae_and_next_value, transition): + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + delta = reward + args.gamma * next_value * (1 - done) - value + gae = ( + delta + + args.gamma * args.gae_lambda * (1 - done) * gae + ) + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + advantages, targets = _calculate_gae(traj_batch, last_val) + + # UPDATE NETWORK + def _update_epoch(update_state, unused): + def _update_minbatch(tup, batch_info): + params, opt_state = tup + traj_batch, advantages, targets = batch_info + + def _loss_fn(params, traj_batch, gae, targets): + # RERUN NETWORK + logits, value = forward.apply(params, traj_batch.obs) + pi = distrax.Categorical(logits=logits) + log_prob = pi.log_prob(traj_batch.action) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + ( + value - traj_batch.value + ).clip(-args.clip_eps, args.clip_eps) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square( + value_pred_clipped - targets) + value_loss = ( + 0.5 * jnp.maximum(value_losses, + value_losses_clipped).mean() + ) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - args.clip_eps, + 1.0 + args.clip_eps, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = pi.entropy().mean() + + total_loss = ( + loss_actor + + args.vf_coef * value_loss + - args.ent_coef * entropy + ) + return total_loss, (value_loss, loss_actor, entropy) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + total_loss, grads = grad_fn( + params, traj_batch, advantages, targets) + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + return (params, opt_state), total_loss + + params, opt_state, traj_batch, advantages, targets, rng = update_state + rng, _rng = jax.random.split(rng) + batch_size = args.minibatch_size * num_minibatches + assert ( + batch_size == args.num_steps * args.num_envs + ), "batch size must be equal to number of steps * number of envs" + permutation = jax.random.permutation(_rng, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), batch + ) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, [num_minibatches, -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + (params, opt_state), total_loss = jax.lax.scan( + _update_minbatch, (params, opt_state), minibatches + ) + update_state = (params, opt_state, traj_batch, + advantages, targets, rng) + return update_state, total_loss + + update_state = (params, opt_state, traj_batch, + advantages, targets, rng) + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, args.update_epochs + ) + params, opt_state, _, _, _, rng = update_state + + runner_state = (params, opt_state, env_state, last_obs, rng) + return runner_state, loss_info + return _update_step + + +@jax.jit +def evaluate(params, rng_key): + step_fn = jax.vmap(env.step) + rng_key, sub_key = jax.random.split(rng_key) + subkeys = jax.random.split(sub_key, args.num_eval_envs) + state = jax.vmap(env.init)(subkeys) + R = jnp.zeros_like(state.rewards) + + def cond_fn(tup): + state, _, _ = tup + return ~state.terminated.all() + + def loop_fn(tup): + state, R, rng_key = tup + logits, value = forward.apply(params, state.observation) + # action = logits.argmax(axis=-1) + pi = distrax.Categorical(logits=logits) + rng_key, _rng = jax.random.split(rng_key) + action = pi.sample(seed=_rng) + rng_key, _rng = jax.random.split(rng_key) + state = step_fn(state, action) + return state, R + state.rewards, rng_key + state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key)) + return R.mean() + + +def train(rng): + tt = 0 + st = time.time() + # INIT NETWORK + rng, _rng = jax.random.split(rng) + init_x = jnp.zeros((1, ) + env.observation_shape) + params = forward.init(_rng, init_x) + opt_state = optimizer.init(params=params) + + # INIT UPDATE FUNCTION + _update_step = make_update_fn() + jitted_update_step = jax.jit(_update_step) + + # INIT ENV + rng, _rng = jax.random.split(rng) + reset_rng = jax.random.split(_rng, args.num_envs) + env_state = jax.jit(jax.vmap(env.init))(reset_rng) + + rng, _rng = jax.random.split(rng) + runner_state = (params, opt_state, env_state, env_state.observation, _rng) + + # warm up + _, _ = jitted_update_step(runner_state) + + steps = 0 + + # initial evaluation + et = time.time() # exclude evaluation time + tt += et - st + rng, _rng = jax.random.split(rng) + eval_R = evaluate(runner_state[0], _rng) + log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps} + print(log) + wandb.log(log) + st = time.time() + + for i in range(num_updates): + runner_state, loss_info = jitted_update_step(runner_state) + steps += args.num_envs * args.num_steps + + # evaluation + et = time.time() # exclude evaluation time + tt += et - st + rng, _rng = jax.random.split(rng) + eval_R = evaluate(runner_state[0], _rng) + log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps} + print(log) + wandb.log(log) + st = time.time() + + return runner_state + + +if __name__ == "__main__": + wandb.init(project=args.wandb_project, config=args.dict()) + rng = jax.random.PRNGKey(args.seed) + out = train(rng) + if args.save_model: + with open(f"{args.env_name}-seed={args.seed}.ckpt", "wb") as f: + pickle.dump(out[0], f)