Remove the redundant shift during the loss computation in the Moshi m… #36928
+3
−13
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Correct the loss computation process in the Moshi model to apply the shift only once, as it is currently being applied twice.
Because the class name MoshiForCausalLM contains 'ForCausalLM', according to the mapping rules in LOSS_MAPPING, the self.loss_function used in the forward function of MoshiForCausalLM should be ForCausalLMLoss. As a result, logits and labels are shifted twice: once before calling self.loss_function and once inside self.loss_function. This leads to tokens < n - 1 predicting n instead of the expected behavior where tokens < n predict n.
This PR removes the shift before the self.loss_function call.
References:
LOSS_MAPPING:
transformers/src/transformers/loss/loss_utils.py
Lines 130 to 135 in 9e125d9
ForCausalLMLoss:
transformers/src/transformers/loss/loss_utils.py
Lines 33 to 57 in 9e125d9