Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 40 additions & 41 deletions src/libtorchaudio/iir_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
#include <libtorchaudio/utils.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/ScalarType.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/torch.h>
#include <c10/core/DeviceGuard.h>

using torch::headeronly::ScalarType;
using torch::stable::Tensor;

template <typename scalar_t>
__global__ void iir_cu_kernel(
const torch::
PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> in,
const torch::
PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>
a_flipped,
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>
out) {
const torchaudio::PackedTensorAccessorSizeT<scalar_t, 3> in,
const torchaudio::PackedTensorAccessorSizeT<scalar_t, 2> a_flipped,
torchaudio::PackedTensorAccessorSizeT<scalar_t, 3> out) {
int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
int64_t n = in.size(0);
int64_t c = in.size(1);
Expand All @@ -33,51 +35,48 @@ __global__ void iir_cu_kernel(
}
}

void cuda_lfilter_core_loop(
const torch::Tensor& in,
const torch::Tensor& a_flipped,
torch::Tensor& padded_out) {
TORCH_CHECK(
in.device().is_cuda() && a_flipped.device().is_cuda() &&
padded_out.device().is_cuda());
Tensor cuda_lfilter_core_loop(
Tensor in,
Tensor a_flipped,
Tensor padded_out) {
STD_TORCH_CHECK(
in.is_cuda() && a_flipped.is_cuda() &&
padded_out.is_cuda());

TORCH_CHECK(
STD_TORCH_CHECK(
(in.get_device_index() == a_flipped.get_device_index()) &&
(in.get_device_index() == padded_out.get_device_index()));

STD_TORCH_CHECK(
in.is_contiguous() && a_flipped.is_contiguous() &&
padded_out.is_contiguous());

TORCH_CHECK(
(in.dtype() == torch::kFloat32 || in.dtype() == torch::kFloat64) &&
(a_flipped.dtype() == torch::kFloat32 ||
a_flipped.dtype() == torch::kFloat64) &&
(padded_out.dtype() == torch::kFloat32 ||
padded_out.dtype() == torch::kFloat64));
STD_TORCH_CHECK(
(in.scalar_type() == ScalarType::Float || in.scalar_type() == ScalarType::Double) &&
(a_flipped.scalar_type() == ScalarType::Float ||
a_flipped.scalar_type() == ScalarType::Double) &&
(padded_out.scalar_type() == ScalarType::Float ||
padded_out.scalar_type() == ScalarType::Double));

const int N = in.size(0);
const int C = in.size(1);
TORCH_CHECK(N == padded_out.size(0));
TORCH_CHECK(C == padded_out.size(1));
STD_TORCH_CHECK(N == padded_out.size(0));
STD_TORCH_CHECK(C == padded_out.size(1));

TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));
STD_TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));

const at::cuda::OptionalCUDAGuard device_guard(device_of(in));
const at::cuda::OptionalCUDAGuard device_guard(in.get_device_index());

const dim3 threads(256);
const dim3 blocks((N * C + threads.x - 1) / threads.x);

AT_DISPATCH_FLOATING_TYPES(
in.scalar_type(), "iir_cu_loop", ([&] {
iir_cu_kernel<scalar_t><<<blocks, threads>>>(
in.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
a_flipped.packed_accessor<
scalar_t,
2,
torch::RestrictPtrTraits,
size_t>(),
padded_out.packed_accessor<
scalar_t,
3,
torch::RestrictPtrTraits,
size_t>());
THO_DISPATCH_V2(
in.scalar_type(), "iir_cu_loop", AT_WRAP([&] {
(iir_cu_kernel<scalar_t><<<blocks, threads>>>(
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out)));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
}), AT_FLOATING_TYPES);
return padded_out;
}
9 changes: 4 additions & 5 deletions src/libtorchaudio/iir_cuda.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#pragma once

#include <torch/types.h>
#include <torch/csrc/stable/tensor.h>

void cuda_lfilter_core_loop(
const torch::Tensor& in,
const torch::Tensor& a_flipped,
torch::Tensor& padded_out);
using torch::stable::Tensor;

Tensor cuda_lfilter_core_loop(Tensor in, Tensor a_flipped, Tensor padded_out);
161 changes: 93 additions & 68 deletions src/libtorchaudio/lfilter.cpp
Original file line number Diff line number Diff line change
@@ -1,116 +1,141 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <libtorchaudio/utils.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/ScalarType.h>

#ifdef USE_CUDA
#include <libtorchaudio/iir_cuda.h>
#endif

namespace {

using torch::headeronly::ScalarType;
using torch::stable::Tensor;

template <typename scalar_t>
void host_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
const Tensor& input_signal_windows,
const Tensor& a_coeff_flipped,
Tensor& padded_output_waveform) {
int64_t n_batch = input_signal_windows.size(0);
int64_t n_channel = input_signal_windows.size(1);
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_samples_output = padded_output_waveform.size(2);
int64_t n_order = a_coeff_flipped.size(1);
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();

at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) {
for (auto i = begin; i < end; i++) {
int64_t offset_input = i * n_samples_input;
int64_t offset_output = i * n_samples_output;
int64_t i_channel = i % n_channel;
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff + i_channel * n_order];
scalar_t* output_data =
reinterpret_cast<scalar_t*>(padded_output_waveform.data_ptr());
const scalar_t* input_data =
reinterpret_cast<scalar_t*>(input_signal_windows.data_ptr());
const scalar_t* a_coeff_flipped_data =
reinterpret_cast<scalar_t*>(a_coeff_flipped.data_ptr());

torch::stable::parallel_for(
0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) {
for (auto i = begin; i < end; i++) {
int64_t offset_input = i * n_samples_input;
int64_t offset_output = i * n_samples_output;
int64_t i_channel = i % n_channel;
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff + i_channel * n_order];
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
}
});
});
}

void cpu_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
TORCH_CHECK(
input_signal_windows.device().is_cpu() &&
a_coeff_flipped.device().is_cpu() &&
padded_output_waveform.device().is_cpu());
Tensor cpu_lfilter_core_loop(
Tensor input_signal_windows,
Tensor a_coeff_flipped,
Tensor padded_output_waveform) {
STD_TORCH_CHECK(
input_signal_windows.is_cpu() && a_coeff_flipped.is_cpu() &&
padded_output_waveform.is_cpu());

TORCH_CHECK(
STD_TORCH_CHECK(
input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() &&
padded_output_waveform.is_contiguous());

TORCH_CHECK(
(input_signal_windows.dtype() == torch::kFloat32 ||
input_signal_windows.dtype() == torch::kFloat64) &&
(a_coeff_flipped.dtype() == torch::kFloat32 ||
a_coeff_flipped.dtype() == torch::kFloat64) &&
(padded_output_waveform.dtype() == torch::kFloat32 ||
padded_output_waveform.dtype() == torch::kFloat64));
STD_TORCH_CHECK(
(input_signal_windows.scalar_type() == ScalarType::Float ||
input_signal_windows.scalar_type() == ScalarType::Double) &&
(a_coeff_flipped.scalar_type() == ScalarType::Float ||
a_coeff_flipped.scalar_type() == ScalarType::Double) &&
(padded_output_waveform.scalar_type() == ScalarType::Float ||
padded_output_waveform.scalar_type() == ScalarType::Double));

TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1));
STD_TORCH_CHECK(
input_signal_windows.size(0) == padded_output_waveform.size(0));
STD_TORCH_CHECK(
input_signal_windows.size(1) == padded_output_waveform.size(1));

TORCH_CHECK(
STD_TORCH_CHECK(
input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 ==
padded_output_waveform.size(2));

AT_DISPATCH_FLOATING_TYPES(
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
THO_DISPATCH_V2(
input_signal_windows.scalar_type(),
"lfilter_core_loop",
[&] {
host_lfilter_core_loop<scalar_t>(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
});
},
AT_FLOATING_TYPES);
return padded_output_waveform;
}

void lfilter_core_generic_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
Tensor lfilter_core_generic_loop(
Tensor input_signal_windows,
Tensor a_coeff_flipped,
Tensor padded_output_waveform) {
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_order = a_coeff_flipped.size(1);
auto coeff = a_coeff_flipped.unsqueeze(2);
auto coeff = torchaudio::stable::unsqueeze(a_coeff_flipped, 2);
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
auto windowed_output_signal =
torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order)
.transpose(0, 1);
auto o0 = torch::select(input_signal_windows, 2, i_sample) -
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
padded_output_waveform.index_put_(
{torch::indexing::Slice(),
torch::indexing::Slice(),
i_sample + n_order - 1},
o0);
auto windowed_output_signal = torch::stable::transpose(
torch::stable::narrow(
padded_output_waveform, 2, i_sample, i_sample + n_order),
0,
1);
auto o0 = torchaudio::stable::subtract(
torchaudio::stable::select(input_signal_windows, 2, i_sample),
torch::stable::transpose(
torchaudio::stable::squeeze(
torchaudio::stable::matmul(windowed_output_signal, coeff), 2),
0,
1));
auto s = torchaudio::stable::select(
padded_output_waveform, 2, i_sample + n_order - 1);
torch::stable::copy_(s, o0);
}
return padded_output_waveform;
}

} // namespace

TORCH_LIBRARY(torchaudio, m) {
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()");
"_lfilter_core_loop("
"Tensor input_signal_windows,"
"Tensor a_coeff_flipped,"
"Tensor(a!) padded_output_waveform) -> Tensor(a!)");
}

TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("_lfilter_core_loop", TORCH_BOX(&cpu_lfilter_core_loop));
}

#ifdef USE_CUDA
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop);
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("_lfilter_core_loop", TORCH_BOX(&cuda_lfilter_core_loop));
}
#endif

TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop);
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
m.impl("_lfilter_core_loop", TORCH_BOX(&lfilter_core_generic_loop));
}
47 changes: 47 additions & 0 deletions src/libtorchaudio/stable/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,51 @@ T item(const Tensor& self) {
}
}

inline Tensor unsqueeze(const Tensor& self, int64_t dim) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}

inline Tensor select(const Tensor& self, int64_t dim, int64_t index) {
const auto num_args = 3;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self),
torch::stable::detail::from(dim),
torch::stable::detail::from(index)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::select", "", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}

inline Tensor squeeze(const Tensor& self, int64_t dim) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::squeeze", "dim", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}

inline Tensor matmul(const Tensor& self, const Tensor& other) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(other)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::matmul", "", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}

inline Tensor subtract(const Tensor& self, const Tensor& other) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(other)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::subtract", "Tensor", stack.data(), TORCH_ABI_VERSION));
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}

} // namespace torchaudio::stable
Loading
Loading