In [None]:
import jax
import jax.numpy as jnp
import jax_ppo
import flock_env
import matplotlib.pyplot as plt
import optax

In [None]:
import ppo_training

In [None]:
k = jax.random.PRNGKey(101)

## Hyper-Parameters

In [None]:
ppo_params = jax_ppo.default_params._replace(
    gamma=0.99, 
    gae_lambda=0.99, 
    entropy_coeff=0.002, 
    adam_eps=1e-8, 
    clip_coeff=0.2
)

In [None]:
env_params = flock_env.EnvParams(
    min_speed=0.02,
    max_speed=0.05,
    max_rotate=0.075,
    max_accelerate=0.01,
    agent_radius=0.01,
    collision_penalty=0.7,
)

In [None]:
p = dict(
    rng=k,
    n_agents=100,
    n_train_steps=1_000,
    test_every=100,
    n_env_steps=200,
    n_train_env=4,
    n_test_env=2,
    n_update_epochs=2,
    mini_batch_size=512,
    max_mini_batches=20,
    network_layer_width=16,
    n_network_layers=2,
    env_params=env_params,
    ppo_params=ppo_params,
    env_type=flock_env.VisionEnv,
)

In [None]:
total_mini_batches = min(
    p["n_train_env"] * p["n_env_steps"] * p["n_agents"] // p["mini_batch_size"],
    p["max_mini_batches"]
)
total_steps = p["n_train_steps"] * p["n_update_epochs"] * total_mini_batches
p["training_schedule"] = optax.linear_schedule(2e-3, 2e-5, total_steps)

## Train Agent

In [None]:
rewards, losses, test_rewards, test_trajectories = ppo_training.training(**p)

## Analysis

In [None]:
plt.plot(jnp.mean(jnp.sum(rewards, axis=-1), axis=(1, 2)));
plt.xlabel("Train Step")
plt.ylabel("Avg Total Rewards");

In [None]:
import matplotlib
matplotlib.use('Agg')
matplotlib.rc('animation', html='html5')

In [None]:
anim = flock_env.visualisation.animate_agents(
    test_trajectories.position[-1, 1, :],
    test_trajectories.heading[-1, 1, :],
    test_rewards[-1, 1, :],
    cmap="cool"
)

In [None]:
anim

In [None]:
anim.save('flock_100.gif', writer='imagemagick', fps=16, dpi=90)