Skip to content

rsshyam/Group-robust-preference-optimization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GRPO: Group Robust Preference Optimization

This codebase builds upon the DPO codebase publicly available in github https://github.com/eric-mitchell/direct-preference-optimization

What is this repo?

This repo includes a reference implementation of the GRPO algorithm for training language models from preference data, as described in our paper

Similar to DPO, our pipeline has two stages:

  1. Run supervised fine-tuning (SFT) on the dataset(s) of interest.
  2. Run robust preference learning (GRIPO) on the model from step 1, using preference data.

The important files in this repo are:

  • train.py: the main entry point for training (either SFT/IPO/GRIPO preference-based training)
  • src/trainers_factory.py: calls all the trainer classes from src/trainers
  • src/utils.py: common functions used by multiple methods
  • src/preference_datasets.py: dataset processing logic for both SFT and IPO/GRIPO preference-based training;

In this codebase, we specifically use the Gemma-2b model and the configurations used are detailed in config/model/gemma-2b.yaml. To download and use the Gemma-2b model, kindly refer to https://huggingface.co/google/gemma-2b. It is a gated model, and hence requires access through huggingface.

Our dataset is the global opinion data from https://huggingface.co/datasets/Anthropic/llm_global_opinions

Set up environment

First, create a virtualenv and install the dependencies. Python 3.10+ is recommended.

python3 -m venv env
source env/bin/activate
pip install -r main_requirements.txt
pip install scikit-learn

In config.yaml setup your wandb details, so that results can be visualized there.

Running SFT

sh scripts/run_sft.sh

Please run this command to reproduce the SFT used in our setup.

Running IPO/GRIPO

To run IPO, one requires the reference policy which is the path to the sft file. Please run the following command for IPO with path-to-sft-file replaced with your actual path to SFT trained policy

sh scripts/run_multi.sh --model.archive path-to-sft-file

The exact configurations we used in our IPO training are already set in sh scripts/run_multi.sh

Similarly for GRIPO

sh scripts/run_multi_robust.sh --model.archive path-to-sft-file

Note these commands were run on a machine with 1 40GB A100 GPU. Further, we are running single GPU training, using GroupEarlyStopTrainer which reduces the learning rate if there is improvement in loss values after a certain number of iterations and is tunable.

Plotting results

In order to visualize the results, we collect data directly from wandb and plot the same. We include plotting scripts in plot_scripts folder that performs this. Kindly change the wandb details and path-to-sft-file in the plot scripts to retrieve the plots.

plot_scripts/plot_from_wandb_full_metrics.py plots all the relevant metrics tracked in our experiments plot_scripts/plot_from_wandb_paper_plots.py reproduces the plots mentioned in the paper

Citation

Please cite our paper if you find the repo helpful in your work:

@article{ramesh2024grpo,
    title={Group Robust Preference Optimization in Reward-free RLHF},
    author={Shyam Sundhar Ramesh, Iason Chaimalas, Viraj Mehta, Haitham Bou Ammar, 
            Pier Giuseppe Sessa, Yifan Hu, Ilija Bogunovic},
    year={2024}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages