Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO #645

Merged
merged 35 commits into from
Apr 11, 2024
Merged

DPO #645

merged 35 commits into from
Apr 11, 2024

Conversation

yechenzhi
Copy link
Contributor

Context

integrating DPO into Torchtune, more details see here

Changelog

  • ...

Test plan

  • ....

Copy link

pytorch-bot bot commented Apr 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/645

Note: Links to docs will display an error until the docs builds have been completed.

❌ 12 New Failures

As of commit c7f1d53 with merge base e472deb (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 3, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating this PR, exciting to see all the components to enable DPO come together! I took an initial pass and had some high level comments.

torchtune/data/_instruct_templates.py Show resolved Hide resolved
torchtune/datasets/_preference.py Show resolved Hide resolved
torchtune/data/_tokenize.py Outdated Show resolved Hide resolved
torchtune/utils/collate.py Show resolved Hide resolved
@yechenzhi yechenzhi changed the title DPO [WIP]DPO Apr 4, 2024
torchtune/datasets/_preference.py Outdated Show resolved Hide resolved
torchtune/datasets/_preference.py Outdated Show resolved Hide resolved
torchtune/datasets/_preference.py Show resolved Hide resolved
torchtune/modules/peft/lora.py Outdated Show resolved Hide resolved
@@ -256,3 +256,11 @@ def get_merged_lora_ckpt(
del state_dict[f"{module}.lora_a.weight"]
del state_dict[f"{module}.lora_b.weight"]
return state_dict


def disable_adapter(model: nn.Module, disabled: bool) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to add a unit test for this and make sure lora weights are indeed not being used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to add a unit test for this and make sure lora weights are indeed not being used

I conducted the test manually by loading the original base model (without LoRA configuration) and subsequently loaded a LoRA model. Then, I enabled disable_adapter to the LoRA model and verified that the outputs matched. Should I include this logic in the unit test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can replicate this with a single LoRA linear layer initialized with constants so it's easier to write up, but yes that sounds like a good unit test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please review the addition I made here? I don't often engage in writing unit tests.

recipes/lora_dpo_single_device.py Outdated Show resolved Hide resolved
recipes/configs/llama2/7B_qlora_dpo_single_device.yaml Outdated Show resolved Hide resolved
recipes/lora_dpo_single_device.py Outdated Show resolved Hide resolved
recipes/lora_dpo_single_device.py Outdated Show resolved Hide resolved
recipes/lora_dpo_single_device.py Outdated Show resolved Hide resolved
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome progress! I left a bunch of comments, please let me know if any of them are unclear. I need to take a second closer look at the loss a bit later on as well, but otherwise things are shaping up nicely here.

torchtune/data/_utils.py Outdated Show resolved Hide resolved
torchtune/datasets/_stack_exchanged_paired.py Outdated Show resolved Hide resolved
torchtune/datasets/_preference.py Outdated Show resolved Hide resolved
torchtune/utils/tensor_utils.py Outdated Show resolved Hide resolved
torchtune/utils/collate.py Outdated Show resolved Hide resolved
recipes/lora_dpo_single_device.py Outdated Show resolved Hide resolved
recipes/lora_dpo_single_device.py Show resolved Hide resolved
recipes/lora_dpo_single_device.py Outdated Show resolved Hide resolved
Comment on lines 499 to 509
if self.loss_type == "sigmoid":
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto_pair":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we move this to a standalone module, we can consider loss_type as a callable instead of a string, then just define quick functions for each of these that we can pass in (not super important though). We can also consider whether we actually need to support all of these loss types, if one is overwhelmingly prevalent we could just start there.

recipes/lora_dpo_single_device.py Outdated Show resolved Hide resolved
recipes/lora_dpo_single_device.py Show resolved Hide resolved
recipes/lora_dpo_single_device.py Outdated Show resolved Hide resolved
recipes/lora_dpo_single_device.py Show resolved Hide resolved
recipes/lora_dpo_single_device.py Show resolved Hide resolved
recipes/configs/llama2/7B_qlora_dpo_single_device.yaml Outdated Show resolved Hide resolved
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is looking pretty good! I left a few more comments but no major concerns from my side. The main ask I have is around testing: are you able to kick off a run of this recipe and document the results in the PR's test plan? Please let me know if you need any assistance with this and I'll be happy to help out.



class DPOLoss(nn.Module):
def __init__(self, beta=0.1, label_smoothing=0, loss_type="sigmoid"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring here please

torchtune/modules/loss/dpo.py Outdated Show resolved Hide resolved

with disable_adapter(model_lora):
lora_outputs = model_lora(inputs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice unit test! One other thing you can assert on here is that model_lora's LoRALinear layer has self.disabled=False once we're outside of the context manager (just to make sure that it returns to its original state)

import torch.nn.functional as F


class DPOLoss(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would also be nice to add a unit test for this class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. I appreciate your advice on adding a unit test for this class. However, I must admit that I don't frequently involve myself in writing unit tests, and I have a few inquiries regarding this matter. Specifically concerning the DPOLoss unit test, when the inputs are randomized, how should one compute the expected output? Conversely, if the input remains constant, wouldn't the utility of the unit test diminish?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the questions. In this case I think it's OK to use a fixed input. As you point out, with random inputs we cannot define a fixed expected loss value as it depends on the random inputs. Although without randomization we are only testing against a single input, the idea is that any subsequent breakage should still cause the test to fail. So as long as we do not provide trivial inputs (e.g. something like perfect prediction where the loss is 0), a single fixed test case should be sufficient.

As an added bonus, it provides extra documentation on how the loss is supposed to work. So if I come to the codebase and don't understand what to expect from DPOLoss, I can then go to the unit test and see what I should expect to get given a simple, fixed input. So for example if you set policy_chosen_logps=[log(0.75), log(0.5)], policy_rejected_logps=[log(0.5), log(0.25)], ref_chosen_logps=[log(0.5), log(0.25)], ref_rejected_logps=[log(0.5), log(0.25)] (and add appropriate comments in the code), a reader will then understand the expectation that the loss should be smaller because the policy in this example is better than the reference at discriminating between chosen and rejected samples.

Let me know if this makes sense to you. If it's too much of a pain we can also help out on this front. You've done a ton of work here so don't want to block you on a unit test that can be easily written as a follow-up.

Comment on lines 462 to 472
"loss": loss.item(),
"lr": self._optimizer.param_groups[0]["lr"],
"rewards/chosen": chosen_rewards.mean().cpu(),
"rewards/rejected": rejected_rewards.mean().cpu(),
"rewards/accuracies": reward_accuracies.mean().cpu(),
"rewards/margins": (chosen_rewards - rejected_rewards).mean().cpu(),
"log_probs/rejected": policy_rejected_log_probs.detach().mean().cpu(),
"log_probs/chosen": policy_chosen_log_probs.detach().mean().cpu(),
"logits/rejected": policy_rejected_logits.detach().mean().cpu(),
"logits/chosen": policy_chosen_logits.detach().mean().cpu(),
"gpu_resources": torch.cuda.memory_allocated(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, do we need all of these in general? E.g. mean of logits feels like it may be covered better by log_probs anyways. (Btw I think we should think about a better way to make our metric logging more configurable in general, so it's not necessarily specific to your DPO recipe)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, do we need all of these in general? E.g. mean of logits feels like it may be covered better by log_probs anyways.

I have the same doubts, too, you can check my training log to see the difference. (by the way I'm still confused)

@yechenzhi
Copy link
Contributor Author

yechenzhi commented Apr 8, 2024

are you able to kick off a run of this recipe and document the results in the PR's test plan? Please let me know if you need any assistance with this and I'll be happy to help out.

I run this recipe on a 24g 4090, here is the log:
loss:
image
rewards margin:
image
rewards accuracies:
image
avg log probs:
image
avg logits:
image

And I also run DPO by using TRL, differences include:

  1. In TRL, I utilize a batch size of 2 with a gradient accumulation of 32. Conversely, in Torchtune, the batch size is 4, and the gradient accumulation is 16.
  2. In TRL, the fine-tuned weights include "q_proj," "v_proj," "k_proj," "out_proj," "fc_in," "fc_out," and "wte," whereas in Torchtune, we only fine-tune "q_proj" and "v_proj."
  3. The learning rate scheduler differs between the two approaches.
  4. Prior to training DPO in TRL, I also execute SFT, which is not performed in Torchtune.

TRL's training log:
loss( a bit lower than torchtune):
W B Chart 2024_4_9 00_00_41
rewards margin(higher than torchtune):
W B Chart 2024_4_8 23_59_51
rewards accuracies:
W B Chart 2024_4_9 00_02_01

An intriguing observation is that the log probabilities of the chosen (preferred) responses decreased, while the margin of rewards increased. This phenomenon has also been noted in this paper.

Comparing the training results in TRL to those in Torchtune, it appears that the second and fourth differences may have influenced the outcome. Do you have any suggestions or insights regarding this?

@ebsmothers
Copy link
Contributor

  1. In TRL, I utilize a batch size of 2 with a gradient accumulation of 32. Conversely, in Torchtune, the batch size is 4, and the gradient accumulation is 16.
  2. In TRL, the fine-tuned weights include "q_proj," "v_proj," "k_proj," "out_proj," "fc_in," "fc_out," and "wte," whereas in Torchtune, we only fine-tune "q_proj" and "v_proj."
  3. The learning rate scheduler differs between the two approaches.
  4. Prior to training DPO in TRL, I also execute SFT, which is not performed in Torchtune.

Comparing the training results in TRL to those in Torchtune, it appears that the second and fourth differences may have influenced the outcome. Do you have any suggestions or insights regarding this?

@yechenzhi thanks for this analysis! I am curious about the difference in the loss curves between trl and torchtune, it seems like the torchtune loss does not decrease as much as the trl one does.

Regarding the differences in training you point out, for (2) you can enable LoRA some of the missing modules to get closer to an apples-to-apples comparison. E.g. set lora_attn_modules=['q_proj', 'k_proj', 'v_proj', 'output_proj'] and apply_lora_to_mlp=True in your config. This is still not perfectly 1:1, but should at least provide a better comparison.

For (3): what is the lr scheduler used in the trl run?

For (4), is it possible to perform the trl run without doing SFT first? (Apologies if this last point doesn't make sense, personally I am not so familiar with the trl setup.)

@yechenzhi
Copy link
Contributor Author

I am curious about the difference in the loss curves between trl and torchtune, it seems like the torchtune loss does not decrease as much as the trl one does.

@ebsmothers Yes, I found two reasons for this. Firstly, in TRL, it defaults to shuffling the data, but in TorchTune's training, I set shuffle to false. Perhaps due to the continuous appearance of similar prompts in the data, such as consecutive Russian questions and answers, it makes training more challenging. Secondly, there is a typo in the code which causes the learning rate to decrease extremely slow after warming up.

set lora_attn_modules=['q_proj', 'k_proj', 'v_proj', 'output_proj'] and apply_lora_to_mlp=True in your config.

I've done that, but if I don't shuffle the data, the loss goes up.

For (3): what is the lr scheduler used in the trl run?

The typo mentioned earlier caused this issue.

For (4), is it possible to perform the trl run without doing SFT first?

I tried that as well, and it turns out it didn't really matter.

Here is the updated training log after addressing the two issues:
loss:
image
reward accuracies:
image
rewards margins:
image

@yechenzhi
Copy link
Contributor Author

To note, for a fair comparison, I selected approximately 70,000 samples for the 1000-step training in both TRL and Torchtune. Initially, I chose the first 64,000 * 5 samples from the dataset, and then filtered out those exceeding a length of 1024, resulting in approximately 70,000 remaining samples. I did not include this code in the PR. If you would like to reproduce the loss curve changes provided by me, you can add the line 'self._data = self._data.select(range(64000 * 5))' before here.

@ebsmothers
Copy link
Contributor

I am curious about the difference in the loss curves between trl and torchtune, it seems like the torchtune loss does not decrease as much as the trl one does.

@ebsmothers Yes, I found two reasons for this. Firstly, in TRL, it defaults to shuffling the data, but in TorchTune's training, I set shuffle to false. Perhaps due to the continuous appearance of similar prompts in the data, such as consecutive Russian questions and answers, it makes training more challenging. Secondly, there is a typo in the code which causes the learning rate to decrease extremely slow after warming up.

set lora_attn_modules=['q_proj', 'k_proj', 'v_proj', 'output_proj'] and apply_lora_to_mlp=True in your config.

I've done that, but if I don't shuffle the data, the loss goes up.

For (3): what is the lr scheduler used in the trl run?

The typo mentioned earlier caused this issue.

For (4), is it possible to perform the trl run without doing SFT first?

I tried that as well, and it turns out it didn't really matter.

@yechenzhi thank you for the thorough analysis here! The loss curves look much better now, and thank you for catching the bug on how we were counting steps in the LR scheduler when gradient accumulation was turned on.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @yechenzhi for your diligence throughout the code review and testing process! This is a major feature and we are very appreciative of you for contributing it to our library. I will do a couple checks on my end then merge this, we can take care of any smaller follow-ups on our end as needed. We are also planning to update our README soon, once we do we will cite you as the author of this recipe.

@@ -204,7 +204,7 @@ def setup(self, cfg: DictConfig) -> None:
# has been computed
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
num_training_steps=self.total_epochs * steps_per_epoch,
num_training_steps=self.total_epochs * self._steps_per_epoch,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for catching this! I think we should just remove steps_per_epoch as it is not used anyways and is clearly a potential cause of confusion. (Don't worry about it in this PR, I can do in a follow-up)

@ebsmothers ebsmothers changed the title [WIP]DPO DPO Apr 11, 2024
@ebsmothers ebsmothers merged commit 8bb3aae into pytorch:main Apr 11, 2024
13 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants