Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 63 additions & 37 deletions src/libtorchaudio/overdrive.cpp
Original file line number Diff line number Diff line change
@@ -1,52 +1,78 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/TensorAccessor.h>

namespace {

using torch::stable::Tensor;

template <typename T, size_t N>
using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;

// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
// after Tensor::accessor is supported in stable ABI
template <typename T, size_t N>
inline TensorAccessor<T, N> accessor(Tensor t) {
return TensorAccessor<T, N>(
reinterpret_cast<T*>(t.data_ptr()), t.sizes().data(), t.strides().data());
}

template <typename scalar_t>
void overdrive_cpu_kernel(
at::TensorAccessor<scalar_t, 2> waveform_accessor,
at::TensorAccessor<scalar_t, 2> temp_accessor,
at::TensorAccessor<scalar_t, 1> last_in_accessor,
at::TensorAccessor<scalar_t, 1> last_out_accessor,
at::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
TensorAccessor<scalar_t, 2> waveform_accessor,
TensorAccessor<scalar_t, 2> temp_accessor,
TensorAccessor<scalar_t, 1> last_in_accessor,
TensorAccessor<scalar_t, 1> last_out_accessor,
TensorAccessor<scalar_t, 2> output_waveform_accessor) {
int64_t n_frames = waveform_accessor.size(1);
int64_t n_channels = waveform_accessor.size(0);

at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) {
for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel];
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
output_waveform_accessor[i_channel][i_frame] =
waveform_accessor[i_channel][i_frame] * 0.5 +
last_out_accessor[i_channel] * 0.75;
}
}
});
torch::stable::parallel_for(
0, n_channels, 1, [&](int64_t begin, int64_t end) {
for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
last_in_accessor[i_channel] +
0.995 * last_out_accessor[i_channel];
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
output_waveform_accessor[i_channel][i_frame] =
waveform_accessor[i_channel][i_frame] * 0.5 +
last_out_accessor[i_channel] * 0.75;
}
}
});
}

void overdrive_core_loop_cpu(
at::Tensor& waveform,
at::Tensor& temp,
at::Tensor& last_in,
at::Tensor& last_out,
at::Tensor& output_waveform) {
AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] {
overdrive_cpu_kernel<scalar_t>(
waveform.accessor<scalar_t, 2>(),
temp.accessor<scalar_t, 2>(),
last_in.accessor<scalar_t, 1>(),
last_out.accessor<scalar_t, 1>(),
output_waveform.accessor<scalar_t, 2>());
}));
std::tuple<Tensor, Tensor, Tensor> overdrive_core_loop_cpu(
Tensor waveform,
Tensor temp,
Tensor last_in,
Tensor last_out,
Tensor output_waveform) {
THO_DISPATCH_V2(
waveform.scalar_type(),
"overdrive_cpu",
AT_WRAP([&] {
overdrive_cpu_kernel<scalar_t>(
accessor<scalar_t, 2>(waveform),
accessor<scalar_t, 2>(temp),
accessor<scalar_t, 1>(last_in),
accessor<scalar_t, 1>(last_out),
accessor<scalar_t, 2>(output_waveform));
}),
AT_FLOATING_TYPES);
return std::make_tuple(last_in, last_out, output_waveform);
}

} // namespace

// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu);
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"_overdrive_core_loop(Tensor waveform, Tensor temp, Tensor(a!) last_in, Tensor(b!) last_out, Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))");
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("_overdrive_core_loop", TORCH_BOX(&overdrive_core_loop_cpu));
}
26 changes: 8 additions & 18 deletions src/libtorchaudio/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ using torch::headeronly::ScalarType;

// Entry point into RNNT Loss
std::tuple<Tensor, Tensor> compute(
const Tensor& logits,
const Tensor& targets,
const Tensor& logit_lengths,
const Tensor& target_lengths,
Tensor logits,
Tensor targets,
Tensor logit_lengths,
Tensor target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
Expand Down Expand Up @@ -148,23 +148,13 @@ std::tuple<Tensor, Tensor> compute(
return std::make_tuple(costs, gradients);
}

void boxed_rnnt_loss(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
STD_TORCH_CHECK(num_args == 7, "num_args must be 7");
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
std::tuple<Tensor, Tensor> res = compute(
/*logits*/torch::stable::detail::to<Tensor>(stack[0]),
/*targets*/torch::stable::detail::to<Tensor>(stack[1]),
/*logit_lengths*/torch::stable::detail::to<Tensor>(stack[2]),
/*target_lengths*/torch::stable::detail::to<Tensor>(stack[3]),
/*blank*/float(torch::stable::detail::to<int64_t>(stack[4])),
/*clamp*/torch::stable::detail::to<double>(stack[5]),
/*fused_log_softmax*/torch::stable::detail::to<bool>(stack[6]));
stack[0] = torch::stable::detail::from(std::get<0>(res));
stack[1] = torch::stable::detail::from(std::get<1>(res));
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss_forward(Tensor logits, Tensor targets, Tensor logit_lengths, Tensor target_lengths, int blank, double clamp, bool fused_log_softmax) -> (Tensor, Tensor)");
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("rnnt_loss_forward", &boxed_rnnt_loss);
m.impl("rnnt_loss_forward", TORCH_BOX(&compute));
}

} // namespace gpu
Expand Down
Loading