Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ option(BUILD_KALDI "Build kaldi statically" ON)
option(BUILD_TRANSDUCER "Enable transducer" OFF)
option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON)
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
option(USE_CUDA "Enable CUDA support" OFF)

if(USE_CUDA)
enable_language(CUDA)
endif()

find_package(Torch REQUIRED)

Expand Down
2 changes: 2 additions & 0 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _get_build(var, default=False):
_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True)
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
_USE_ROCM = _get_build("USE_ROCM")
_USE_CUDA = torch.cuda.is_available()


def get_ext_modules():
Expand Down Expand Up @@ -76,6 +77,7 @@ def build_extension(self, ext):
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}",
f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}",
]
build_args = [
'--target', 'install'
Expand Down
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from .rnnt_loss_impl import RNNTLossTest
from torchaudio_unittest import common_utils
from .utils import skipIfNoTransducer


@skipIfNoTransducer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just realized the name has not been standardized from Transducer to RNNT, but this can be changed after this PR

@common_utils.skipIfNoCuda
class TestRNNTLoss(RNNTLossTest, common_utils.PytorchTestCase):
device = torch.device('cuda')
15 changes: 15 additions & 0 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ if(BUILD_TRANSDUCER)
rnnt/compute_betas.cpp
rnnt/compute.cpp
)

if (USE_CUDA)
set(
CUDA_TRANSDUCER_SOURCES
rnnt/gpu/compute_alphas.cu
rnnt/gpu/compute_betas.cu
rnnt/gpu/compute.cu
)
list(APPEND TRANSDUCER_SOURCES ${CUDA_TRANSDUCER_SOURCES})
endif()

list(APPEND LIBTORCHAUDIO_SOURCES ${TRANSDUCER_SOURCES})
endif()

Expand Down Expand Up @@ -105,6 +116,10 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
target_compile_definitions(_torchaudio PRIVATE INCLUDE_KALDI)
endif()

if (USE_CUDA)
target_compile_definitions(_torchaudio PRIVATE USE_CUDA)
endif()

target_include_directories(
_torchaudio
PRIVATE
Expand Down
105 changes: 105 additions & 0 deletions torchaudio/csrc/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>

namespace torchaudio {
namespace rnnt {
namespace gpu {

// Entry point into RNNT Loss
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_smax;

CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
cudaSetDevice(logits.get_device());
options.device_ = GPU;

torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) {
if (reuse_logits_for_grads) {
gradients = logits;
} else {
gradients = torch::zeros_like(logits);
}
}

torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));

torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));

Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());

switch (logits.scalar_type()) {
case torch::ScalarType::Float: {
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
break;
}
case torch::ScalarType::Half: {
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/
(gradients == c10::nullopt) ? nullptr
: gradients->data_ptr<c10::Half>());
break;
}
default: {
break;
}
};

return std::make_tuple(costs, gradients);
}

TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss", &compute);
}

} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
73 changes: 73 additions & 0 deletions torchaudio/csrc/rnnt/gpu/compute_alphas.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>

namespace torchaudio {
namespace rnnt {
namespace gpu {

torch::Tensor compute_alphas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;

CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
cudaSetDevice(logits.get_device());
options.device_ = GPU;

torch::Tensor alphas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));

torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));

torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));

Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());

// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>());
return alphas;
}

TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_alphas", &compute_alphas);
}

} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
78 changes: 78 additions & 0 deletions torchaudio/csrc/rnnt/gpu/compute_betas.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>

namespace torchaudio {
namespace rnnt {
namespace gpu {

torch::Tensor compute_betas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = src_lengths.size(0);
options.nHypos_ = tgt_lengths.size(0) / src_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;

CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
cudaSetDevice(logits.get_device());
options.device_ = GPU;

torch::Tensor costs = torch::empty(
tgt_lengths.size(0),
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));

torch::Tensor betas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));

torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));

torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));

Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());

// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*src_lengths=*/src_lengths.data_ptr<int>(),
/*tgt_lengths=*/tgt_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*betas=*/betas.data_ptr<float>());
return betas;
}

TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_betas", &compute_betas);
}

} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
Loading