Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable mel_scale option #593

Merged
merged 1 commit into from Mar 2, 2021
Merged

enable mel_scale option #593

merged 1 commit into from Mar 2, 2021

Conversation

vincentqb
Copy link
Contributor

@vincentqb vincentqb commented Apr 28, 2020

Add mel_scale option enabling the htk or slaney mel scale, see here for discussion and librosa.

Relates #589, internal, (#608?), (#259 for mel-hz functional)

@@ -431,6 +496,7 @@ def create_fb_matrix(
f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform
htk (bool): Use HTK formula instead of Slaney
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this instead be a flag that allows the user to pass a selection (e.g. "Slaney" or "HTK") so that we could add more formulas later on (if that could come up)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The htk boolean flag is a convention from librosa. Someone had also suggested mel_scale="htk"|"slaney", see here.

Copy link
Contributor

@cpuhrsch cpuhrsch Apr 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason htk (bool) is preferred over the mel_scale flag (outside of it being a convention from a popular library)? In any case, I, of course, trust your judgement here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see other reasons. I prefer mel_scale personally as it is more explicit about which convention is being used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to mel_scale

@vincentqb vincentqb marked this pull request as ready for review May 4, 2020 22:36
@vincentqb vincentqb changed the title enable htk option [WIP] enable htk option May 5, 2020
@vincentqb vincentqb changed the title [WIP] enable htk option enable mel_scale option May 5, 2020
@vincentqb vincentqb requested review from cpuhrsch and mthrok May 5, 2020 20:24

return mels
else:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
Copy link
Collaborator

@mthrok mthrok May 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, it is more readable if argument validation is carried out at the beginning of the function.

Also, (this is kind of nit) when one branch of if ~ elif ~ else clause is much longer than the others, extracting the content to separate function improves readability.

i.e.

if mel_scale not in ['slaney', 'htk']:
    raise ValueError('mel_scale should be one of "htk" or "slaney".')

if mel_scale == 'htk':
    return 2595.0 * torch.log10(torch.tensor(1.0 + (freq / 700.0), dtype=torch.float64))
else:
    return _hz_to_mel_slaney(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with checking at the beginning. For the readability, since this is already a private function, how does the following looks to you?

if mel_scale not in ['slaney', 'htk']:
    raise ValueError('mel_scale should be one of "htk" or "slaney".')

if mel_scale == "htk":
    return 2595.0 * torch.log10(torch.tensor(1.0 + (freq / 700.0), dtype=torch.float64))

# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3

mels = (freq - f_min) / f_sp

# Fill in the log-scale part
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0

log_t = freq >= min_log_hz
mels[log_t] = min_log_mel + torch.log(torch.tensor(freq[log_t] / min_log_hz, dtype=torch.float64)) / logstep

return mels

Returns:
mels (Tensor): Input frequencies in Mels
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a good idea to add a note on why we have to use float64 for log10.

def create_fb_matrix(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int
sample_rate: int,
mel_scale: str = "htk",
) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the new implementation returns float64 Tensor whereas the original implementation returns float32. If so, this is BC breaking.

Either case, I think it's good idea to document what dtype the resulting tensor is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll cast to float32 at the end of the function. We can always add a dtype parameter to create_fb_matrix if other types are needed later.

The dtype returned should be documented in all functions indeed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversions like that should incur a fully copy. I'd double check what the performance implications of using float64 are. We could also set expectations and let the user pass float64 Tensors if they want higher precision.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the idea, but I would leave that discussion as a separate PR, since this may affect the outcome of #611.

@@ -446,22 +515,25 @@ def create_fb_matrix(
# Equivalent filterbank construction by Librosa
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

# torch.log10 with float32 produces different results on different CPUs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note added in code since this is not visible to the end-user

@vincentqb vincentqb requested a review from mthrok May 6, 2020 18:13
raise ValueError('mel_scale should be one of "htk" or "slaney".')

if mel_scale == "htk":
return 2595.0 * torch.log10(1.0 + (freq / 700.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's a way of rearranging the calculation to improve numerical stability. We know that all values will be greater than 1, if that helps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying torch.log10((700. + freq) / 700.).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are typical ranges on values within freq and do our tests mimic those adequately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some hints here and here. The former aligns with the tests we already have. Others to suggest?

Copy link
Contributor Author

@vincentqb vincentqb May 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.log10((700. + freq) / 700.) behaves better and doesn't require float conversions :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't necessarily just rely on the librosa tests. We should understand how this is used for our own sake and to have something to point to when it fails outside of those domains.

@vincentqb vincentqb changed the title enable mel_scale option [WIP] enable mel_scale option May 15, 2020
@vincentqb vincentqb mentioned this pull request Jul 2, 2020
2 tasks
@vincentqb vincentqb marked this pull request as draft September 30, 2020 18:33
@vincentqb vincentqb changed the title [WIP] enable mel_scale option enable mel_scale option Sep 30, 2020
@vincentqb
Copy link
Contributor Author

Rebased, and extended to MelScale, MelSpectrogram.

raise ValueError('mel_scale should be one of "htk" or "slaney".')

if mel_scale == "htk":
return 2595.0 * torch.log10((700.0 + freq) / 700.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep math.log10 as before?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that break torchscript compatibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though I rewrote the function with the math library and it is also torchscriptable now

librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=norm)
hop_length=hop_length, n_mels=n_mels, htk=mel_scale == "htk", norm=norm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

sample_rate = 16000
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
sound_librosa = sound.cpu().numpy().squeeze()
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm)
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm, mel_scale=mel_scale)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 👍

@vincentqb
Copy link
Contributor Author

Current error:

torchaudio_unittest/librosa_compatibility_test.py::TestTransforms::test_mel_spectrogram_04 FAILED [ 24%]

...

=================================== FAILURES ===================================
____________________ TestTransforms.test_mel_spectrogram_04 ____________________

a = (<torchaudio_unittest.librosa_compatibility_test.TestTransforms testMethod=test_mel_spectrogram_04>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../env/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
torchaudio_unittest/librosa_compatibility_test.py:70: in test_mel_spectrogram
    self.assertEqual(
../env/lib/python3.8/site-packages/torch/testing/_internal/common_utils.py:1189: in assertEqual
    super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
E   AssertionError: False is not true : Tensors failed to compare as equal!With rtol=1e-05 and atol=0.005, found 157 element(s) (out of 20608) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.0771484375 (2955.172607421875 vs. 2955.095458984375), which occurred at index (16, 128).
=============================== warnings summary ===============================
test/torchaudio_unittest/batch_consistency_test.py::TestTransforms::test_batch_InverseMelScale
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/functional/functional.py:426: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (32) may be set too high. Or, the value for `n_freqs` (5) may be set too low.
    warnings.warn(

test/torchaudio_unittest/batch_consistency_test.py::TestTransforms::test_batch_MelScale
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/functional/functional.py:426: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (31) may be set too low.
    warnings.warn(

test/torchaudio_unittest/batch_consistency_test.py: 2 warnings
test/torchaudio_unittest/librosa_compatibility_test.py: 7 warnings
test/torchaudio_unittest/torchscript_consistency_cpu_test.py: 4 warnings
test/torchaudio_unittest/transforms_test.py: 4 warnings
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/functional/functional.py:426: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (201) may be set too low.
    warnings.warn(

test/torchaudio_unittest/kaldi_io_test.py::Test_KaldiIO::test_read_mat_ark
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/kaldi_io.py:42: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /opt/conda/conda-bld/pytorch_1614585890319/work/torch/csrc/utils/tensor_numpy.cpp:179.)
    yield key, torch.from_numpy(np_arr)

test/torchaudio_unittest/librosa_compatibility_test.py::TestTransforms::test_InverseMelScale
test/torchaudio_unittest/librosa_compatibility_test.py::TestTransforms::test_MelScale
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/functional/functional.py:426: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (256) may be set too high. Or, the value for `n_freqs` (1025) may be set too low.
    warnings.warn(

test/torchaudio_unittest/librosa_compatibility_test.py: 12 warnings
test/torchaudio_unittest/functional/librosa_compatibility_test.py: 1 warning
  /root/project/env/lib/python3.8/site-packages/librosa/filters.py:238: UserWarning: Empty filters detected in mel frequency basis. Some channels will produce empty responses. Try increasing your sampling rate (and fmax) or reducing n_mels.
    warnings.warn(

test/torchaudio_unittest/librosa_compatibility_test.py::TestTransforms::test_mel_spectrogram_08
test/torchaudio_unittest/librosa_compatibility_test.py::TestTransforms::test_mel_spectrogram_09
test/torchaudio_unittest/librosa_compatibility_test.py::TestTransforms::test_mel_spectrogram_10
test/torchaudio_unittest/librosa_compatibility_test.py::TestTransforms::test_mel_spectrogram_11
test/torchaudio_unittest/librosa_compatibility_test.py::TestTransforms::test_mfcc_2
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/functional/functional.py:426: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (101) may be set too low.
    warnings.warn(

test/torchaudio_unittest/torchscript_consistency_cpu_test.py::TestTransformsFloat32::test_GriffinLim
  /root/project/env/lib/python3.8/site-packages/torch/functional.py:654: UserWarning: istft will require a complex-valued input tensor in a future PyTorch release. Matching the output from stft with return_complex=True.  (Triggered internally at  /opt/conda/conda-bld/pytorch_1614585890319/work/aten/src/ATen/native/SpectralOps.cpp:787.)
    return _VF.istft(input, n_fft, hop_length, win_length, window, center,  # type: ignore

test/torchaudio_unittest/torchscript_consistency_cpu_test.py::TestTransformsFloat32::test_MelScale
test/torchaudio_unittest/torchscript_consistency_cpu_test.py::TestTransformsFloat64::test_MelScale
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/functional/functional.py:426: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (6) may be set too low.
    warnings.warn(

test/torchaudio_unittest/backend/soundfile/save_test.py: 10 warnings
  /root/project/test/torchaudio_unittest/common_utils/wav_utils.py:77: WavFileWarning: Chunk (non-data) not understood, skipping it.
    sample_rate, data = scipy.io.wavfile.read(path)

test/torchaudio_unittest/backend/sox_io/info_test.py: 22 warnings
test/torchaudio_unittest/backend/sox_io/load_test.py: 6 warnings
test/torchaudio_unittest/sox_effect/sox_effect_test.py: 4 warnings
  /root/project/test/torchaudio_unittest/common_utils/sox_utils.py:32: UserWarning: Use get_wav_data and save_wav to generate wav file for accurate result.
    warnings.warn('Use get_wav_data and save_wav to generate wav file for accurate result.')

test/torchaudio_unittest/functional/functional_cpu_test.py::TestMaskAlongAxis::test_mask_along_axis_0
  /root/project/test/torchaudio_unittest/functional/functional_cpu_test.py:196: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
  To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /opt/conda/conda-bld/pytorch_1614585890319/work/aten/src/ATen/native/BinaryOps.cpp:329.)
    num_masked_columns //= mask_specgram.size(0)

test/torchaudio_unittest/functional/librosa_compatibility_test.py::TestFunctional::test_create_fb
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/functional/functional.py:426: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (56) may be set too high. Or, the value for `n_freqs` (1025) may be set too low.
    warnings.warn(

test/torchaudio_unittest/functional/librosa_compatibility_test.py::TestFunctional::test_create_fb
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/functional/functional.py:426: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (10) may be set too high. Or, the value for `n_freqs` (1025) may be set too low.
    warnings.warn(

test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py::TestFunctionalFloat32::test_create_fb_matrix
test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py::TestFunctionalFloat32::test_create_fb_matrix
test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py::TestFunctionalFloat64::test_create_fb_matrix
test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py::TestFunctionalFloat64::test_create_fb_matrix
  /root/project/env/lib/python3.8/site-packages/torchaudio-0.8.0a0+0d8c3de-py3.8-linux-x86_64.egg/torchaudio/functional/functional.py:426: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (10) may be set too high. Or, the value for `n_freqs` (100) may be set too low.
    warnings.warn(

@vincentqb
Copy link
Contributor Author

Tests are now green, but for one unrelated here. No test thresholds needed to be adjusted for the current implementation, and the implementation remains the same as before for mel_scale='htk'.

Copy link
Collaborator

@mthrok mthrok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants