Skip to content

Shentao-YANG/AMPL_NeurIPS2022

Repository files navigation

Offline Alternating Model-Policy Learning

Source codes for the experiments in A Unified Framework for Alternating Offline Model Training and Policy Learning. [Paper], [Poster], [Slides].

Bibtex:

@inproceedings{yang2022unified,
  title={A Unified Framework for Alternating Offline Model Training and Policy Learning},
  author={Shentao Yang and Shujian Zhang and Yihao Feng and Mingyuan Zhou},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022},
  url={https://arxiv.org/abs/2210.05922}
}

Installation

  1. Install basic packages, using e.g.,
conda create -n ampl python=3.8.5
conda activate ampl
pip install numpy matplotlib seaborn gym==0.17.0 torch==1.10.1 cudatoolkit==11.1.74

and adding other possible dependencies. 2. Install MuJoCo and mujoco-py. 3. Install D4RL.

Offline RL Experiment

Main Method

The run files to run the experiments are generated by the submit_jobs_server_gan.py file. An example use of this file is

python submit_jobs_server_gan.py

Flags can be provided to the python command. Please take a look at this file for available flags.

The location of the generated run files will be printed out.

Evaluation

The run files will generate a folder for each (dataset, seed) pair. Within a such folder, the file eval_norm.npy stores the normalized scores and eval.npy records the unnormalized scores. The normalized scores are calculated by the D4RL package.

Algorithmic Variants

Below lists the commands for the variants used in our ablation study.

  • No weighted model-retraining (train the model only once in the beginning using MLE)
python submit_jobs_server_gan.py --model_retrain_period=1000
  • Use VPM to train the MIW model
python submit_jobs_server_gan.py --dr_method="VPM" --weight_output_clipping="True"
  • Use GenDICE to train the MIW model
python submit_jobs_server_gan.py --dr_method="GenDICE" --weight_output_clipping="True"
  • Use DualDICE to train the MIW model
python submit_jobs_server_gan.py --dr_method="DualDICE" --weight_output_clipping="True"
  • Use weighted policy-regularizer
python submit_jobs_server_gan.py --weighted_policy_training='True'
  • KL-Dual + weighted policy-regularizer
python submit_jobs_server_gan.py --weighted_policy_training='True' --use_kl_dual='True' --use_weight_wpr='True'
  • KL-Dual + No-weighted policy-regularizer
python submit_jobs_server_gan.py --weighted_policy_training='True' --use_kl_dual='True' --use_weight_wpr='False'
  • Gaussian policy + JSD for the policy training
python submit_jobs_server_gan.py --use_gaussian_policy='True'
  • No regularization in the policy training
python submit_jobs_server_gan.py --remove_reg='True'
  • No model-rollout data (real_data_pct=1)
python submit_jobs_server_gan.py --real_data_pct=1.
  • Use reward function as the test function in training the MIW
python submit_jobs_server_gan.py --use_reward_test_func='True'
  • Use value-function as the discriminator to train the model
python submit_jobs_server_gan.py --q_dis_model='True'

License

MIT License.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages