Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BC-Breaking: Remove deprecated normalized argument from griffinlim #1369

Merged
merged 9 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,13 @@ def test_griffinlim(self):
hop = 200
window = torch.hann_window(ws)
power = 2
normalize = False
momentum = 0.99
n_iter = 32
length = 1000
torch.random.manual_seed(0)
batch = torch.rand(self.batch_size, 1, 201, 6)
self.assert_batch_consistency(
F.griffinlim, batch, window, n_fft, hop, ws, power, normalize,
F.griffinlim, batch, window, n_fft, hop, ws, power,
n_iter, momentum, length, 0, atol=5e-5)

@parameterized.expand(list(itertools.product(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_griffinlim(self):
init = 'random' if rand_init else None

specgram = F.spectrogram(tensor, 0, window, n_fft, hop, ws, 2, normalize).sqrt()
ta_out = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize,
ta_out = F.griffinlim(specgram, window, n_fft, hop, ws, 1,
n_iter, momentum, length, rand_init)
lr_out = librosa.griffinlim(specgram.squeeze(0).numpy(), n_iter=n_iter, hop_length=hop,
momentum=momentum, init=init, length=length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ def func(tensor):
hop = 200
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.
normalize = False
momentum = 0.99
n_iter = 32
length = 1000
rand_int = False
return F.griffinlim(tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_int)
return F.griffinlim(tensor, window, n_fft, hop, ws, power, n_iter, momentum, length, rand_int)

tensor = torch.rand((1, 201, 6))
self._assert_consistency(func, tensor)
Expand Down
8 changes: 0 additions & 8 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def griffinlim(
hop_length: int,
win_length: int,
power: float,
normalized: bool,
n_iter: int,
momentum: float,
length: Optional[int],
Expand Down Expand Up @@ -148,7 +147,6 @@ def griffinlim(
win_length (int): Window size. (Default: ``n_fft``)
power (float): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
normalized (bool): Whether to normalize by magnitude after stft.
n_iter (int): Number of iteration for phase recovery process.
momentum (float): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method.
Expand All @@ -162,12 +160,6 @@ def griffinlim(
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum)

if normalized:
warnings.warn(
"The argument normalized is not used in Griffin-Lim, "
"and will be removed in v0.9.0 release. To suppress this warning, "
"please use `normalized=False`.")

# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
Expand Down
5 changes: 1 addition & 4 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ class GriffinLim(torch.nn.Module):
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
momentum (float, optional): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method.
Expand Down Expand Up @@ -158,7 +157,6 @@ def __init__(self,
hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: float = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None,
momentum: float = 0.99,
length: Optional[int] = None,
Expand All @@ -174,7 +172,6 @@ def __init__(self,
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.normalized = normalized
self.length = length
self.power = power
self.momentum = momentum / (1 + momentum)
Expand All @@ -191,7 +188,7 @@ def forward(self, specgram: Tensor) -> Tensor:
Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
"""
return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)
self.n_iter, self.momentum, self.length, self.rand_init)


class AmplitudeToDB(torch.nn.Module):
Expand Down