From ccd78ffba3208b4a9f5face4bd76e942ddc2d4a3 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 10 Nov 2023 00:05:45 -0500 Subject: [PATCH] Warn if the input dtype to TimeStretch is not complex (#3695) Addresses #3688 --- src/torchaudio/transforms/_transforms.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 654b25c213..802cbd3d77 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -1066,6 +1066,13 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = Stretched spectrogram. The resulting tensor is of the corresponding complex dtype as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``. """ + if not torch.is_complex(complex_specgrams): + warnings.warn( + "The input to TimeStretch must be complex type. " + "Providing non-complex tensor produces invalid results.", + stacklevel=4, + ) + if overriding_rate is None: if self.fixed_rate is None: raise ValueError("If no fixed_rate is specified, must pass a valid rate to the forward method.")