-
Notifications
You must be signed in to change notification settings - Fork 732
Add GPU RNNT Loss #1483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
carolineechen
merged 5 commits into
pytorch:master
from
carolineechen:export-D28128853-to-fbsync
May 6, 2021
Merged
Add GPU RNNT Loss #1483
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import torch | ||
| from .rnnt_loss_impl import RNNTLossTest | ||
| from torchaudio_unittest import common_utils | ||
| from .utils import skipIfNoTransducer | ||
|
|
||
|
|
||
| @skipIfNoTransducer | ||
| @common_utils.skipIfNoCuda | ||
| class TestRNNTLoss(RNNTLossTest, common_utils.PytorchTestCase): | ||
| device = torch.device('cuda') | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| #include <c10/cuda/CUDAStream.h> | ||
| #include <torch/script.h> | ||
| #include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h> | ||
|
|
||
| namespace torchaudio { | ||
| namespace rnnt { | ||
| namespace gpu { | ||
|
|
||
| // Entry point into RNNT Loss | ||
| std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( | ||
| torch::Tensor& logits, | ||
| const torch::Tensor& targets, | ||
| const torch::Tensor& src_lengths, | ||
| const torch::Tensor& tgt_lengths, | ||
| int64_t blank, | ||
| double clamp, | ||
| bool fused_log_smax = true, | ||
| bool reuse_logits_for_grads = true) { | ||
| Options options; | ||
| options.batchSize_ = src_lengths.size(0); | ||
| options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); | ||
| options.maxSrcLen_ = logits.size(1); | ||
| options.maxTgtLen_ = logits.size(2); | ||
| options.numTargets_ = logits.size(3); | ||
| options.blank_ = blank; | ||
| options.clamp_ = clamp; | ||
| options.fusedLogSmax_ = fused_log_smax; | ||
|
|
||
| CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); | ||
| options.stream_ = at::cuda::getCurrentCUDAStream(); | ||
| cudaSetDevice(logits.get_device()); | ||
| options.device_ = GPU; | ||
|
|
||
| torch::Tensor costs = torch::empty( | ||
| options.batchSize_ * options.nHypos_, | ||
| torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); | ||
| c10::optional<torch::Tensor> gradients = c10::nullopt; | ||
| if (logits.requires_grad()) { | ||
| if (reuse_logits_for_grads) { | ||
| gradients = logits; | ||
| } else { | ||
| gradients = torch::zeros_like(logits); | ||
| } | ||
| } | ||
|
|
||
| torch::Tensor int_workspace = torch::empty( | ||
| IntWorkspace::ComputeSizeFromOptions(options), | ||
| torch::TensorOptions() | ||
| .device(logits.device()) | ||
| .dtype(torch::ScalarType::Int)); | ||
|
|
||
| torch::Tensor float_workspace = torch::empty( | ||
| DtypeWorkspace<float>::ComputeSizeFromOptions(options), | ||
| torch::TensorOptions() | ||
| .device(logits.device()) | ||
| .dtype(torch::ScalarType::Float)); | ||
|
|
||
| Workspace<float> workspace( | ||
| /*options=*/options, | ||
| /*dtype_data=*/float_workspace.data_ptr<float>(), | ||
| /*dtype_size=*/float_workspace.numel(), | ||
| /*int_data=*/int_workspace.data_ptr<int>(), | ||
| /*int_size=*/int_workspace.numel()); | ||
|
|
||
| switch (logits.scalar_type()) { | ||
| case torch::ScalarType::Float: { | ||
| Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>( | ||
| /*workspace=*/workspace, | ||
| /*logits=*/logits.data_ptr<float>(), | ||
| /*targets=*/targets.data_ptr<int>(), | ||
| /*src_lengths=*/src_lengths.data_ptr<int>(), | ||
| /*tgt_lengths=*/tgt_lengths.data_ptr<int>(), | ||
| /*costs=*/costs.data_ptr<float>(), | ||
| /*gradients=*/ | ||
| (gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>()); | ||
| break; | ||
| } | ||
| case torch::ScalarType::Half: { | ||
| Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>( | ||
| /*workspace=*/workspace, | ||
| /*logits=*/logits.data_ptr<c10::Half>(), | ||
| /*targets=*/targets.data_ptr<int>(), | ||
| /*src_lengths=*/src_lengths.data_ptr<int>(), | ||
| /*tgt_lengths=*/tgt_lengths.data_ptr<int>(), | ||
| /*costs=*/costs.data_ptr<c10::Half>(), | ||
| /*gradients=*/ | ||
| (gradients == c10::nullopt) ? nullptr | ||
| : gradients->data_ptr<c10::Half>()); | ||
| break; | ||
| } | ||
| default: { | ||
| break; | ||
| } | ||
| }; | ||
|
|
||
| return std::make_tuple(costs, gradients); | ||
| } | ||
|
|
||
| TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { | ||
| m.impl("rnnt_loss", &compute); | ||
| } | ||
|
|
||
| } // namespace gpu | ||
| } // namespace rnnt | ||
| } // namespace torchaudio |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| #include <c10/cuda/CUDAStream.h> | ||
| #include <torch/script.h> | ||
| #include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h> | ||
|
|
||
| namespace torchaudio { | ||
| namespace rnnt { | ||
| namespace gpu { | ||
|
|
||
| torch::Tensor compute_alphas( | ||
| const torch::Tensor& logits, | ||
| const torch::Tensor& targets, | ||
| const torch::Tensor& src_lengths, | ||
| const torch::Tensor& tgt_lengths, | ||
| int64_t blank, | ||
| double clamp) { | ||
| Options options; | ||
| options.batchSize_ = src_lengths.size(0); | ||
| options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); | ||
| options.maxSrcLen_ = logits.size(1); | ||
| options.maxTgtLen_ = logits.size(2); | ||
| options.numTargets_ = logits.size(3); | ||
| options.blank_ = blank; | ||
| options.clamp_ = clamp; | ||
|
|
||
| CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); | ||
| options.stream_ = at::cuda::getCurrentCUDAStream(); | ||
| cudaSetDevice(logits.get_device()); | ||
| options.device_ = GPU; | ||
|
|
||
| torch::Tensor alphas = torch::zeros( | ||
| {options.batchSize_ * options.nHypos_, | ||
| options.maxSrcLen_, | ||
| options.maxTgtLen_}, | ||
| torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); | ||
|
|
||
| torch::Tensor int_workspace = torch::empty( | ||
| IntWorkspace::ComputeSizeFromOptions(options), | ||
| torch::TensorOptions() | ||
| .device(logits.device()) | ||
| .dtype(torch::ScalarType::Int)); | ||
|
|
||
| torch::Tensor float_workspace = torch::empty( | ||
| DtypeWorkspace<float>::ComputeSizeFromOptions(options), | ||
| torch::TensorOptions() | ||
| .device(logits.device()) | ||
| .dtype(torch::ScalarType::Float)); | ||
|
|
||
| Workspace<float> workspace( | ||
| /*options=*/options, | ||
| /*dtype_data=*/float_workspace.data_ptr<float>(), | ||
| /*dtype_size=*/float_workspace.numel(), | ||
| /*int_data=*/int_workspace.data_ptr<int>(), | ||
| /*int_size=*/int_workspace.numel()); | ||
|
|
||
| // Only support float, this is mainly to enable easy | ||
| // unit-testing | ||
| ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>( | ||
| /*workspace=*/workspace, | ||
| /*logits=*/logits.data_ptr<float>(), | ||
| /*targets=*/targets.data_ptr<int>(), | ||
| /*src_lengths=*/src_lengths.data_ptr<int>(), | ||
| /*tgt_lengths=*/tgt_lengths.data_ptr<int>(), | ||
| /*alphas=*/alphas.data_ptr<float>()); | ||
| return alphas; | ||
| } | ||
|
|
||
| TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { | ||
| m.impl("rnnt_loss_alphas", &compute_alphas); | ||
| } | ||
|
|
||
| } // namespace gpu | ||
| } // namespace rnnt | ||
| } // namespace torchaudio |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| #include <c10/cuda/CUDAStream.h> | ||
| #include <torch/script.h> | ||
| #include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h> | ||
|
|
||
| namespace torchaudio { | ||
| namespace rnnt { | ||
| namespace gpu { | ||
|
|
||
| torch::Tensor compute_betas( | ||
| const torch::Tensor& logits, | ||
| const torch::Tensor& targets, | ||
| const torch::Tensor& src_lengths, | ||
| const torch::Tensor& tgt_lengths, | ||
| int64_t blank, | ||
| double clamp) { | ||
| Options options; | ||
| options.batchSize_ = src_lengths.size(0); | ||
| options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0); | ||
| options.maxSrcLen_ = logits.size(1); | ||
| options.maxTgtLen_ = logits.size(2); | ||
| options.numTargets_ = logits.size(3); | ||
| options.blank_ = blank; | ||
| options.clamp_ = clamp; | ||
|
|
||
| CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); | ||
| options.stream_ = at::cuda::getCurrentCUDAStream(); | ||
| cudaSetDevice(logits.get_device()); | ||
| options.device_ = GPU; | ||
|
|
||
| torch::Tensor costs = torch::empty( | ||
| tgt_lengths.size(0), | ||
| torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); | ||
|
|
||
| torch::Tensor betas = torch::zeros( | ||
| {options.batchSize_ * options.nHypos_, | ||
| options.maxSrcLen_, | ||
| options.maxTgtLen_}, | ||
| torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); | ||
|
|
||
| torch::Tensor int_workspace = torch::empty( | ||
| IntWorkspace::ComputeSizeFromOptions(options), | ||
| torch::TensorOptions() | ||
| .device(logits.device()) | ||
| .dtype(torch::ScalarType::Int)); | ||
|
|
||
| torch::Tensor float_workspace = torch::empty( | ||
| DtypeWorkspace<float>::ComputeSizeFromOptions(options), | ||
| torch::TensorOptions() | ||
| .device(logits.device()) | ||
| .dtype(torch::ScalarType::Float)); | ||
|
|
||
| Workspace<float> workspace( | ||
| /*options=*/options, | ||
| /*dtype_data=*/float_workspace.data_ptr<float>(), | ||
| /*dtype_size=*/float_workspace.numel(), | ||
| /*int_data=*/int_workspace.data_ptr<int>(), | ||
| /*int_size=*/int_workspace.numel()); | ||
|
|
||
| // Only support float, this is mainly to enable easy | ||
| // unit-testing | ||
| ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>( | ||
| /*workspace=*/workspace, | ||
| /*logits=*/logits.data_ptr<float>(), | ||
| /*targets=*/targets.data_ptr<int>(), | ||
| /*src_lengths=*/src_lengths.data_ptr<int>(), | ||
| /*tgt_lengths=*/tgt_lengths.data_ptr<int>(), | ||
| /*costs=*/costs.data_ptr<float>(), | ||
| /*betas=*/betas.data_ptr<float>()); | ||
| return betas; | ||
| } | ||
|
|
||
| TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { | ||
| m.impl("rnnt_loss_betas", &compute_betas); | ||
| } | ||
|
|
||
| } // namespace gpu | ||
| } // namespace rnnt | ||
| } // namespace torchaudio |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just realized the name has not been standardized from Transducer to RNNT, but this can be changed after this PR