Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor performance on llama3 #27

Open
JasonZhu1313 opened this issue May 1, 2024 · 3 comments
Open

Poor performance on llama3 #27

JasonZhu1313 opened this issue May 1, 2024 · 3 comments

Comments

@JasonZhu1313
Copy link

Hey,

I am wondering if you have tried ORPO on llama3, I was using the same hp as mistral training the mt-bench score is quite low compared to llama-instruct and similarly trained mistral model.

I used to see a lot of llama3 based model on this hub https://huggingface.co/orpo-explorers but they are suddenly all gone, do you have some reference for me to compare? And share some pointers on what could be wrong in llama3. I am using the same chat template as shown in repo during training and inference.


Llama-3-8B-Instruct           8.078125
llama3_orpo_trl_tt_uf_epoch4  6.187500

@JasonZhu1313
Copy link
Author

@jiwooya1000 if you have any pointer and some llama checkpoints for ORPO, thanks a lot!!

@jiwooya1000
Copy link
Contributor

Hello @JasonZhu1313,

Yes there were a lot of models in the orpo-explorers, but we made it private for now since we were running some ablations for building a general recipe for using ORPO, and there were way too many checkpoints from orpo-explorers😅

I did not try Llama-3 + UltraFeedback at the moment, but I did try with the Capybara-Preference dataset. Since I ran with alignment-handbook + 4 A100 + FSDP, here is the .yaml file you could easily try with:

# Model arguments
model_name_or_path: meta-llama/Meta-Llama-3-8B
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true

# Data training arguments
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
dataset_mixer:
  argilla/Capybara-Preferences: 1.0
dataset_splits:
- train
preprocessing_num_workers: 48

# ORPOTrainer arguments
bf16: true
beta: 0.05
gradient_accumulation_steps: 2
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true
hub_model_id:   #### your/repo/id
learning_rate: 5.0e-6
log_level: info
logging_steps: 10
lr_scheduler_type: cosine
max_length: 2048
max_prompt_length: 1792
num_train_epochs: 3
optim: paged_adamw_8bit
output_dir:    #### your/output/dir
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- tensorboard
- wandb
save_strategy: "epoch"

As a result, I've got 7.19 on MT-Bench. We did not fully study the optimal setting for Llama-3 for now, but I will let you know if we have some good insights🙂 Regarding that Llama-3-8B-Instruct with 8.08 was trained on a heavily filtered 10M human preference dataset, we might be able to get close to Llama-3-8B-Instruct with some hparam search!

@JasonZhu1313
Copy link
Author

Thanks a lot for being responsive as always!! Maybe a different question regarding Figure 5 of your paper. Can you explain what the x-axis is? Each method has a different reward function. How do you plot them on the same axis? Can you explain if any transformation was used? Also, is this analysis done on the win candidates of a test set which is a held out part of the Ultrafeedback dataset? Can you please give full details? Any code that you can point to would be useful. Please give details about Figure 11 too. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants