Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(stft), "stft.center", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional<at::ScalarType>), fp32)
Expand Down
39 changes: 26 additions & 13 deletions aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,14 +759,11 @@ static Stream& write_opt(Stream& SS, const optional<T>& value) {
*
* This is modeled after librosa but with support for complex time-domain
* signals and complex windows.
*
* NOTE: librosa's center and pad_mode arguments are currently only implemented
* in python because it uses torch.nn.functional.pad which is python-only.
*/
Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
const bool normalized, const optional<bool> onesidedOpt,
const optional<bool> return_complexOpt) {
const bool center, c10::string_view mode, const bool normalized,
const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> window_maybe_owned = at::borrow_from_optional_tensor(window_opt);
const Tensor& window = *window_maybe_owned;
Expand Down Expand Up @@ -824,6 +821,19 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
if (self.dim() == 1) {
input = input.unsqueeze(0);
}

if (center) {
const auto input_shape = input.sizes();
const auto input_dim = input_shape.size();
const auto extra_dims = std::max(size_t{3}, input_dim) - input_dim;
const auto pad_amount = n_fft / 2;

DimVector extended_shape(extra_dims, 1);
extended_shape.append(input_shape.begin(), input_shape.end());
input = at::pad(input.view(extended_shape), {pad_amount, pad_amount}, mode);
input = input.view(IntArrayRef(input.sizes()).slice(extra_dims));
}

int64_t batch = input.size(0);
int64_t len = input.size(1);
if (n_fft <= 0 || n_fft > len) {
Expand Down Expand Up @@ -897,6 +907,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
}
}

Tensor stft(
const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
const bool normalized,
const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
return at::stft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
/*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
return_complexOpt);
}

// Create complex tensor from the old style of real tensor with size=(..., 2)
// This is to support istft in the transition to requiring complex input.
// NOTE: This may return a view of the input tensor, or might clone if necessary
Expand Down Expand Up @@ -1090,14 +1111,6 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
#undef REPR
}

Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const Tensor& window,
const bool normalized, const optional<bool> onesidedOpt) {
return at::native::stft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window, normalized, onesidedOpt,
/*return_complex=*/c10::nullopt);
}

Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const Tensor& window,
const bool center, const bool normalized, const optional<bool> onesidedOpt,
Expand Down
9 changes: 5 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4243,12 +4243,13 @@

- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)

# The signature is designed to be consistent with librosa except that it is
# missing the `pad_mode` and `center` arguments, which are taken care of at
# `torch.functional.py`. They shall be moved here once we have mapping between
# Python strings and C++ Enum in codegen.
# Overload without center & pad mode, needed for forward-compatibility
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
variants: function, method
cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized']

- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
variants: function, method

- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor
variants: function, method
Expand Down
4 changes: 2 additions & 2 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,8 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex)
# TODO: after having proper ways to map Python strings to ATen Enum, move
# this and F.pad to ATen.
# NOTE: Do not edit. This code will be removed once the forward-compatibility
# period is over for PR #73432
if center:
signal_dim = input.dim()
extended_shape = [1] * (3 - signal_dim) + list(input.size())
Expand Down