Skip to content

Commit

Permalink
Warn if the input dtype to TimeStretch is not complex (#3695)
Browse files Browse the repository at this point in the history
Addresses #3688
  • Loading branch information
mthrok committed Nov 10, 2023
1 parent 172260f commit ccd78ff
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,13 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =
Stretched spectrogram. The resulting tensor is of the corresponding complex dtype
as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``.
"""
if not torch.is_complex(complex_specgrams):
warnings.warn(
"The input to TimeStretch must be complex type. "
"Providing non-complex tensor produces invalid results.",
stacklevel=4,
)

if overriding_rate is None:
if self.fixed_rate is None:
raise ValueError("If no fixed_rate is specified, must pass a valid rate to the forward method.")
Expand Down

0 comments on commit ccd78ff

Please sign in to comment.