Skip to content

JAX implementation of the TD3 algorithm for reinforcement learning

License

Notifications You must be signed in to change notification settings

yifan12wu/td3-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX Implementation of TD3

JAX implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3) paper.

This code attempts to turn the PyTorch implementation from the original TD3 repository into JAX implementation while making minimal modifications. Training runs about two times as fast as the original PyTorch code on a i7-6700K+GTX-1080 machine.

Code is tested using jaxlib 0.1.61, flax 0.3.0 and Python 3.9.

Example usage:

python main.py --env HalfCheetah-v3

or

./run_experiments.sh

for full experiments.

Example Plots

About

JAX implementation of the TD3 algorithm for reinforcement learning

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published