In [None]:
import jax
import jax.numpy as jnp
import jax_ppo
import matplotlib.pyplot as plt
import optax

In [None]:
import flock_env

In [None]:
k = jax.random.PRNGKey(101)

## Hyper-Parameters

In [None]:
n_agents = 20
n_train = 1_500
n_train_env = 32
n_test_env = 5
n_env_steps = 200
n_train_epochs = 2
mini_batch_size = 256

## Training Environment

In [None]:
env = flock_env.SimpleFlockEnv(
    reward_func=flock_env.rewards.exponential_rewards, n_agents=n_agents
)
env_params = env.default_params

In [None]:
n_steps = n_train * n_train_env * n_env_steps * n_train_epochs // mini_batch_size

## PPO Agent

In [None]:
params = jax_ppo.default_params._replace(
    gamma=0.99, 
    gae_lambda=0.99, 
    entropy_coeff=0.002, 
    adam_eps=1e-8, 
    clip_coeff=0.2
)

In [None]:
train_schedule = optax.linear_schedule(2e-3, 2e-5, n_steps)

In [None]:
k, agent = jax_ppo.init_agent(
    k,
    params,
    env.action_space(env_params).shape,
    env.observation_space(env_params).shape,
    train_schedule,
    layer_width=16,
)

## Train Agent

In [None]:
k, trained_agent, losses, state_ts, rewards, _ = jax_ppo.train(
    k,
    env,
    env_params,
    agent,
    n_train,
    n_train_env,
    n_train_epochs,
    mini_batch_size,
    n_test_env,
    params,
    greedy_test_policy=True,
    n_env_steps=n_env_steps,
    n_agents=n_agents,
)

In [None]:
plt.plot(jnp.mean(jnp.sum(rewards, axis=(2, 3)), axis=1));
plt.xlabel("Train Step")
plt.ylabel("Avg Total Rewards");

In [None]:
import matplotlib
matplotlib.use('Agg')
matplotlib.rc('animation', html='html5')

In [None]:
anim = flock_env.visualisation.animate_agents(
    state_ts.agent_positions[-1, 0],
    state_ts.agent_headings[-1, 0],
    rewards[-1, 0]
)

In [None]:
anim