Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "replace reshape by view" #594

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 22 additions & 21 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def istft(

# pack batch
shape = stft_matrix.size()
stft_matrix = stft_matrix.view(-1, shape[-3], shape[-2], shape[-1])
stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1])

dtype = stft_matrix.dtype
device = stft_matrix.device
Expand Down Expand Up @@ -196,7 +196,7 @@ def istft(
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)

# unpack batch
y = y.view(shape[:-3] + y.shape[-1:])
y = y.reshape(shape[:-3] + y.shape[-1:])

if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0)
Expand Down Expand Up @@ -241,15 +241,15 @@ def spectrogram(

# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = waveform.reshape(-1, shape[-1])

# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.stft(
waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
)

# unpack batch
spec_f = spec_f.view(shape[:-1] + spec_f.shape[-3:])
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])

if normalized:
spec_f /= window.pow(2.).sum().sqrt()
Expand Down Expand Up @@ -314,7 +314,7 @@ def griffinlim(

# pack batch
shape = specgram.size()
specgram = specgram.view([-1] + list(shape[-2:]))
specgram = specgram.reshape([-1] + list(shape[-2:]))

specgram = specgram.pow(1 / power)

Expand Down Expand Up @@ -360,7 +360,7 @@ def griffinlim(
length=length)

# unpack batch
waveform = waveform.view(shape[:-2] + waveform.shape[-1:])
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])

return waveform

Expand Down Expand Up @@ -623,7 +623,7 @@ def phase_vocoder(

# pack batch
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.view([-1] + list(shape[-3:]))
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))

time_steps = torch.arange(0,
complex_specgrams.size(-2),
Expand Down Expand Up @@ -663,7 +663,7 @@ def phase_vocoder(
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)

# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.view(shape[:-3] + complex_specgrams_stretch.shape[1:])
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])

return complex_specgrams_stretch

Expand All @@ -689,7 +689,7 @@ def lfilter(
"""
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = waveform.reshape(-1, shape[-1])

assert (a_coeffs.size(0) == b_coeffs.size(0))
assert (len(waveform.size()) == 2)
Expand Down Expand Up @@ -732,7 +732,7 @@ def lfilter(
output = torch.clamp(padded_output_waveform[:, (n_order - 1):], min=-1., max=1.)

# unpack batch
output = output.view(shape[:-1] + output.shape[-1:])
output = output.reshape(shape[:-1] + output.shape[-1:])

return output

Expand Down Expand Up @@ -1362,7 +1362,7 @@ def mask_along_axis(

# pack batch
shape = specgram.size()
specgram = specgram.view([-1] + list(shape[-2:]))
specgram = specgram.reshape([-1] + list(shape[-2:]))

value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)
Expand All @@ -1379,7 +1379,7 @@ def mask_along_axis(
raise ValueError('Only Frequency and Time masking are supported')

# unpack batch
specgram = specgram.view(shape[:-2] + specgram.shape[-2:])
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])

return specgram

Expand Down Expand Up @@ -1416,7 +1416,7 @@ def compute_deltas(

# pack batch
shape = specgram.size()
specgram = specgram.view(1, -1, shape[-1])
specgram = specgram.reshape(1, -1, shape[-1])

assert win_length >= 3

Expand All @@ -1432,7 +1432,7 @@ def compute_deltas(
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom

# unpack batch
output = output.view(shape)
output = output.reshape(shape)

return output

Expand Down Expand Up @@ -1466,10 +1466,11 @@ def _add_noise_shaping(
error[n] = dithered[n] - original[n]
noise_shaped_waveform[n] = dithered[n] + error[n-1]
"""
waveform = waveform.view(-1, waveform.size()[-1])
wf_shape = waveform.size()
waveform = waveform.reshape(-1, wf_shape[-1])

dithered_shape = dithered_waveform.size()
dithered_waveform = dithered_waveform.view(-1, dithered_shape[-1])
dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])

error = dithered_waveform - waveform

Expand All @@ -1480,7 +1481,7 @@ def _add_noise_shaping(
error[index] = error_offset[:waveform.size()[1]]

noise_shaped = dithered_waveform + error
return noise_shaped.view(dithered_shape[:-1] + noise_shaped.shape[-1:])
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])


def _apply_probability_distribution(
Expand Down Expand Up @@ -1513,7 +1514,7 @@ def _apply_probability_distribution(

# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = waveform.reshape(-1, shape[-1])

channel_size = waveform.size()[0] - 1
time_size = waveform.size()[-1] - 1
Expand Down Expand Up @@ -1554,7 +1555,7 @@ def _apply_probability_distribution(
quantised_signal = quantised_signal_scaled / down_scaling

# unpack batch
return quantised_signal.view(shape[:-1] + quantised_signal.shape[-1:])
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])


def dither(
Expand Down Expand Up @@ -1732,7 +1733,7 @@ def detect_pitch_frequency(
"""
# pack batch
shape = list(waveform.size())
waveform = waveform.view([-1] + shape[-1:])
waveform = waveform.reshape([-1] + shape[-1:])

nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
indices = _find_max_per_frame(nccf, sample_rate, freq_high)
Expand All @@ -1743,7 +1744,7 @@ def detect_pitch_frequency(
freq = sample_rate / (EPSILON + indices.to(torch.float))

# unpack batch
freq = freq.view(shape[:-1] + list(freq.shape[-1:]))
freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))

return freq

Expand Down
8 changes: 4 additions & 4 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def forward(self, specgram: Tensor) -> Tensor:

# pack batch
shape = specgram.size()
specgram = specgram.view(-1, shape[-2], shape[-1])
specgram = specgram.reshape(-1, shape[-2], shape[-1])

if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
Expand All @@ -260,7 +260,7 @@ def forward(self, specgram: Tensor) -> Tensor:
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)

# unpack batch
mel_specgram = mel_specgram.view(shape[:-2] + mel_specgram.shape[-2:])
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])

return mel_specgram

Expand Down Expand Up @@ -485,7 +485,7 @@ def forward(self, waveform: Tensor) -> Tensor:

# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = waveform.reshape(-1, shape[-1])

mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels:
Expand All @@ -498,7 +498,7 @@ def forward(self, waveform: Tensor) -> Tensor:
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)

# unpack batch
mfcc = mfcc.view(shape[:-1] + mfcc.shape[-2:])
mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])

return mfcc

Expand Down