-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DPO #645
DPO #645
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/645
Note: Links to docs will display an error until the docs builds have been completed. ❌ 12 New FailuresAs of commit c7f1d53 with merge base e472deb (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for creating this PR, exciting to see all the components to enable DPO come together! I took an initial pass and had some high level comments.
torchtune/modules/peft/peft_utils.py
Outdated
@@ -256,3 +256,11 @@ def get_merged_lora_ckpt( | |||
del state_dict[f"{module}.lora_a.weight"] | |||
del state_dict[f"{module}.lora_b.weight"] | |||
return state_dict | |||
|
|||
|
|||
def disable_adapter(model: nn.Module, disabled: bool) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be great to add a unit test for this and make sure lora weights are indeed not being used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be great to add a unit test for this and make sure lora weights are indeed not being used
I conducted the test manually by loading the original base model (without LoRA configuration) and subsequently loaded a LoRA model. Then, I enabled disable_adapter to the LoRA model and verified that the outputs matched. Should I include this logic in the unit test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can replicate this with a single LoRA linear layer initialized with constants so it's easier to write up, but yes that sounds like a good unit test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please review the addition I made here? I don't often engage in writing unit tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome progress! I left a bunch of comments, please let me know if any of them are unclear. I need to take a second closer look at the loss a bit later on as well, but otherwise things are shaping up nicely here.
recipes/lora_dpo_single_device.py
Outdated
if self.loss_type == "sigmoid": | ||
losses = ( | ||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) | ||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing | ||
) | ||
elif self.loss_type == "hinge": | ||
losses = torch.relu(1 - self.beta * logits) | ||
elif self.loss_type == "ipo": | ||
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. | ||
losses = (logits - 1 / (2 * self.beta)) ** 2 | ||
elif self.loss_type == "kto_pair": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we move this to a standalone module, we can consider loss_type as a callable instead of a string, then just define quick functions for each of these that we can pass in (not super important though). We can also consider whether we actually need to support all of these loss types, if one is overwhelmingly prevalent we could just start there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this is looking pretty good! I left a few more comments but no major concerns from my side. The main ask I have is around testing: are you able to kick off a run of this recipe and document the results in the PR's test plan? Please let me know if you need any assistance with this and I'll be happy to help out.
torchtune/modules/loss/dpo.py
Outdated
|
||
|
||
class DPOLoss(nn.Module): | ||
def __init__(self, beta=0.1, label_smoothing=0, loss_type="sigmoid"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstring here please
|
||
with disable_adapter(model_lora): | ||
lora_outputs = model_lora(inputs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice unit test! One other thing you can assert on here is that model_lora
's LoRALinear
layer has self.disabled=False
once we're outside of the context manager (just to make sure that it returns to its original state)
import torch.nn.functional as F | ||
|
||
|
||
class DPOLoss(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would also be nice to add a unit test for this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion. I appreciate your advice on adding a unit test for this class. However, I must admit that I don't frequently involve myself in writing unit tests, and I have a few inquiries regarding this matter. Specifically concerning the DPOLoss unit test, when the inputs are randomized, how should one compute the expected output? Conversely, if the input remains constant, wouldn't the utility of the unit test diminish?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the questions. In this case I think it's OK to use a fixed input. As you point out, with random inputs we cannot define a fixed expected loss value as it depends on the random inputs. Although without randomization we are only testing against a single input, the idea is that any subsequent breakage should still cause the test to fail. So as long as we do not provide trivial inputs (e.g. something like perfect prediction where the loss is 0), a single fixed test case should be sufficient.
As an added bonus, it provides extra documentation on how the loss is supposed to work. So if I come to the codebase and don't understand what to expect from DPOLoss
, I can then go to the unit test and see what I should expect to get given a simple, fixed input. So for example if you set policy_chosen_logps=[log(0.75), log(0.5)]
, policy_rejected_logps=[log(0.5), log(0.25)]
, ref_chosen_logps=[log(0.5), log(0.25)]
, ref_rejected_logps=[log(0.5), log(0.25)]
(and add appropriate comments in the code), a reader will then understand the expectation that the loss should be smaller because the policy in this example is better than the reference at discriminating between chosen and rejected samples.
Let me know if this makes sense to you. If it's too much of a pain we can also help out on this front. You've done a ton of work here so don't want to block you on a unit test that can be easily written as a follow-up.
recipes/lora_dpo_single_device.py
Outdated
"loss": loss.item(), | ||
"lr": self._optimizer.param_groups[0]["lr"], | ||
"rewards/chosen": chosen_rewards.mean().cpu(), | ||
"rewards/rejected": rejected_rewards.mean().cpu(), | ||
"rewards/accuracies": reward_accuracies.mean().cpu(), | ||
"rewards/margins": (chosen_rewards - rejected_rewards).mean().cpu(), | ||
"log_probs/rejected": policy_rejected_log_probs.detach().mean().cpu(), | ||
"log_probs/chosen": policy_chosen_log_probs.detach().mean().cpu(), | ||
"logits/rejected": policy_rejected_logits.detach().mean().cpu(), | ||
"logits/chosen": policy_chosen_logits.detach().mean().cpu(), | ||
"gpu_resources": torch.cuda.memory_allocated(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious, do we need all of these in general? E.g. mean of logits feels like it may be covered better by log_probs anyways. (Btw I think we should think about a better way to make our metric logging more configurable in general, so it's not necessarily specific to your DPO recipe)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious, do we need all of these in general? E.g. mean of logits feels like it may be covered better by log_probs anyways.
I have the same doubts, too, you can check my training log to see the difference. (by the way I'm still confused)
I run this recipe on a 24g 4090, here is the log: And I also run DPO by using TRL, differences include:
TRL's training log: An intriguing observation is that the log probabilities of the chosen (preferred) responses decreased, while the margin of rewards increased. This phenomenon has also been noted in this paper. Comparing the training results in TRL to those in Torchtune, it appears that the second and fourth differences may have influenced the outcome. Do you have any suggestions or insights regarding this? |
@yechenzhi thanks for this analysis! I am curious about the difference in the loss curves between trl and torchtune, it seems like the torchtune loss does not decrease as much as the trl one does. Regarding the differences in training you point out, for (2) you can enable LoRA some of the missing modules to get closer to an apples-to-apples comparison. E.g. set For (3): what is the lr scheduler used in the trl run? For (4), is it possible to perform the trl run without doing SFT first? (Apologies if this last point doesn't make sense, personally I am not so familiar with the trl setup.) |
@ebsmothers Yes, I found two reasons for this. Firstly, in TRL, it defaults to shuffling the data, but in TorchTune's training, I set shuffle to false. Perhaps due to the continuous appearance of similar prompts in the data, such as consecutive Russian questions and answers, it makes training more challenging. Secondly, there is a typo in the code which causes the learning rate to decrease extremely slow after warming up.
I've done that, but if I don't shuffle the data, the loss goes up.
The typo mentioned earlier caused this issue.
I tried that as well, and it turns out it didn't really matter. Here is the updated training log after addressing the two issues: |
To note, for a fair comparison, I selected approximately 70,000 samples for the 1000-step training in both TRL and Torchtune. Initially, I chose the first 64,000 * 5 samples from the dataset, and then filtered out those exceeding a length of 1024, resulting in approximately 70,000 remaining samples. I did not include this code in the PR. If you would like to reproduce the loss curve changes provided by me, you can add the line 'self._data = self._data.select(range(64000 * 5))' before here. |
@yechenzhi thank you for the thorough analysis here! The loss curves look much better now, and thank you for catching the bug on how we were counting steps in the LR scheduler when gradient accumulation was turned on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @yechenzhi for your diligence throughout the code review and testing process! This is a major feature and we are very appreciative of you for contributing it to our library. I will do a couple checks on my end then merge this, we can take care of any smaller follow-ups on our end as needed. We are also planning to update our README soon, once we do we will cite you as the author of this recipe.
@@ -204,7 +204,7 @@ def setup(self, cfg: DictConfig) -> None: | |||
# has been computed | |||
self._lr_scheduler = self._setup_lr_scheduler( | |||
cfg_lr_scheduler=cfg.lr_scheduler, | |||
num_training_steps=self.total_epochs * steps_per_epoch, | |||
num_training_steps=self.total_epochs * self._steps_per_epoch, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for catching this! I think we should just remove steps_per_epoch
as it is not used anyways and is clearly a potential cause of confusion. (Don't worry about it in this PR, I can do in a follow-up)
Context
integrating DPO into Torchtune, more details see here
Changelog
Test plan