New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
disable non-deterministic cudnn ctcloss #22977
disable non-deterministic cudnn ctcloss #22977
Conversation
FYI: I'm not sure if writing "Associated issue" will automatically close the issue, but writing "Fixes: " will. |
aten/src/ATen/native/LossCTC.cpp
Outdated
@@ -364,7 +364,7 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu | |||
|
|||
Tensor res; | |||
if (use_cudnn) { | |||
res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, ctx.deterministicCuDNN(), zero_infinity)); | |||
res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, true, zero_infinity)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deserves a comment?
aten/src/ATen/native/LossCTC.cpp
Outdated
@@ -364,7 +364,7 @@ Tensor ctc_loss(const Tensor& log_probs, const Tensor& targets, IntArrayRef inpu | |||
|
|||
Tensor res; | |||
if (use_cudnn) { | |||
res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, ctx.deterministicCuDNN(), zero_infinity)); | |||
res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, true, zero_infinity)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you just add a comment explaining the issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nairbv is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Associated issue: pytorch/pytorch#21680 Pull Request resolved: pytorch/pytorch#22977 Differential Revision: D16357873 Pulled By: nairbv fbshipit-source-id: 58711bac7d3e8390e868d594dc265ba053a1537c
Associated issue: #21680