Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,23 @@
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
nested_params,
)


class _DeterministicWrapper(torch.nn.Module):
"""Helper transform wrapper to make the given transform deterministic"""
def __init__(self, transform, seed=0):
super().__init__()
self.seed = seed
self.transform = transform

def forward(self, input: torch.Tensor):
torch.random.manual_seed(self.seed)
return self.transform(input)


class AutogradTestMixin(TestBaseMixin):
def assert_grad(
self,
Expand Down Expand Up @@ -65,14 +79,20 @@ def test_melspectrogram(self):
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

@parameterized.expand([(0, ), (0.99, )])
def test_griffinlim(self, momentum):
@nested_params(
[0, 0.99],
[False, True],
)
def test_griffinlim(self, momentum, rand_init):
n_fft = 400
n_frames = 5
power = 1
n_iter = 3
spec = torch.rand(n_fft // 2 + 1, n_frames) * n_fft
transform = T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=False)
self.assert_grad(transform, [spec], nondet_tol=1e-10)
spec = get_spectrogram(
get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2),
n_fft=n_fft, power=power)
transform = _DeterministicWrapper(
T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=rand_init, power=power))
self.assert_grad(transform, [spec])

@parameterized.expand([(False, ), (True, )])
def test_mfcc(self, log_mels):
Expand Down
59 changes: 33 additions & 26 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ def spectrogram(
return spec_f


def _get_complex_dtype(real_dtype: torch.dtype):
if real_dtype == torch.double:
return torch.cdouble
if real_dtype == torch.float:
return torch.cfloat
if real_dtype == torch.half:
return torch.complex32
raise ValueError(f'Unexpected dtype {real_dtype}')


def griffinlim(
specgram: Tensor,
window: Tensor,
Expand Down Expand Up @@ -180,23 +190,19 @@ def griffinlim(

specgram = specgram.pow(1 / power)

# randomly initialize the phase
batch, freq, frames = specgram.size()
# initialize the phase
if rand_init:
angles = 2 * math.pi * torch.rand(batch, freq, frames)
angles = torch.rand(
specgram.size(),
dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait the angles here should be a floating point tensor, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, angles is complex value. Griffin-Lim algorithm iteratively optimizes the phase (or the direction in complex plain) of each element of the given spectrogram so that at the end istft give the original waveform.

Copy link
Contributor Author

@mthrok mthrok Apr 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*The original implementation was reusing the variable name, constructing the complex value Tensor called angles from real valued tensor called angles and magnitude.

else:
angles = torch.zeros(batch, freq, frames)
angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \
.to(dtype=specgram.dtype, device=specgram.device)
specgram = specgram.unsqueeze(-1).expand_as(angles)
angles = torch.full(
specgram.size(), 1,
dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)

# And initialize the previous iterate to 0
rebuilt = torch.tensor(0.)

tprev = torch.tensor(0., dtype=specgram.dtype, device=specgram.device)
for _ in range(n_iter):
# Store the previous iterate
tprev = rebuilt

# Invert with our current estimate of the phases
inverse = torch.istft(specgram * angles,
n_fft=n_fft,
Expand All @@ -206,26 +212,27 @@ def griffinlim(
length=length)

# Rebuild the spectrogram
rebuilt = torch.view_as_real(
torch.stft(
input=inverse,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True,
)
rebuilt = torch.stft(
input=inverse,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
return_complex=True,
)

# Update our phase estimates
angles = rebuilt
if momentum:
angles = angles - tprev.mul_(momentum / (1 + momentum))
angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles))
angles = angles.div(angles.abs().add(1e-16))

# Store the previous iterate
tprev = rebuilt

# Return the final phase estimates
waveform = torch.istft(specgram * angles,
Expand Down