Official Jax implementation of EDP, from the following paper:
Efficient Diffusion Policies for Offline Reinforcement Learning. NeurIPS 2023.
Bingyi Kang, Xiao Ma, Chao, Du, Tianyu Pang, Shuicheng Yan
Sea AI Lab
[arxiv]
We propse a class of diffusion policies (EDP) that are efficient to train and generally compatible to a variety of RL algorithms. EDP serves as a more powerful policy representation for decision making, which can be used as a plug-in replacement for feed-forward policies (e.g., Gaussian policies). It has the following features:
- Enabling training diffusion with long steps, e.g., 1000 steps.
-
$25\times$ boost in traning speed, reducing training time from 5 days to 5 hours. - Generally applicable to both likelihood-based methods (PG, CRR, AWR, IQL) and value-maximization based methods (DDPG, TD3)
- Setting new state-of-the-arts on all four domains in D4RL.
Before you start, make sure to run
pip install -e .
Apart from this, you'll have to setup your MuJoCo environment and key as well. Please follow D4RL repo and setup the environment accordingly.
You can run EDP experiments using the following command:
python -m diffusion.trainer --env 'walker2d-medium-v2' --logging.output_dir './experiment_output' --algo_cfg.loss_type=TD3
To use other offline RL algorithms, simply change --algo_cfg.loss_type
parameter. For example:
python -m diffusion.trainer --env 'walker2d-medium-v2' --logging.output_dir './experiment_output' --algo_cfg.loss_type=IQL --norm_reward=True
By default we use ddpm
solver. To use dpm
, set --sample_method=dpm
and -algo_cfg.num_timesteps=1000
.
This codebase can also log to W&B online visualization platform. To log to W&B, you first need to set your W&B API key environment variable.
Alternatively, you could simply run wandb login
.
The project structure borrows from the Jax CQL implementation.
We also refer to the diffusion model implementation from OpenAI and the official diffusion Q learning implementation.