diff --git a/src/libtorchaudio/overdrive.cpp b/src/libtorchaudio/overdrive.cpp index 4954271e41..4e3c8bf8c9 100644 --- a/src/libtorchaudio/overdrive.cpp +++ b/src/libtorchaudio/overdrive.cpp @@ -1,52 +1,78 @@ -#include -#include +#include +#include +#include +#include +#include namespace { +using torch::stable::Tensor; + +template +using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor; + +// TODO: eliminate accessor(t) in favor of t.accessor +// after Tensor::accessor is supported in stable ABI +template +inline TensorAccessor accessor(Tensor t) { + return TensorAccessor( + reinterpret_cast(t.data_ptr()), t.sizes().data(), t.strides().data()); +} + template void overdrive_cpu_kernel( - at::TensorAccessor waveform_accessor, - at::TensorAccessor temp_accessor, - at::TensorAccessor last_in_accessor, - at::TensorAccessor last_out_accessor, - at::TensorAccessor output_waveform_accessor) { + TensorAccessor waveform_accessor, + TensorAccessor temp_accessor, + TensorAccessor last_in_accessor, + TensorAccessor last_out_accessor, + TensorAccessor output_waveform_accessor) { int64_t n_frames = waveform_accessor.size(1); int64_t n_channels = waveform_accessor.size(0); - at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) { - for (int64_t i_channel = begin; i_channel < end; ++i_channel) { - for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { - last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - - last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; - last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; - output_waveform_accessor[i_channel][i_frame] = - waveform_accessor[i_channel][i_frame] * 0.5 + - last_out_accessor[i_channel] * 0.75; - } - } - }); + torch::stable::parallel_for( + 0, n_channels, 1, [&](int64_t begin, int64_t end) { + for (int64_t i_channel = begin; i_channel < end; ++i_channel) { + for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - + last_in_accessor[i_channel] + + 0.995 * last_out_accessor[i_channel]; + last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; + output_waveform_accessor[i_channel][i_frame] = + waveform_accessor[i_channel][i_frame] * 0.5 + + last_out_accessor[i_channel] * 0.75; + } + } + }); } -void overdrive_core_loop_cpu( - at::Tensor& waveform, - at::Tensor& temp, - at::Tensor& last_in, - at::Tensor& last_out, - at::Tensor& output_waveform) { - AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] { - overdrive_cpu_kernel( - waveform.accessor(), - temp.accessor(), - last_in.accessor(), - last_out.accessor(), - output_waveform.accessor()); - })); +std::tuple overdrive_core_loop_cpu( + Tensor waveform, + Tensor temp, + Tensor last_in, + Tensor last_out, + Tensor output_waveform) { + THO_DISPATCH_V2( + waveform.scalar_type(), + "overdrive_cpu", + AT_WRAP([&] { + overdrive_cpu_kernel( + accessor(waveform), + accessor(temp), + accessor(last_in), + accessor(last_out), + accessor(output_waveform)); + }), + AT_FLOATING_TYPES); + return std::make_tuple(last_in, last_out, output_waveform); } } // namespace -// Note: We want to avoid using "catch-all" kernel. -// The following registration should be replaced with CPU specific registration. -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu); +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "_overdrive_core_loop(Tensor waveform, Tensor temp, Tensor(a!) last_in, Tensor(b!) last_out, Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))"); +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("_overdrive_core_loop", TORCH_BOX(&overdrive_core_loop_cpu)); } diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 336e3b8abd..77c6a0e268 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -14,10 +14,10 @@ using torch::headeronly::ScalarType; // Entry point into RNNT Loss std::tuple compute( - const Tensor& logits, - const Tensor& targets, - const Tensor& logit_lengths, - const Tensor& target_lengths, + Tensor logits, + Tensor targets, + Tensor logit_lengths, + Tensor target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { @@ -148,23 +148,13 @@ std::tuple compute( return std::make_tuple(costs, gradients); } -void boxed_rnnt_loss(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - STD_TORCH_CHECK(num_args == 7, "num_args must be 7"); - STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); - std::tuple res = compute( - /*logits*/torch::stable::detail::to(stack[0]), - /*targets*/torch::stable::detail::to(stack[1]), - /*logit_lengths*/torch::stable::detail::to(stack[2]), - /*target_lengths*/torch::stable::detail::to(stack[3]), - /*blank*/float(torch::stable::detail::to(stack[4])), - /*clamp*/torch::stable::detail::to(stack[5]), - /*fused_log_softmax*/torch::stable::detail::to(stack[6])); - stack[0] = torch::stable::detail::from(std::get<0>(res)); - stack[1] = torch::stable::detail::from(std::get<1>(res)); +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "rnnt_loss_forward(Tensor logits, Tensor targets, Tensor logit_lengths, Tensor target_lengths, int blank, double clamp, bool fused_log_softmax) -> (Tensor, Tensor)"); } STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_forward", &boxed_rnnt_loss); + m.impl("rnnt_loss_forward", TORCH_BOX(&compute)); } } // namespace gpu