JAX implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3) paper.
This code attempts to turn the PyTorch implementation from the original TD3 repository into JAX implementation while making minimal modifications. Training runs about two times as fast as the original PyTorch code on a i7-6700K+GTX-1080 machine.
Code is tested using jaxlib 0.1.61, flax 0.3.0 and Python 3.9.
Example usage:
python main.py --env HalfCheetah-v3
or
./run_experiments.sh
for full experiments.