Skip to content

srzer/LaMo-2023

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning

pytorch arXiv Twitter License

This repo is the official code release for the ICLR 2024 conference paper:

 

Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning
Ruizhe Shi*1, Yuyao Liu*1, Yanjie Ze2, Simon Shaolei Du3, Huazhe Xu124
The International Conference on Learning Representations (ICLR) 2024
1Tsinghua Universtiy, IIIS   2Shanghai Qi Zhi Institute   3University of Washington   4Shanghai AI Lab
*Equal contribution. Order is decided by coin flip.

 

🧾 Introduction

We propose LaMo, an offline RL framework that leverages the pre-trained Language Models (LMs) for low-level Motion control. On sparse-reward tasks, LaMo achieves strong results and surpasses recent strong algorithms CQL, IQL, TD3+BC, and DT; On dense-reward tasks, LaMo significantly improves Decision Transformer and closes the gap between value-based methods and DT-based methods. Notably, in low-data scenarios, our method demonstrates powerful few-shot learning ability, which can be attributed to the inductive bias from pre-trained LMs.

We look into the relationship between the performance of various algorithms and the scale of data. As depicted in the Figure, LaMo is capable of achieving excellent performance even with relatively small datasets. For example, in Hopper, LaMo surpasses the performance of CQL and DT when the sample ratio of data is 0.5% and maintains this advantage consistently as the sample ratio increases.

Below, we visualize 8 tasks across 3 domains that we consider.

  • D4RL
    • MuJoCo: Hopper, Walker2d, HalfCheetah, Reacher2d
    • Kitchen
  • Atari: Breakout, Qbert, Pong

💻 Installation

D4RL-tasks

Environment

We can only guarantee the reproducibility with the environment configuration as below.

Install MuJoCo

First, you need to download the file from this link and tar -xvf the_file_name in the ~/.mujoco folder. Then, run the following commands.

cd experiment-d4rl
conda create -n lamo-d4rl python=3.8.17

After that, add the following lines to your ~/.bashrc file:

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/YOUR_PATH_TO_THIS/.mujoco/mujoco210/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia

Remember to source ~/.bashrc to make the changes take effect.

Install D4RL

Install D4RL by following the guidance in D4RL.

Degrade the dm-control and mujoco package:

pip install mujoco==2.3.7
pip install dm-control==1.0.14

Install torch and other dependencies

pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install -r requirements.txt

Dataset

To download original D4RL data,

cd data
python download_d4rl_datasets.py

As for downsampled data, if you want to reproduce our experiments, you should directly get our pre-processed data in this link.

You can also generate more downsampled data by modifing line 10 of 'data/mujoco/ratio_dataset.py' and line 10 of 'data/kitchen/ratio_dataset.py' as

suffix = [your data version name]

and then run

cd data
cd mujoco
python ratio_dataset.py
cd ..
cd kitchen
python ratio_dataset.py
cd ..

You can also try generating the data using a PPO agent trained by yourself (only support Reacher2d), as provided in ‘data/data_generation_PPO’.

Atari-tasks

Environment

First make sure you have the dependencies to install Atari.

sudo apt install cmake
sudo apt install zlib1g-dev

Then run the following commands.

cd experiment-atari
conda env create -f env.yml

Dataset

The dataset will be downloaded automatically and cached locally by the package d4rl-atari once you launch an experiment. To reproduce our results of downsampled datasets, you can set the seed to be identical to ours (3 seeds, 0, 1, and 2), and our implementation of experiment-atari/buffer.py will make sure that the downsampled dataset will also be identical to ours.

🛠️ Usage

D4RL

After installing the packages and data, to reproduce our results on D4RL, you only need to run scripts provided in this link.

If you meet errors in running those scripts, try

dos2unix [the-script-name].sh

If you meet errors about D4RL or MuJoCo when running, these tips 1,2 may help.

If you want to view results on Weights & Biases, you need to modify line 435, 436 of 'experiment.py' as:

entity=[your-group-name],
project=[your-project-name],

You can also design your own script as ``run.sh''.

cd experiment-d4rl
bash run.sh [env_name] [dataset_name] [sample_ratio] [description] [seed] [gpu]

An example is:

bash run.sh hopper medium 0.1 reproduce 0 0

Trying more configurations is encouraged! Important arguments are explained as below:

-w # enable wandb
--sample_ratio your_sample_ratio # determine the size of the data you are training on, like 0.1
--data_suffix your_data_version_name # you could downsample the data by yourself, default is "d1"
--mlp_embedding # use MLP as embeddings and projections
--adapt_mode # otherwise fully fine-tuning
--adapt_embed # fine-tune embeddings and projections when adapt_mode is ON
--lora # fine-tune low rank matrices of Transformer when adapt_mode is ON
--pretrained_lm language_model_name # you could try 'gpt2' and 'gpt2-medium'
--co_training # use language loss as auxiliary objective
--co_lambda # the weight of language loss, like 0.1

Atari

To reproduce our results on Breakout with one click, run the following commands

cd experiment-atari
bash run.sh 

Since we use Hydra to manage the configuration of the experiments on Atari, you can overwrite hypermeters conveniently. If you want to run experiments on more environments, add the configuration for the corresponding environment under experiments-atari/cfgs/env. Refer to the documentation of Hydra for more details. Here are a few important hyperparameters:

env # environment name (breakout, qbert, pong, or any atari environment you want to explore)
pretrained_lm # gpt2, gpt2-medium or none
seed # 0, 1, 2
sample_ratio # the ratio of dataset you train on
model.random_initialize # randomly initialize the weight of the model (overwrite the pretrained weight) or not 
model.adapt_cfg.use_adapt # use adapt mode or not (relative to fully finetune)
model.adapt_cfg.adapt_embed # unfreeze embedding or not 
model.lora_cfg.use_lora # use lora or not
model.lora_cfg.lora_attn_dim # the dimension of lora
model.context_len # the context length of the transformer model
train.lr # learning rate
train.weight_decay # weight decay
train.batch_size # batch size
nlp_train.co_training # use language joint training or not
nlp_train.co_lambda # the weight of language joint training loss

🙏 Acknowledgement

LaMo is based on many open-source projects, including Decision Transformer, Can Wikipedia Help Offline Reinforcement Learning, LoRA, DeFog, d4rl-atari. We thank all these authors for their nicely open sourced code and their great contributions to the community.

🏷️ License

LaMo is licensed under the MIT license. See the LICENSE file for details.

📝 Citation

If you find our work useful, please consider citing:

@article{Shi2024LaMo,
  title={Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning},
  author={Ruizhe Shi and Yuyao Liu and Yanjie Ze and Simon S. Du and Huazhe Xu},
  journal={International Conference on Learning Representations}, 
  year={2024}
}

About

Official code for "Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published