Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable log probs input for rnnt loss #2798

Closed
wants to merge 5 commits into from

Conversation

carolineechen
Copy link
Contributor

@carolineechen carolineechen commented Oct 26, 2022

Add fused_log_softmax argument (default/current behavior = True) to rnnt loss.

If setting it to False, call log_softmax on the logits prior to passing it in to the rnnt loss function.

The following should produce the same output:

rnnt_loss(logits, targets, logit_lengths, target_lengths, fused_log_softmax=True)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
rnnt_loss(log_probs, targets, logit_lengths, target_lengths, fused_log_softmax=False)

testing -- unit tests + get same results on the conformer rnnt recipe

@carolineechen carolineechen changed the title [WIP] Enable log probs input for rnnt loss Enable log probs input for rnnt loss Nov 7, 2022
@carolineechen carolineechen marked this pull request as ready for review November 7, 2022 15:27
@carolineechen carolineechen requested a review from a team November 7, 2022 15:27
Copy link
Contributor

@xiaohui-zhang xiaohui-zhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. thanks @carolineechen for addressing this quickly to unblock @BriansIDP !

@facebook-github-bot
Copy link
Contributor

@carolineechen has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@carolineechen has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@carolineechen has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@github-actions
Copy link

github-actions bot commented Nov 8, 2022

Hey @carolineechen.
You merged this PR, but labels were not properly added. Please add a primary and secondary label (See https://github.com/pytorch/audio/blob/main/.github/process_commit.py)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants