diff --git a/src/libtorchaudio/iir_cuda.cu b/src/libtorchaudio/iir_cuda.cu index 2f6b75b239..dd679d9681 100644 --- a/src/libtorchaudio/iir_cuda.cu +++ b/src/libtorchaudio/iir_cuda.cu @@ -1,16 +1,18 @@ +#include +#include +#include #include #include -#include +#include + +using torch::headeronly::ScalarType; +using torch::stable::Tensor; template __global__ void iir_cu_kernel( - const torch:: - PackedTensorAccessor in, - const torch:: - PackedTensorAccessor - a_flipped, - torch::PackedTensorAccessor - out) { + const torchaudio::PackedTensorAccessorSizeT in, + const torchaudio::PackedTensorAccessorSizeT a_flipped, + torchaudio::PackedTensorAccessorSizeT out) { int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; int64_t n = in.size(0); int64_t c = in.size(1); @@ -33,51 +35,48 @@ __global__ void iir_cu_kernel( } } -void cuda_lfilter_core_loop( - const torch::Tensor& in, - const torch::Tensor& a_flipped, - torch::Tensor& padded_out) { - TORCH_CHECK( - in.device().is_cuda() && a_flipped.device().is_cuda() && - padded_out.device().is_cuda()); +Tensor cuda_lfilter_core_loop( + Tensor in, + Tensor a_flipped, + Tensor padded_out) { + STD_TORCH_CHECK( + in.is_cuda() && a_flipped.is_cuda() && + padded_out.is_cuda()); - TORCH_CHECK( + STD_TORCH_CHECK( + (in.get_device_index() == a_flipped.get_device_index()) && + (in.get_device_index() == padded_out.get_device_index())); + + STD_TORCH_CHECK( in.is_contiguous() && a_flipped.is_contiguous() && padded_out.is_contiguous()); - TORCH_CHECK( - (in.dtype() == torch::kFloat32 || in.dtype() == torch::kFloat64) && - (a_flipped.dtype() == torch::kFloat32 || - a_flipped.dtype() == torch::kFloat64) && - (padded_out.dtype() == torch::kFloat32 || - padded_out.dtype() == torch::kFloat64)); + STD_TORCH_CHECK( + (in.scalar_type() == ScalarType::Float || in.scalar_type() == ScalarType::Double) && + (a_flipped.scalar_type() == ScalarType::Float || + a_flipped.scalar_type() == ScalarType::Double) && + (padded_out.scalar_type() == ScalarType::Float || + padded_out.scalar_type() == ScalarType::Double)); const int N = in.size(0); const int C = in.size(1); - TORCH_CHECK(N == padded_out.size(0)); - TORCH_CHECK(C == padded_out.size(1)); + STD_TORCH_CHECK(N == padded_out.size(0)); + STD_TORCH_CHECK(C == padded_out.size(1)); - TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2)); + STD_TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(in)); + const at::cuda::OptionalCUDAGuard device_guard(in.get_device_index()); const dim3 threads(256); const dim3 blocks((N * C + threads.x - 1) / threads.x); - AT_DISPATCH_FLOATING_TYPES( - in.scalar_type(), "iir_cu_loop", ([&] { - iir_cu_kernel<<>>( - in.packed_accessor(), - a_flipped.packed_accessor< - scalar_t, - 2, - torch::RestrictPtrTraits, - size_t>(), - padded_out.packed_accessor< - scalar_t, - 3, - torch::RestrictPtrTraits, - size_t>()); + THO_DISPATCH_V2( + in.scalar_type(), "iir_cu_loop", AT_WRAP([&] { + (iir_cu_kernel<<>>( + torchaudio::packed_accessor_size_t(in), + torchaudio::packed_accessor_size_t(a_flipped), + torchaudio::packed_accessor_size_t(padded_out))); C10_CUDA_KERNEL_LAUNCH_CHECK(); - })); + }), AT_FLOATING_TYPES); + return padded_out; } diff --git a/src/libtorchaudio/iir_cuda.h b/src/libtorchaudio/iir_cuda.h index 9cad1134bb..c1209fe154 100644 --- a/src/libtorchaudio/iir_cuda.h +++ b/src/libtorchaudio/iir_cuda.h @@ -1,8 +1,7 @@ #pragma once -#include +#include -void cuda_lfilter_core_loop( - const torch::Tensor& in, - const torch::Tensor& a_flipped, - torch::Tensor& padded_out); +using torch::stable::Tensor; + +Tensor cuda_lfilter_core_loop(Tensor in, Tensor a_flipped, Tensor padded_out); diff --git a/src/libtorchaudio/lfilter.cpp b/src/libtorchaudio/lfilter.cpp index 9d9b05c7d8..b89cec27d2 100644 --- a/src/libtorchaudio/lfilter.cpp +++ b/src/libtorchaudio/lfilter.cpp @@ -1,5 +1,9 @@ -#include -#include +#include +#include +#include +#include +#include +#include #ifdef USE_CUDA #include @@ -7,110 +11,131 @@ namespace { +using torch::headeronly::ScalarType; +using torch::stable::Tensor; + template void host_lfilter_core_loop( - const torch::Tensor& input_signal_windows, - const torch::Tensor& a_coeff_flipped, - torch::Tensor& padded_output_waveform) { + const Tensor& input_signal_windows, + const Tensor& a_coeff_flipped, + Tensor& padded_output_waveform) { int64_t n_batch = input_signal_windows.size(0); int64_t n_channel = input_signal_windows.size(1); int64_t n_samples_input = input_signal_windows.size(2); int64_t n_samples_output = padded_output_waveform.size(2); int64_t n_order = a_coeff_flipped.size(1); - scalar_t* output_data = padded_output_waveform.data_ptr(); - const scalar_t* input_data = input_signal_windows.data_ptr(); - const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr(); - - at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) { - for (auto i = begin; i < end; i++) { - int64_t offset_input = i * n_samples_input; - int64_t offset_output = i * n_samples_output; - int64_t i_channel = i % n_channel; - for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { - scalar_t a0 = input_data[offset_input + i_sample]; - for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { - a0 -= output_data[offset_output + i_sample + i_coeff] * - a_coeff_flipped_data[i_coeff + i_channel * n_order]; + scalar_t* output_data = + reinterpret_cast(padded_output_waveform.data_ptr()); + const scalar_t* input_data = + reinterpret_cast(input_signal_windows.data_ptr()); + const scalar_t* a_coeff_flipped_data = + reinterpret_cast(a_coeff_flipped.data_ptr()); + + torch::stable::parallel_for( + 0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) { + for (auto i = begin; i < end; i++) { + int64_t offset_input = i * n_samples_input; + int64_t offset_output = i * n_samples_output; + int64_t i_channel = i % n_channel; + for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { + scalar_t a0 = input_data[offset_input + i_sample]; + for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { + a0 -= output_data[offset_output + i_sample + i_coeff] * + a_coeff_flipped_data[i_coeff + i_channel * n_order]; + } + output_data[offset_output + i_sample + n_order - 1] = a0; + } } - output_data[offset_output + i_sample + n_order - 1] = a0; - } - } - }); + }); } -void cpu_lfilter_core_loop( - const torch::Tensor& input_signal_windows, - const torch::Tensor& a_coeff_flipped, - torch::Tensor& padded_output_waveform) { - TORCH_CHECK( - input_signal_windows.device().is_cpu() && - a_coeff_flipped.device().is_cpu() && - padded_output_waveform.device().is_cpu()); +Tensor cpu_lfilter_core_loop( + Tensor input_signal_windows, + Tensor a_coeff_flipped, + Tensor padded_output_waveform) { + STD_TORCH_CHECK( + input_signal_windows.is_cpu() && a_coeff_flipped.is_cpu() && + padded_output_waveform.is_cpu()); - TORCH_CHECK( + STD_TORCH_CHECK( input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() && padded_output_waveform.is_contiguous()); - TORCH_CHECK( - (input_signal_windows.dtype() == torch::kFloat32 || - input_signal_windows.dtype() == torch::kFloat64) && - (a_coeff_flipped.dtype() == torch::kFloat32 || - a_coeff_flipped.dtype() == torch::kFloat64) && - (padded_output_waveform.dtype() == torch::kFloat32 || - padded_output_waveform.dtype() == torch::kFloat64)); + STD_TORCH_CHECK( + (input_signal_windows.scalar_type() == ScalarType::Float || + input_signal_windows.scalar_type() == ScalarType::Double) && + (a_coeff_flipped.scalar_type() == ScalarType::Float || + a_coeff_flipped.scalar_type() == ScalarType::Double) && + (padded_output_waveform.scalar_type() == ScalarType::Float || + padded_output_waveform.scalar_type() == ScalarType::Double)); - TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0)); - TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1)); + STD_TORCH_CHECK( + input_signal_windows.size(0) == padded_output_waveform.size(0)); + STD_TORCH_CHECK( + input_signal_windows.size(1) == padded_output_waveform.size(1)); - TORCH_CHECK( + STD_TORCH_CHECK( input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 == padded_output_waveform.size(2)); - AT_DISPATCH_FLOATING_TYPES( - input_signal_windows.scalar_type(), "lfilter_core_loop", [&] { + THO_DISPATCH_V2( + input_signal_windows.scalar_type(), + "lfilter_core_loop", + [&] { host_lfilter_core_loop( input_signal_windows, a_coeff_flipped, padded_output_waveform); - }); + }, + AT_FLOATING_TYPES); + return padded_output_waveform; } -void lfilter_core_generic_loop( - const torch::Tensor& input_signal_windows, - const torch::Tensor& a_coeff_flipped, - torch::Tensor& padded_output_waveform) { +Tensor lfilter_core_generic_loop( + Tensor input_signal_windows, + Tensor a_coeff_flipped, + Tensor padded_output_waveform) { int64_t n_samples_input = input_signal_windows.size(2); int64_t n_order = a_coeff_flipped.size(1); - auto coeff = a_coeff_flipped.unsqueeze(2); + auto coeff = torchaudio::stable::unsqueeze(a_coeff_flipped, 2); for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { - auto windowed_output_signal = - torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order) - .transpose(0, 1); - auto o0 = torch::select(input_signal_windows, 2, i_sample) - - at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1); - padded_output_waveform.index_put_( - {torch::indexing::Slice(), - torch::indexing::Slice(), - i_sample + n_order - 1}, - o0); + auto windowed_output_signal = torch::stable::transpose( + torch::stable::narrow( + padded_output_waveform, 2, i_sample, i_sample + n_order), + 0, + 1); + auto o0 = torchaudio::stable::subtract( + torchaudio::stable::select(input_signal_windows, 2, i_sample), + torch::stable::transpose( + torchaudio::stable::squeeze( + torchaudio::stable::matmul(windowed_output_signal, coeff), 2), + 0, + 1)); + auto s = torchaudio::stable::select( + padded_output_waveform, 2, i_sample + n_order - 1); + torch::stable::copy_(s, o0); } + return padded_output_waveform; } } // namespace -TORCH_LIBRARY(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( - "torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()"); + "_lfilter_core_loop(" + "Tensor input_signal_windows," + "Tensor a_coeff_flipped," + "Tensor(a!) padded_output_waveform) -> Tensor(a!)"); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("_lfilter_core_loop", TORCH_BOX(&cpu_lfilter_core_loop)); } #ifdef USE_CUDA -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("_lfilter_core_loop", TORCH_BOX(&cuda_lfilter_core_loop)); } #endif -TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) { - m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) { + m.impl("_lfilter_core_loop", TORCH_BOX(&lfilter_core_generic_loop)); } diff --git a/src/libtorchaudio/stable/ops.h b/src/libtorchaudio/stable/ops.h index fa5a89345d..c9b088e53e 100644 --- a/src/libtorchaudio/stable/ops.h +++ b/src/libtorchaudio/stable/ops.h @@ -182,4 +182,51 @@ T item(const Tensor& self) { } } +inline Tensor unsqueeze(const Tensor& self, int64_t dim) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(dim)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +inline Tensor select(const Tensor& self, int64_t dim, int64_t index) { + const auto num_args = 3; + std::array stack{ + torch::stable::detail::from(self), + torch::stable::detail::from(dim), + torch::stable::detail::from(index)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::select", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +inline Tensor squeeze(const Tensor& self, int64_t dim) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(dim)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::squeeze", "dim", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +inline Tensor matmul(const Tensor& self, const Tensor& other) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(other)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::matmul", "", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + +inline Tensor subtract(const Tensor& self, const Tensor& other) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(other)}; + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::subtract", "Tensor", stack.data(), TORCH_ABI_VERSION)); + return torch::stable::detail::to(stack[0]); +} + } // namespace torchaudio::stable diff --git a/src/libtorchaudio/utils.h b/src/libtorchaudio/utils.h index a700e30efa..725b6c9699 100644 --- a/src/libtorchaudio/utils.h +++ b/src/libtorchaudio/utils.h @@ -48,6 +48,24 @@ inline PackedTensorAccessor32 packed_accessor32(Tensor t) { t.sizes().data(), t.strides().data()); } + +template +using PackedTensorAccessorSizeT = + torch::headeronly::HeaderOnlyGenericPackedTensorAccessor< + T, + N, + torch::headeronly::RestrictPtrTraits, + size_t>; + +template +inline PackedTensorAccessorSizeT packed_accessor_size_t(Tensor t) { + return PackedTensorAccessorSizeT( + static_cast::PtrType>( + t.data_ptr()), + t.sizes().data(), + t.strides().data()); +} + #endif } // namespace torchaudio