Skip to content
/ GEM Public

Code for Paper (Preserving Diversity in Supervised Fine-tuning of Large Language Models)

Notifications You must be signed in to change notification settings

liziniu/GEM

Repository files navigation

πŸš€ PyTorch Implementation of GEM 🌟

Welcome to the official PyTorch implementation of GEM! πŸŽ‰

GEM was introduced in our ICLR 2025 paper "Preserving Diversity in Supervised Fine-tuning of Large Language Models". This work was previously titled "Entropic Distribution Matching in Supervised Fine-tuning of LLMs: Less Overfitting and Better Diversity" and received the Best Paper Runner-up Award at the NeurIPS 2024 FITML Workshop.

GEM can replace the CE loss during SFT to preserve diversity and mitigate overfitting. 🌍✨

For more insights on GEM's potential to enhance RL training through improved cold-start strategies, check out our blog post: "Can Better Cold-Start Strategies Improve RL Training for LLMs?"

Quickstart Guide πŸ’»

Setup πŸ”§

First, create a new environment and install the required packages:

conda create -n gem python=3.10
conda activate gem
pip install -r requirements.txt

Note that the version of packages in requirements.txt is used in the paper. You may use a higher version of transformers (>= 4.46.0) that fixes the potential bug of gradient accumulation.

We also provide a Triton implementation of GEM loss in the utils folder, which may be faster than the original implementation when training large-scale models. Please refer to the README for more details. You may use this implementation with the following command:

python train.py --loss gem_triton

Training πŸ‹οΈβ€β™‚οΈ

Kickstart your training process using the UltraFeedback dataset from HuggingFace. Here's how:

Tokenize Data

bash scripts/tokenize_data.sh

Training

bash scripts/train_gem_ultrafeedback.sh

Evaluation πŸ§ͺ

Run evaluations for different tasks:

GSM8K

bash scripts/eval/gsm8k_eval.sh

GSM8K (Voting)

bash scripts/eval/gsm8k_voting_eval.sh

Creative Writing

bash scripts/eval/creative_writing.sh

To Do

  • Add the adaptive mechanism for choosing the hyper-parameter Ξ² .

πŸ“œ Citation

If you find this repository helpful in your research or projects, please consider citing the GEM paper in your academic work. Your support is much appreciated! πŸ™Œ

@inproceedings{li2025preserving,
  title={Preserving Diversity in Supervised Fine-Tuning of Large Language Models},
  author={Ziniu Li and Congliang Chen and Tian Xu and Zeyu Qin and Jiancong Xiao and Zhi-Quan Luo and Ruoyu Sun},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=NQEe7B7bSw}
}

Our work was previously titled "Entropic Distribution Matching in Supervised Fine-tuning of LLMs: Less Overfitting and Better Diversity", available on arXiv.

@article{li2024entropic,
  title={Entropic Distribution Matching in Supervised Fine-tuning of LLMs: Less Overfitting and Better Diversity},
  author={Li, Ziniu and Chen, Congliang and Xu, Tian and Qin, Zeyu and Xiao, Jiancong and Sun, Ruoyu and Luo, Zhi-Quan},
  journal={arXiv preprint arXiv:2408.16673},
  year={2024}
}

Ziniu Li would like to acknowledge Zhengyang Tang for his minimalistic and clean implementation of SFT.

Releases

No releases published

Packages

No packages published