JAX (using flax) Implementation of Proximal Policy Optimisation (PPO) Algorithm, designed for continuous action spaces.
The base implementation is largely based around the cleanrl implementation and the recurrent implementation using LSTM motivated by these blogs:
- https://npitsillos.github.io/blog/2021/recurrent-ppo/
- https://medium.com/@ngoodger_7766/proximal-policy-optimisation-in-pytorch-with-recurrent-models-edefb8a72180
- https://kam.al/blog/ppo_stale_states/
See example/gym_usage.ipynb
for an example of using this implementation
with a gymnax environment.
Dependencies can be installed with poetry by running
poetry install
Total rewards per train step with parameters (see example/gym_usage.ipynb)
n-train
: 2,500n-steps
: 2,048n-train-epochs
: 2mini-batch-size
: 256n-test-steps
: 2,000gamma
: 0.95gae-lambda
: 0.9entropy-coefficient
: 0.0001adam-eps
: 1e-8clip-coefficient
: 0.2critic-coefficient
: 0.5max-grad-norm
: 0.75LR
: 2e-3 → 2e-5
Mean and std of total rewards during training, averaged over random seeds.
This was tested against the pendulum environment with the velocity component of the observation masked.
Total rewards per train step with parameters (see example/lstm_usage.ipynb)
n-train
: 2,500n-train-env
: 32n-test-env
: 5n-train-epochs
: 2mini-batch-size
: 512n-test-steps
: 2,000sequence-length
: 8n-burn-in
: 8gamma
: 0.95gae-lambda
: 0.99entropy-coefficient
: 0.0001adam-eps
: 1e-8clip-coefficient
: 0.1critic-coefficient
: 0.5max-grad-norm
: 0.75LR
: 2e-3 → 2e-6
NOTE: This achieves good results but seems to be somewhat unstable. I suspect this might be due to stale hidden states (see here)
Avg total rewards during training across test environments, generated from 10 random seeds.
Recurrent Hidden States Initialisation
At the start of each episode we reset the LSTM hidden-states to zero, but then burn-in their value before we collect trajectories (and the same during evaluation). I did also try carrying over hidden states between training steps, with good results, but if training across multiple environments this becomes a bit harder to reason about.
Note that this may lead to strange behaviour is the training environment quickly reaches a terminal state (i.e. if the episode completes during the burn-in period).
- Early stopping based on the KL-divergence is not implemented.
- Benchmark against other reference implementations.
- Recalculate hidden states during policy update.
Pre commit hooks can be installed by running
pre-commit install
Pre-commit checks can then be run using
task lint
Tests can be run with
task test