Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the redundant shift during the loss computation in the Moshi m… #36928

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

glynpu
Copy link

@glynpu glynpu commented Mar 24, 2025

What does this PR do?

Correct the loss computation process in the Moshi model to apply the shift only once, as it is currently being applied twice.

Because the class name MoshiForCausalLM contains 'ForCausalLM', according to the mapping rules in LOSS_MAPPING, the self.loss_function used in the forward function of MoshiForCausalLM should be ForCausalLMLoss. As a result, logits and labels are shifted twice: once before calling self.loss_function and once inside self.loss_function. This leads to tokens < n - 1 predicting n instead of the expected behavior where tokens < n predict n.

This PR removes the shift before the self.loss_function call.

References:
LOSS_MAPPING:

LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss,
"ForMaskedLM": ForMaskedLMLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,

ForCausalLMLoss:

def ForCausalLMLoss(
logits,
labels,
vocab_size: int,
num_items_in_batch: int = None,
ignore_index: int = -100,
shift_labels=None,
**kwargs,
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
if shift_labels is None:
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss

@github-actions github-actions bot marked this pull request as draft March 24, 2025 12:49
Copy link

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@glynpu glynpu marked this pull request as ready for review March 24, 2025 12:56
@github-actions github-actions bot requested a review from eustlb March 24, 2025 12:56
@Rocketknight1
Copy link
Member

cc @eustlb!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants