Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug with frame_step in tf.contrib.signal overlap_and_add inverse_stft #16465

Closed
memo opened this issue Jan 26, 2018 · 14 comments
Closed

bug with frame_step in tf.contrib.signal overlap_and_add inverse_stft #16465

memo opened this issue Jan 26, 2018 · 14 comments
Assignees
Labels
comp:signal tf.signal related issues

Comments

@memo
Copy link
Contributor

memo commented Jan 26, 2018

System information

  • Based on example
  • Linux Ubuntu 16.04
  • installed from binary
  • v1.4.0-19-ga52c8d9, 1.4.1; also 1.5.0
  • Python 2.7.14 |Anaconda custom (64-bit)| (default, Oct 16 2017, 17:29:19). IPython 5.4.1
  • Cuda release 8.0, V8.0.61, cuDNN 6; also Cuda release 9.0, V9.0.176, cuDNN 7.0.5
  • Geforce GTX 970M, also GTX 1070, Driver Version: 384.111

Describe the problem

A.) When I create frames from a signal with frame_length=1024 and frame_step=256 (i.e. 25% hop size, 75% overlap) using a hann window (also tried hamming), and then I reconstruct with overlap_and_add, I'd expect the signal to be reconstructed correctly (because of COLA etc). But instead it comes out exactly double the amplitude. I need to divide the resulting signal by two for it to be correct.

B.) If I use STFT to create a series of overlapping spectrograms, and then reconstruct with inverse STFT, again with frame_length=1024 and frame_step=256, the signal is again reconstructed at double amplitude.

I realise why these might be the case (unity gain at 50% overlap for hann, so 75% overlap will double the signal). But is it not normal for the reconstruction function to take this into account? E.g. librosa istft does return signal with correct amplitude while tensorflow returns double.

C.)
At any other frame_step there is severe amplitude modulation going on. See images below. This doesn't seem right at all.

UPDATE: If I explicitly set window_fn=tf.contrib.signal.inverse_stft_window_fn(frame_step) in inverse_stft the output is correct. So it seems the frame_step parameter in inverse_stft is not being passed into the window function (which is also what the results hint at).

Source code / logs

original data:
22050 orig

tensorflow output from frames + overlap_and_add:
tensorflow 22050 frame l1024 s256
tensorflow 22050 frame l1024 s512
tensorflow 22050 frame l1024 s768
tensorflow 22050 frame l1024 s1024

tensorflow output from stft+istft:
tensorflow 22050 stft l1024 s256
tensorflow 22050 stft l1024 s512
tensorflow 22050 stft l1024 s768
tensorflow 22050 stft l1024 s1024

librosa output from stft+istft:
librosa 22050 stft l1024 s256
librosa 22050 stft l1024 s512
librosa 22050 stft l1024 s768
librosa 22050 stft l1024 s1024

tensorflow code:

from __future__ import print_function
from __future__ import division

import numpy as np
import scipy.io.wavfile
import math
import random
import matplotlib.pyplot as plt

import tensorflow as tf
out_prefix = 'tensorflow'


def plot(data, title, do_save=True):
    plt.figure(figsize=(20,5))
    plt.plot(data[:3*frame_length])
    plt.ylim([-1, 1])
    plt.title(title)
    plt.grid()
    if do_save: plt.savefig(title + '.png')
    plt.show()


def reconstruct_from_frames(x, frame_length, frame_step):
    name = 'frame'
    frames_T = tf.contrib.signal.frame(x, frame_length=frame_length, frame_step=frame_step)
    windowed_frames_T = frames_T * tf.contrib.signal.hann_window(frame_length, periodic=True)
    output_T = tf.contrib.signal.overlap_and_add(windowed_frames_T, frame_step=frame_step)
    return name, output_T


def reconstruct_from_stft(x, frame_length, frame_step):
    name = 'stft'
    spectrograms_T = tf.contrib.signal.stft(x, frame_length, frame_step)
    output_T = tf.contrib.signal.inverse_stft(spectrograms_T, frame_length, frame_step)
    return name, output_T


def test(fn, input_data):
    print('-'*80)
    tf.reset_default_graph()
    input_T = tf.placeholder(tf.float32, [None]) 
    name, output_T = fn(input_T, frame_length, frame_step)

    title = "{}.{}.{}.l{}.s{}".format(out_prefix, sample_rate, name, frame_length, frame_step)
    print(title)

    with tf.Session():
        output_data =  output_T.eval({input_T:input_data})

#    output_data /= frame_length/frame_step/2 # tensorflow needs this to normalise amp
    plot(output_data, title)
    scipy.io.wavfile.write(title+'.wav', sample_rate, output_data)


