Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

contrib STFT magnitudes different to librosa's #15134

Closed
fedden opened this issue Dec 5, 2017 · 9 comments
Closed

contrib STFT magnitudes different to librosa's #15134

fedden opened this issue Dec 5, 2017 · 9 comments
Assignees
Labels
1.4.0 comp:signal tf.signal related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author

Comments

@fedden
Copy link

fedden commented Dec 5, 2017


System information

  • Have I written custom code: Yes
  • OS Platform and Distribution: Linux Ubuntu 16.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version: 1.4
  • Python version: 3.5.2
  • Bazel version: N/A
  • GCC/Compiler version: N/A
  • CUDA/cuDNN version: 6.0
  • GPU model and memory: nvidia quadro m2000m 4gb
  • Exact command to reproduce: See below code

Hi all!

I was comparing the TensorFlow's contrib STFT against librosa's and noticed there are some discrepancies in terms of output between the two. Not sure if this normal between libraries implementations, but I wanted to raise it in case it matters!

I'm also aware it could be some small bug or difference in implementation/argument that I have supplied.

Code:

import tensorflow as tf
import numpy as np
import librosa

np.random.seed(666)
np.set_printoptions(precision=5, suppress=True)

audio_length_seconds = 2
sample_rate = 44100
audio_frames_length = int(sample_rate * audio_length_seconds)
audio_shape = [None, audio_frames_length]
fft_size = 1024
hop_size = 512

tf.reset_default_graph()

audio = tf.placeholder(tf.float32, 
                       shape=audio_shape)
stfts = tf.contrib.signal.stft(audio, 
                               frame_length=fft_size, 
                               frame_step=hop_size,
                               fft_length=fft_size,
                               pad_end=True)
real = tf.real(stfts)
imag = tf.imag(stfts)
magnitudes = tf.abs(stfts)
phases = tf.atan2(imag, real)
features = tf.concat([magnitudes, phases], axis=2)

sess = tf.Session()
with sess.as_default():
    
    data = np.random.random((1, audio_frames_length))
    tf_results = magnitudes.eval({audio: data})
    
    lr_results = librosa.core.stft(y=data.reshape((-1)),
                                   n_fft=fft_size,
                                   hop_length=hop_size,
                                   win_length=fft_size)
                                   
    lr_results = np.abs(lr_results)
    
    difference = np.abs(tf_results - lr_results.T)
    print("Differences:\nmin:", np.min(difference), 
          "max:", np.max(difference), 
          "mean:", np.mean(difference), 
          "std:", np.std(difference))

And the expected output from the print would be:

Differences:
min: 6.97374e-05 max: 246.904 mean: 2.92715 std: 2.45132
@drpngx
Copy link
Contributor

drpngx commented Dec 6, 2017

@rryan do you know anything about that?

@drpngx drpngx added the stat:community support Status - Community Support label Dec 6, 2017
@fedden
Copy link
Author

fedden commented Dec 7, 2017

I was talking to a friend about this and he tested a bunch of prominent FFT libraries such as FFTW, and he noted that the output was different for those. So maybe its to be expected across implementations. Just wanted to let you know at any rate.

@spbolton
Copy link

The difference is probably due to the window function. By default librosa uses the hann window function. You can do similar in tensorflow with the following. I have not tested your code to see if this actually does make it the same.

stfts=tf.contrib.signal.stft(audio, frame_length=fft_size, frame_step=hop_size, window_fn=functools.partial(tf.contrib.signal.hann_window, periodic=True), pad_end=True )

@bmcfee
Copy link

bmcfee commented Mar 10, 2018

What version of librosa are you comparing to? Given the date of the post, I'm assuming <=0.5.1.

We recently fixed a long-standing, inherited bug where the stft was incorrectly conjugated, so that might also contribute to discrepancies. The fix was included in librosa 0.6, see this thread for an explanation of what happened there.

It's also worth checking magnitude ratios, instead of absolute differences, just to see if it's a factor of sqrt(2*pi*n) issue.

@PetrochukM
Copy link

Howdy!

@fedden Did you conclude that the TensorFlow's contrib STFT library is good to go? Are you using it or the librosa one?

@rryan
Copy link
Member

rryan commented Jun 12, 2018

Hi! Thank you @fedden for the helpful repro.

One difference between librosa and tf.contrib.signal.stft is that librosa center-pads the signal with reflection when framing while tf.contrib.signal.stft (and tf.spectral.rfft and np.fft.rfft) pad from the right with zeros by default. Adding support for padding options like these to tf.contrib.signal.stft would be a nice addition!

I think that accounts for the difference here. I put your repro in a Colab notebook and set center=False and the max difference in magnitude between TF and librosa becomes 1e-5.

@rryan rryan added the comp:signal tf.signal related issues label Sep 19, 2019
@mohantym mohantym self-assigned this Dec 17, 2021
@mohantym mohantym added 1.4.0 and removed stat:community support Status - Community Support labels Dec 17, 2021
@mohantym
Copy link
Contributor

Hi @fedden !
It seems you are using older versions(1.x versions) of Tensorflow which is not supported any more. We recommend you try updating your code base to 2.7 and check whether issue exists or not? Thanks!

@mohantym mohantym added the stat:awaiting response Status - Awaiting response from author label Dec 17, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Dec 24, 2021
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1.4.0 comp:signal tf.signal related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author
Projects
None yet
Development

No branches or pull requests

7 participants