Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions src/libtorchaudio/forced_align/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
#include <libtorchaudio/forced_align/compute.h>
#include <torch/script.h>
#include <torch/csrc/stable/library.h>

std::tuple<torch::Tensor, torch::Tensor> forced_align(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const torch::Tensor& inputLengths,
const torch::Tensor& targetLengths,
const int64_t blank) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::forced_align", "")
.typed<decltype(forced_align)>();
return op.call(logProbs, targets, inputLengths, targetLengths, blank);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"forced_align(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> (Tensor, Tensor)");
"forced_align(Tensor log_probs,"
"Tensor targets,"
"Tensor input_lengths,"
"Tensor target_lengths,"
"int blank) -> (Tensor, Tensor)");
}
9 changes: 0 additions & 9 deletions src/libtorchaudio/forced_align/compute.h
Original file line number Diff line number Diff line change
@@ -1,10 +1 @@
#pragma once

#include <torch/script.h>

std::tuple<torch::Tensor, torch::Tensor> forced_align(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const torch::Tensor& inputLengths,
const torch::Tensor& targetLengths,
const int64_t blank);
32 changes: 8 additions & 24 deletions src/libtorchaudio/forced_align/cpu/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ void forced_align_impl(
for (int i = 0; i < T * S; i++) {
backPtr_a[i] = -1;
}
auto logProbs_a = torchaudio::stable::accessor<scalar_t, 3>(logProbs);
auto targets_a = torchaudio::stable::accessor<target_t, 2>(targets);
auto paths_a = torchaudio::stable::accessor<target_t, 2>(paths);
auto logProbs_a = torchaudio::accessor<scalar_t, 3>(logProbs);
auto targets_a = torchaudio::accessor<target_t, 2>(targets);
auto paths_a = torchaudio::accessor<target_t, 2>(paths);
auto R = 0;
for (auto i = 1; i < L; i++) {
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
Expand Down Expand Up @@ -147,10 +147,10 @@ template <typename scalar_t>
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;

std::tuple<Tensor, Tensor> compute(
const Tensor& logProbs,
const Tensor& targets,
const Tensor& inputLengths,
const Tensor& targetLengths,
Tensor logProbs,
Tensor targets,
Tensor inputLengths,
Tensor targetLengths,
const int64_t blank) {
STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
Expand Down Expand Up @@ -224,24 +224,8 @@ std::tuple<Tensor, Tensor> compute(
return std::make_tuple(paths, logProbs);
}

void boxed_forced_align_cpu(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
std::tuple<Tensor, Tensor> res = compute(
/*logProbs*/ torch::stable::detail::to<Tensor>(stack[0]),
/*targets*/ torch::stable::detail::to<Tensor>(stack[1]),
/*logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2]),
/*target_lengths*/ torch::stable::detail::to<Tensor>(stack[3]),
/*blank*/ float(torch::stable::detail::to<int64_t>(stack[4])));
stack[0] = torch::stable::detail::from(std::get<0>(res));
stack[1] = torch::stable::detail::from(std::get<1>(res));
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("forced_align", &boxed_forced_align_cpu);
m.impl("forced_align", TORCH_BOX(&compute));
}

} // namespace cpu
Expand Down
63 changes: 20 additions & 43 deletions src/libtorchaudio/forced_align/gpu/compute.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <libtorchaudio/utils.h>
#include <libtorchaudio/stable/TensorAccessor.h>
#include <torch/csrc/stable/library.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/ScalarType.h>
Expand All @@ -23,9 +22,9 @@ using torch::headeronly::ScalarType;

template <typename scalar_t, typename target_t>
__global__ void falign_cuda_step_kernel(
const torchaudio::stable::PackedTensorAccessor32<scalar_t, 3, torchaudio::stable::RestrictPtrTraits>
const torchaudio::PackedTensorAccessor32<scalar_t, 3>
logProbs_a,
const torchaudio::stable::PackedTensorAccessor32<target_t, 2, torchaudio::stable::RestrictPtrTraits>
const torchaudio::PackedTensorAccessor32<target_t, 2>
targets_a,
const int T,
const int L,
Expand All @@ -36,9 +35,9 @@ __global__ void falign_cuda_step_kernel(
int start,
int end,
int backPtrBufferLen,
torchaudio::stable::PackedTensorAccessor32<scalar_t, 2, torchaudio::stable::RestrictPtrTraits>
torchaudio::PackedTensorAccessor32<scalar_t, 2>
alphas_a,
torchaudio::stable::PackedTensorAccessor32<int8_t, 2, torchaudio::stable::RestrictPtrTraits>
torchaudio::PackedTensorAccessor32<int8_t, 2>
backPtrBuffer_a) {
scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
const int batchIndex =
Expand Down Expand Up @@ -125,7 +124,7 @@ void forced_align_impl(
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
auto paths_a = torchaudio::stable::accessor<target_t, 2>(paths);
auto paths_a = torchaudio::accessor<target_t, 2>(paths);
const int batchIndex =
0; // TODO: support batch version and use the real batch index
const int T = logProbs.size(1); // num frames
Expand All @@ -150,8 +149,8 @@ void forced_align_impl(
torch::stable::fill_(alphas, kNegInfinity);

// CPU accessors
auto targetsCpu_a = torchaudio::stable::accessor<target_t, 2>(targetsCpu);
auto backPtrCpu_a = torchaudio::stable::accessor<int8_t, 2>(backPtrCpu);
auto targetsCpu_a = torchaudio::accessor<target_t, 2>(targetsCpu);
auto backPtrCpu_a = torchaudio::accessor<int8_t, 2>(backPtrCpu);
// count the number of repeats in label
int R = 0;
for (int i = 1; i < L; ++i) {
Expand Down Expand Up @@ -192,8 +191,8 @@ void forced_align_impl(
}
falign_cuda_step_kernel<scalar_t, target_t>
<<<1, kNumThreads, 0, defaultStream>>>(
torchaudio::stable::packed_accessor32<scalar_t, 3, torchaudio::stable::RestrictPtrTraits>(logProbs),
torchaudio::stable::packed_accessor32<target_t, 2, torchaudio::stable::RestrictPtrTraits>(targets),
torchaudio::packed_accessor32<scalar_t, 3>(logProbs),
torchaudio::packed_accessor32<target_t, 2>(targets),
T,
L,
N,
Expand All @@ -203,8 +202,8 @@ void forced_align_impl(
start,
end,
backPtrBufferLen,
torchaudio::stable::packed_accessor32<scalar_t, 2, torchaudio::stable::RestrictPtrTraits>(alphas),
torchaudio::stable::packed_accessor32<int8_t, 2, torchaudio::stable::RestrictPtrTraits>(backPtrBuffer));
torchaudio::packed_accessor32<scalar_t, 2>(alphas),
torchaudio::packed_accessor32<int8_t, 2>(backPtrBuffer));
C10_CUDA_KERNEL_LAUNCH_CHECK();
++backPtrBufferLen;
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
Expand All @@ -228,9 +227,8 @@ void forced_align_impl(
}
}
cpuDataTranferStream.synchronize();

auto alphasCpu = torchaudio::stable::cpu(alphas);
auto alphasCpu_a = torchaudio::stable::accessor<scalar_t, 2>(alphasCpu);
auto alphasCpu_a = torchaudio::accessor<scalar_t, 2>(alphasCpu);
int curIdxOffset = ((T - 1) % 2);
int ltrIdx =
alphasCpu_a[curIdxOffset][S - 1] > alphasCpu_a[curIdxOffset][S - 2]
Expand All @@ -244,18 +242,11 @@ void forced_align_impl(
}
}

template <typename scalar_t>
const auto forced_align_long_impl =
forced_align_impl<scalar_t, ScalarType::Long>;

template <typename scalar_t>
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;

std::tuple<Tensor, Tensor> compute(
const Tensor& logProbs,
const Tensor& targets,
const Tensor& inputLengths,
const Tensor& targetLengths,
Tensor logProbs,
Tensor targets,
Tensor inputLengths,
Tensor targetLengths,
const int64_t blank) {

STD_TORCH_CHECK(logProbs.is_cuda(), "log_probs must be a CUDA tensor");
Expand Down Expand Up @@ -307,31 +298,17 @@ std::tuple<Tensor, Tensor> compute(

THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] {
if (targets.scalar_type() == ScalarType::Long) {
forced_align_long_impl<scalar_t>(logProbs, targets, blank, paths);
(forced_align_impl<scalar_t, ScalarType::Long>(logProbs, targets, blank, paths));
} else {
forced_align_int_impl<scalar_t>(logProbs, targets, blank, paths);
}
(forced_align_impl<scalar_t, ScalarType::Int>(logProbs, targets, blank, paths));
}
}), AT_EXPAND(AT_FLOATING_TYPES), ScalarType::Half);

Tensor pathsCuda = torchaudio::stable::cuda(paths, logProbs.get_device_index());
return std::make_tuple(pathsCuda, logProbs);
}

void boxed_forced_align_gpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
std::tuple<Tensor, Tensor> res = compute(
/*logProbs*/torch::stable::detail::to<Tensor>(stack[0]),
/*targets*/torch::stable::detail::to<Tensor>(stack[1]),
/*logit_lengths*/torch::stable::detail::to<Tensor>(stack[2]),
/*target_lengths*/torch::stable::detail::to<Tensor>(stack[3]),
/*blank*/float(torch::stable::detail::to<int64_t>(stack[4])));
stack[0] = torch::stable::detail::from(std::get<0>(res));
stack[1] = torch::stable::detail::from(std::get<1>(res));
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("forced_align", &boxed_forced_align_gpu);
m.impl("forced_align", TORCH_BOX(&compute));
}

} // namespace gpu
Expand Down
39 changes: 16 additions & 23 deletions src/libtorchaudio/overdrive.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
#include <libtorchaudio/utils.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/TensorAccessor.h>