def generate_data(duration_secs, sample_rate, num_sin, min_freq=10, max_freq=500, rnd_seed=0, max_val=0):
    '''generate signal from multiple random sin waves'''
    if rnd_seed>0: random.seed(rnd_seed)
    data = np.zeros([duration_secs*sample_rate], np.float32)
    for i in range(num_sin):
        w = np.float32(np.sin(np.linspace(0, math.pi*2*random.randrange(min_freq, max_freq), num=duration_secs*sample_rate)))
        data += random.random() * w
    if max_val>0:
        data *= max_val / np.max(np.abs(data))
    return data
    

frame_length = 1024
sample_rate = 22050

input_data = generate_data(duration_secs=1, sample_rate=sample_rate, num_sin=1, rnd_seed=2, max_val=0.5)

title = "{}.orig".format(sample_rate)
plot(input_data, title)
scipy.io.wavfile.write(title+'.wav', sample_rate, input_data)

for frame_step in [256, 512, 768, 1024]:
    test(reconstruct_from_frames, input_data)
    test(reconstruct_from_stft, input_data)

print('done.')

librosa code:

from __future__ import print_function
from __future__ import division

import numpy as np
import scipy.io.wavfile
import math
import random
import matplotlib.pyplot as plt

import librosa.core as lc
out_prefix = 'librosa'


def plot(data, title, do_save=True):
    plt.figure(figsize=(20,5))
    plt.plot(data[:3*frame_length])
    plt.ylim([-1, 1])
    plt.title(title)
    plt.grid()
    if do_save: plt.savefig(title + '.png')
    plt.show()


def reconstruct_from_stft(x, frame_length, frame_step):
    name = 'stft'
    stft = lc.stft(x, n_fft=frame_length, hop_length=frame_step)
    istft = lc.istft(stft, frame_step)
    return name, istft


def test(fn, input_data):
    print('-'*80)
    name, output_data = fn(input_data, frame_length, frame_step)

    title = "{}.{}.{}.l{}.s{}".format(out_prefix, sample_rate, name, frame_length, frame_step)
    print(title)

#    output_data /= frame_length/frame_step/2 # tensorflow needs this to normalise amp
    plot(output_data, title)
    scipy.io.wavfile.write(title+'.wav', sample_rate, output_data)


def generate_data(duration_secs, sample_rate, num_sin, min_freq=10, max_freq=500, rnd_seed=0, max_val=0):
    '''generate signal from multiple random sin waves'''
    if rnd_seed>0: random.seed(rnd_seed)
    data = np.zeros([duration_secs*sample_rate], np.float32)
    for i in range(num_sin):
        w = np.float32(np.sin(np.linspace(0, math.pi*2*random.randrange(min_freq, max_freq), num=duration_secs*sample_rate)))
        data += random.random() * w
    if max_val>0:
        data *= max_val / np.max(np.abs(data))
    return data
    

frame_length = 1024
sample_rate = 22050

input_data = generate_data(duration_secs=1, sample_rate=sample_rate, num_sin=1, rnd_seed=2, max_val=0.5)

title = "{}.orig".format(sample_rate)
plot(input_data, title)
scipy.io.wavfile.write(title+'.wav', sample_rate, input_data)

for frame_step in [256, 512, 768, 1024]:
    test(reconstruct_from_stft, input_data)

print('done.')
@memo memo changed the title bug with frame_step in tf.contrib.signal.frame (or am I doing something wrong?) bug with frame_step in tf.contrib.signal.frame or overlap_and_add (or am I doing something wrong?) Jan 26, 2018
@memo memo changed the title bug with frame_step in tf.contrib.signal.frame or overlap_and_add (or am I doing something wrong?) bug with frame_step in tf.contrib.signal frame overlap_and_add stft inverse_stft Jan 30, 2018
@memo memo changed the title bug with frame_step in tf.contrib.signal frame overlap_and_add stft inverse_stft bug with frame_step in tf.contrib.signal overlap_and_add inverse_stft Jan 30, 2018
@reedwm
Copy link
Member

reedwm commented Jan 30, 2018

/CC @rryan, can you take a look?

@reedwm reedwm added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 30, 2018
@rryan
Copy link
Member

rryan commented Feb 2, 2018

Thanks very much for the detailed bug report @memo! I'll take a look, though I probably won't have time until 2/9.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 6, 2018
@andimarafioti
Copy link

I am having a (probably) related problem when I try to use the istft to reconstruct a signal.

figure_1

import functools

import tensorflow as tf
from tensorflow.contrib.signal.python.ops import window_ops

sampling_rate = 44000
freq = 440
countOfCycles = 4
_time = tf.range(0, 1024 / sampling_rate, 1 / sampling_rate, dtype=tf.float32)
firstSignal = tf.sin(2 * 3.14159 * freq * _time)

