diff --git a/src/libtorchaudio/rnnt/cpu/cpu_kernels.h b/src/libtorchaudio/rnnt/cpu/cpu_kernels.h index a4a8473d35..8f1fb39bb4 100644 --- a/src/libtorchaudio/rnnt/cpu/cpu_kernels.h +++ b/src/libtorchaudio/rnnt/cpu/cpu_kernels.h @@ -3,8 +3,7 @@ #include #include #include - -#include +#include #include #include @@ -50,7 +49,7 @@ class TensorView { } DTYPE& operator()(const std::vector& indices) { - TORCH_CHECK_EQ(indices.size(), dims_.size()); + STD_TORCH_CHECK(indices.size() == dims_.size()); int index = indices.back(); for (int i = indices.size() - 2; i >= 0; --i) { index += indices[i] * strides_[i]; diff --git a/src/libtorchaudio/rnnt/cpu/cpu_transducer.h b/src/libtorchaudio/rnnt/cpu/cpu_transducer.h index c7fd2d289f..878c1ef42e 100644 --- a/src/libtorchaudio/rnnt/cpu/cpu_transducer.h +++ b/src/libtorchaudio/rnnt/cpu/cpu_transducer.h @@ -28,7 +28,7 @@ status_t Compute( DTYPE* gradients = nullptr) { const Options& options = workspace.GetOptions(); - TORCH_CHECK_EQ(options.device_, CPU); + STD_TORCH_CHECK(options.device_ == CPU); const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; @@ -91,7 +91,7 @@ status_t ComputeAlphas( DTYPE* alphas) { const Options& options = workspace.GetOptions(); - TORCH_CHECK_EQ(options.device_, CPU); + STD_TORCH_CHECK(options.device_ == CPU); const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; @@ -140,7 +140,7 @@ status_t ComputeBetas( DTYPE* betas) { const Options& options = workspace.GetOptions(); - TORCH_CHECK_EQ(options.device_, CPU); + STD_TORCH_CHECK(options.device_ == CPU); const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 800f029121..4701580e6b 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -4,6 +4,7 @@ #include #include #include +#include namespace torchaudio { namespace rnnt { @@ -117,33 +118,21 @@ std::tuple compute( /*int_data=*/reinterpret_cast(int_workspace.data_ptr()), /*int_size=*/int_workspace.numel()); - switch (logits.scalar_type()) { - case ScalarType::Float: { - Compute( - /*workspace=*/workspace, - /*logits=*/reinterpret_cast(logits.data_ptr()), - /*targets=*/reinterpret_cast(targets.data_ptr()), - /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), - /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), - /*costs=*/reinterpret_cast(costs.data_ptr()), - /*gradients=*/reinterpret_cast(gradients.data_ptr())); - break; - } - case ScalarType::Half: { - Compute( - /*workspace=*/workspace, - /*logits=*/reinterpret_cast(logits.data_ptr()), - /*targets=*/reinterpret_cast(targets.data_ptr()), - /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), - /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), - /*costs=*/reinterpret_cast(costs.data_ptr()), - /*gradients=*/reinterpret_cast(gradients.data_ptr())); - break; - } - default: { - STD_TORCH_CHECK(false, "unreachable"); - } - }; + THO_DISPATCH_V2( + logits.scalar_type(), + "rnnt:compute", + AT_WRAP([&] { + (Compute( + /*workspace=*/workspace, + /*logits=*/reinterpret_cast(logits.data_ptr()), + /*targets=*/reinterpret_cast(targets.data_ptr()), + /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), + /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), + /*costs=*/reinterpret_cast(costs.data_ptr()), + /*gradients=*/reinterpret_cast(gradients.data_ptr()))); + }), + ScalarType::Float, + ScalarType::Half); return std::make_tuple(costs, gradients); } diff --git a/src/libtorchaudio/rnnt/gpu/half.cuh b/src/libtorchaudio/rnnt/gpu/half.cuh index 1a9b2c3eda..3d8bfb9e20 100644 --- a/src/libtorchaudio/rnnt/gpu/half.cuh +++ b/src/libtorchaudio/rnnt/gpu/half.cuh @@ -1,7 +1,7 @@ #pragma once #ifdef USE_C10_HALF -#include "c10/util/Half.h" +#include #endif // USE_C10_HALF #include diff --git a/src/libtorchaudio/rnnt/workspace.h b/src/libtorchaudio/rnnt/workspace.h index b4bbb30a43..cb217e8700 100644 --- a/src/libtorchaudio/rnnt/workspace.h +++ b/src/libtorchaudio/rnnt/workspace.h @@ -4,8 +4,7 @@ #include #include - -#include +#include namespace torchaudio { namespace rnnt { @@ -29,7 +28,7 @@ class DtypeWorkspace { ~DtypeWorkspace() {} static int ComputeSizeFromOptions(const Options& options) { - TORCH_CHECK_NE(options.device_, UNDEFINED); + STD_TORCH_CHECK(options.device_ != UNDEFINED); return ComputeSizeForDenominators(options) + ComputeSizeForLogProbs(options) + ComputeSizeForAlphas(options) + ComputeSizeForBetas(options); @@ -38,7 +37,7 @@ class DtypeWorkspace { void Free(); void Reset(const Options& options, DTYPE* data, int size) { int needed_size = ComputeSizeFromOptions(options); - TORCH_CHECK_LE(needed_size, size); + STD_TORCH_CHECK(needed_size <= size); options_ = options; data_ = data; size_ = size; @@ -100,7 +99,7 @@ class IntWorkspace { void Reset(const Options& options, int* data, int size) { int needed_size = ComputeSizeFromOptions(options); - TORCH_CHECK_LE(needed_size, size); + STD_TORCH_CHECK(needed_size <= size); options_ = options; data_ = data; size_ = size; @@ -111,11 +110,11 @@ class IntWorkspace { } int* GetPointerToAlphaCounters() const { - TORCH_CHECK_EQ(options_.device_, GPU); + STD_TORCH_CHECK(options_.device_ == GPU); return data_; } int* GetPointerToBetaCounters() const { - TORCH_CHECK_EQ(options_.device_, GPU); + STD_TORCH_CHECK(options_.device_ == GPU); return GetPointerToAlphaCounters() + ComputeSizeForAlphaCounters(options_); } diff --git a/src/libtorchaudio/utils.cpp b/src/libtorchaudio/utils.cpp index c32368b761..3fb53264d3 100644 --- a/src/libtorchaudio/utils.cpp +++ b/src/libtorchaudio/utils.cpp @@ -1,6 +1,4 @@ -#include #include -#include #ifdef USE_CUDA #include