namespace {

using torch::stable::Tensor;

template <typename T, size_t N>
using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;

// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
// after Tensor::accessor is supported in stable ABI
template <typename T, size_t N>
inline TensorAccessor<T, N> accessor(Tensor t) {
return TensorAccessor<T, N>(
reinterpret_cast<T*>(t.data_ptr()), t.sizes().data(), t.strides().data());
}

template <typename scalar_t>
void overdrive_cpu_kernel(
TensorAccessor<scalar_t, 2> waveform_accessor,
TensorAccessor<scalar_t, 2> temp_accessor,
TensorAccessor<scalar_t, 1> last_in_accessor,
TensorAccessor<scalar_t, 1> last_out_accessor,
TensorAccessor<scalar_t, 2> output_waveform_accessor) {
torchaudio::TensorAccessor<scalar_t, 2> waveform_accessor,
torchaudio::TensorAccessor<scalar_t, 2> temp_accessor,
torchaudio::TensorAccessor<scalar_t, 1> last_in_accessor,
torchaudio::TensorAccessor<scalar_t, 1> last_out_accessor,
torchaudio::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
int64_t n_frames = waveform_accessor.size(1);
int64_t n_channels = waveform_accessor.size(0);

Expand Down Expand Up @@ -56,11 +45,11 @@ std::tuple<Tensor, Tensor, Tensor> overdrive_core_loop_cpu(
"overdrive_cpu",
AT_WRAP([&] {
overdrive_cpu_kernel<scalar_t>(
accessor<scalar_t, 2>(waveform),
accessor<scalar_t, 2>(temp),
accessor<scalar_t, 1>(last_in),
accessor<scalar_t, 1>(last_out),
accessor<scalar_t, 2>(output_waveform));
torchaudio::accessor<scalar_t, 2>(waveform),
torchaudio::accessor<scalar_t, 2>(temp),
torchaudio::accessor<scalar_t, 1>(last_in),
torchaudio::accessor<scalar_t, 1>(last_out),
torchaudio::accessor<scalar_t, 2>(output_waveform));
}),
AT_FLOATING_TYPES);
return std::make_tuple(last_in, last_out, output_waveform);
Expand All @@ -70,7 +59,11 @@ std::tuple<Tensor, Tensor, Tensor> overdrive_core_loop_cpu(

STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"_overdrive_core_loop(Tensor waveform, Tensor temp, Tensor(a!) last_in, Tensor(b!) last_out, Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))");
"_overdrive_core_loop(Tensor waveform,"
"Tensor temp,"
"Tensor(a!) last_in,"
"Tensor(b!) last_out,"
"Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))");
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
Expand Down
Loading
Loading