with tf.name_scope('Energy_Spectogram'):
    fft_frame_length = 128
    fft_frame_step = 32
    window_fn = functools.partial(window_ops.hann_window, periodic=True)
    stft = tf.contrib.signal.stft(signals=firstSignal, frame_length=fft_frame_length, frame_step=fft_frame_step,
                                  window_fn=window_fn)
    istft = tf.contrib.signal.inverse_stft(stfts=stft, frame_length=fft_frame_length, frame_step=fft_frame_step,
    window_fn=tf.contrib.signal.inverse_stft_window_fn(fft_frame_step,
                                           forward_window_fn=window_fn))

with tf.Session() as sess:
    original, reconstructed = sess.run([firstSignal, istft])

import matplotlib.pyplot as plt

plt.plot(original)
plt.plot(reconstructed)
plt.show()

Note that the problem is worse when you don't explicitly give the window (which in this case is the default one). Giving amplitude modulation all across the signal.

@nuchi
Copy link
Contributor

nuchi commented Mar 10, 2018

For anyone else still having problems even when manually passing in the inverse window function to inverse_stft:

I still had problems reconstructing the original signal, but I managed to fix the issue by manually zero-padding the signal by frame_length - frame_step on both sides. After taking the inverse transform, the zero-padded signal is reconstructed perfectly.

>>> signal = tf.constant(0.5 * np.sin(np.linspace(0., 440*2*np.pi, 16000)), dtype=tf.float32)
>>> frame_length = 400
>>> frame_step = 100
>>> fft_length = 512
>>> pad = frame_length - frame_step
>>> stft = tf.contrib.signal.stft(tf.pad(signal, [[pad, pad]]), frame_length, frame_step, fft_length)
>>> reconstructed = tf.contrib.signal.inverse_stft(
...   stft, frame_length, frame_step, fft_length, window_fn=tf.contrib.signal.inverse_stft_window_fn(frame_step))
>>> error = (reconstructed[pad:-pad] - signal).eval()
>>> np.max(np.abs(error))
2.0861626e-06

@andimarafioti
Copy link

I've also noticed that this is mostly a border effect. Zero padding fixes it to some extent, but it's not optimal for my application.
I'm also concerned about where does the problem comes from. Are either of the implementations for the stft or the istft reliable?

@PetrochukM
Copy link

@rryan How is this going?

@rryan
Copy link
Member

rryan commented Jun 12, 2018

Thanks a ton @nuchi, @memo, and @andimarafioti for your patience and helpful repro code. As you've summarized nicely, this is caused by at least two issues:

  • By default, tf.contrib.signal.inverse_stft does not assume that the input STFT was generated from tf.contrib.signal.stft, and therefore does not divide the window by the squared sum of its magnitude as librosa does by default. To get this behavior, pass window_fn=tf.contrib.signal.inverse_stft_window_fn(frame_step), which is designed to compute a reconstruction window given the window and frame_step used in a forward STFT.
  • tf.contrib.signal.stft does not center the framed windows as librosa does by default with center=True. This option doesn't exist in tf.contrib.signal yet (contrib STFT magnitudes different to librosa's #15134) but it's simple to work around, since you just reflect-pad the input to stft and slice the result of inverse_stftas librosa does here and here.

Here is a replacement for reconstruct_from_stft that works with @memo's test case:

def reconstruct_from_stft(x, frame_length, frame_step):
    name = 'stft'
    center = True
    if center:
        # librosa pads by frame_length, which almost works perfectly here, except for with frame_step 256.
        pad_amount = 2 * (frame_length - frame_step)
        x = tf.pad(x, [[pad_amount // 2, pad_amount // 2]], 'REFLECT')
    
    f = tf.contrib.signal.frame(x, frame_length, frame_step, pad_end=False)
    w = tf.contrib.signal.hann_window(frame_length, periodic=True)
    spectrograms_T = tf.spectral.rfft(f * w, fft_length=[frame_length])
        
    output_T = tf.contrib.signal.inverse_stft(
        spectrograms_T, frame_length, frame_step,
        window_fn=tf.contrib.signal.inverse_stft_window_fn(frame_step))
    if center and pad_amount > 0:
        output_T = output_T[pad_amount // 2:-pad_amount // 2]
    return name, output_T

Here is a Colab notebook demonstrating.

@kouohhashi
Copy link

@rryan
Thank you for your code!
And I have a question.

How can I use my own wave file?

I mean when I use my own wave file like...

sample_rate, input_data = wavfile.read('my-wave.wav')

instead of input_data = generate_data(duration_secs=1, sample_rate=sample_rate, num_sin=1, rnd_seed=2, max_val=0.5)

Files like tensorflow.16000.stft.1024.wav, tensorflow.16000.stft.256.wav,tensorflow.16000.stft.768.wav,tensorflow.16000.stft.512.wav have significant noises.

How can I apply reconstruct_from_stft function for my own wave file?

I want to try to train end-to-end noise reduction model like below.
wave file > input ( stft data ) > NN > output data ( stft data ) > wave file.

Since I'm new to DSP, probably I miss some basic things...

My wave file's sample_rate is 16000 and 10 seconds length.

Thanks in advance.

@rryan
Copy link
Member

rryan commented May 15, 2019

@kouohhashi, what shape and type is input_data? If it's a [samples] or [channels, samples] ndarray with type float32 (scaled to the range [-1, 1]) then it should work with the example.

@kouohhashi
Copy link

@rryan
Thank you for responding me.

input_data was like: [ 0. 0. 0. ... -5206. -4761. -3248.].
So I made it [-1, 1] range by dividing by 32768.0 because 32768.0 was the biggest.
But result was the same....

BTW, I noticed one thing.
Bits per sample seems to be changes.

Original file:
Sample rate: 16 kHz
Bits per sample: 16

New file:
Sample rate: 16 kHz
Bits per sample: 32

Is "Bits per sample change" the cause of problem?

Thanks,

@rryan rryan added the comp:signal tf.signal related issues label Sep 19, 2019
@YoavRamon
Copy link
Contributor

Thanks a ton @nuchi, @memo, and @andimarafioti for your patience and helpful repro code. As you've summarized nicely, this is caused by at least two issues:

* By default, `tf.contrib.signal.inverse_stft` does not assume that the input STFT was generated from `tf.contrib.signal.stft`, and therefore does not divide the window by the squared sum of its magnitude as librosa [does by default](https://github.com/librosa/librosa/blob/0dcd53f462db124ed3f54edf2334f28738d2ecc6/librosa/core/spectrum.py#L302-L311). To get this behavior, pass `window_fn=tf.contrib.signal.inverse_stft_window_fn(frame_step)`, which is designed to compute a reconstruction window given the window and `frame_step` used in a forward STFT.

* `tf.contrib.signal.stft` does not center the framed windows as librosa does by default with `center=True`. This option doesn't exist in `tf.contrib.signal` yet (#15134) but it's simple to work around, since you just reflect-pad the input to `stft` and slice the result of `inverse_stft`as librosa does [here](https://github.com/librosa/librosa/blob/0dcd53f462db124ed3f54edf2334f28738d2ecc6/librosa/core/spectrum.py#L162-L164) and [here](https://github.com/librosa/librosa/blob/0dcd53f462db124ed3f54edf2334f28738d2ecc6/librosa/core/spectrum.py#L314-L317).

Here is a replacement for reconstruct_from_stft that works with @memo's test case:

def reconstruct_from_stft(x, frame_length, frame_step):
    name = 'stft'
    center = True
    if center:
        # librosa pads by frame_length, which almost works perfectly here, except for with frame_step 256.
        pad_amount = 2 * (frame_length - frame_step)
        x = tf.pad(x, [[pad_amount // 2, pad_amount // 2]], 'REFLECT')
    
    f = tf.contrib.signal.frame(x, frame_length, frame_step, pad_end=False)
    w = tf.contrib.signal.hann_window(frame_length, periodic=True)
    spectrograms_T = tf.spectral.rfft(f * w, fft_length=[frame_length])
        
    output_T = tf.contrib.signal.inverse_stft(
        spectrograms_T, frame_length, frame_step,
        window_fn=tf.contrib.signal.inverse_stft_window_fn(frame_step))
    if center and pad_amount > 0:
        output_T = output_T[pad_amount // 2:-pad_amount // 2]
    return name, output_T

Here is a Colab notebook demonstrating.

Thank you @rryan , that worked well. Just wanted to say that for me what gave results the same as librosa is:
pad_amount = 2 * (frame_length - (frame_step * 2))

Used:

  • librosa==0.6.3
  • tensorflow==2.1.0

@sushreebarsa sushreebarsa self-assigned this Dec 24, 2021
@sushreebarsa
Copy link
Contributor

@memo It seems you are using older versions(1.x versions) of Tensorflow which is not actively supported. Since contrib has been depreciated in Tensorflow 2.x ,Please do upgrade to a latest Tensorflow version.Attaching migration guide for reference. Thanks!

@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label Jan 4, 2022
@memo
Copy link
Contributor Author

memo commented Jan 4, 2022

@sushreebarsa yes, this issue is from 4 years ago! :) I'm not sure if it is still relevant.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jan 6, 2022
@sushreebarsa
Copy link
Contributor

@memo Thank you for your response!
As TF v1.x is not actively supported we recommend to upgrade to 2.4 or later versions.If you face any issues after rewriting the code in TF v2, please raise a new ticket.
Closing this issue for now ,please feel free to reopen the issue if you have any concern ?
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:signal tf.signal related issues
Projects
None yet
Development

No branches or pull requests

10 participants