Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.signal CPU FFT implementation is slower than NumPy, PyTorch, etc. #6541

Closed
diggerdu opened this issue Dec 28, 2016 · 49 comments
Closed

tf.signal CPU FFT implementation is slower than NumPy, PyTorch, etc. #6541

diggerdu opened this issue Dec 28, 2016 · 49 comments
Assignees
Labels
comp:signal tf.signal related issues stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.11 Issues related to TF 2.11 type:performance Performance Issue

Comments

@diggerdu
Copy link

diggerdu commented Dec 28, 2016

What related GitHub issues or StackOverflow threads have you found by searching the web for your problem?

Environment info

Operating System:
Ubuntu 16.04 LTS 64bit

Installed version of CUDA and cuDNN:
(please attach the output of ls -l /path/to/cuda/lib/libcud*):
-rw-r--r-- 1 root root 558720 9月 15 07:02 libcudadevrt.a lrwxrwxrwx 1 root root 16 9月 15 07:05 libcudart.so -> libcudart.so.8.0 lrwxrwxrwx 1 root root 19 9月 15 07:05 libcudart.so.8.0 -> libcudart.so.8.0.44 -rw-r--r-- 1 root root 415432 9月 15 07:02 libcudart.so.8.0.44 -rw-r--r-- 1 root root 775162 9月 15 07:02 libcudart_static.a -rwxr-xr-x 1 root root 79337624 10月 27 23:13 libcudnn.so -rwxr-xr-x 1 root root 79337624 10月 27 23:13 libcudnn.so.5 -rwxr-xr-x 1 root root 79337624 10月 27 23:13 libcudnn.so.5.1.5 -rw-r--r-- 1 root root 69756172 10月 27 23:13 libcudnn_static.a
If installed from binary pip package, provide:

  1. A link to the pip package you installed:

  2. The output from python -c "import tensorflow; print(tensorflow.__version__)".
    0.12.head
    If installed from source, provide

  3. The commit hash (git rev-parse HEAD)

  4. The output of bazel version

If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code)

import numpy as np
import tensorflow as tf
import time

wav = np.random.random_sample((1024,))
spec = np.fft.fft(wav)[:513]


x = tf.placeholder(dtype=tf.complex64, shape=[513])
result = tf.ifft(x) 
sess = tf.Session()

start = time.time()
for i in range(10000):
    something = sess.run(result, feed_dict={x:spec})
print 'tensorflow:{}s'.format(time.time()-start)

start = time.time()
for i in range(10000):
   	something = np.fft.ifft(spec)
print 'numpy:{}s'.format(time.time() - start)

tensorflow:25.7219519615s
numpy:0.391902923584s

What other attempted solutions have you tried?

Logs or other output that would be helpful

(If logs are large, please upload as attachment or provide link).

@aselle
Copy link
Contributor

aselle commented Dec 28, 2016

Likely you are measuring the overhead of .run(), the copying of the feed variable spec which doesn't happen in numpy. In general, you will get good performance if each .run() does a significant portion of work. For example, I move your computation into a while loop inside tensorflow graph and get

tensorflow:0.0105199813843s
numpy:0.040412902832s

Here's the code:

import numpy as np
import tensorflow as tf
import time

niterations=1000
N=1024
wav = np.random.random_sample((N,))
spec = np.fft.fft(wav)[:N/2+1]

sess = tf.Session()

# Create and initialize variables
cnt = tf.Variable(tf.constant(niterations))
specVar = tf.Variable(spec, dtype=tf.complex64)
sess.run(tf.variables_initializer([specVar,cnt]))

# While loop that counts down to zero and computes reverse and forward fft's
def condition(x,cnt):
  return cnt <= 0

def body(x,cnt):
  xrev=tf.ifft(x)
  xnew=tf.fft(xrev)
  cntnew=cnt-1
  return xnew, cntnew

start = time.time()
tf.while_loop(condition, body, [specVar,cnt], parallel_iterations=1)

print 'tensorflow:{}s'.format(time.time()-start)

# Equivalent numpy loop
start = time.time()
x = spec
for i in range(niterations):
   xrev = np.fft.ifft(x)
   x= np.fft.fft(xrev)
   
print 'numpy:{}s'.format(time.time() - start)

@aselle aselle added stat:awaiting response Status - Awaiting response from author type:support Support issues labels Dec 28, 2016
@vrv
Copy link

vrv commented Dec 28, 2016

@aselle do you run the ops in your version or just construct them?

@aselle
Copy link
Contributor

aselle commented Dec 28, 2016

Good catch @vrv, forgot the run() call. When putting in a run() call (see below code), I get

tensorflow:1.78562903404s
numpy:0.0487790107727s

which I suppose is comparible to your results (yours was numpy 66x faster, and mine was like numpy 33x faster). One explanation is that the GPU FFT implementation is really not tuned to smalls sizes, so that it can't achieve the same performance of the CPU FFT on a relatively small 513 element array. That's only 2 KBytes of data, which is not much for throughput optimized devices. As such I ran another one where I set N=16384 which will basically use 8193 element array. Then I got

tensorflow:5.21
numpy:41.5
import numpy as np
import tensorflow as tf
import time

niterations=1000
N=16384
wav = np.random.random_sample((N,))
spec = np.fft.fft(wav)[:N/2+1]


with tf.Session() as sess:
  # Create and initialize variables
  cnt = tf.Variable(tf.constant(niterations))
  specVar = tf.Variable(spec, dtype=tf.complex64)
  sess.run(tf.variables_initializer([specVar,cnt]))

  # While loop that counts down to zero and computes reverse and forward fft's
  def condition(x,cnt):
    return cnt > 0

  def body(x,cnt):
    xrev=tf.ifft(x)
    xnew=tf.fft(xrev)
    cntnew=cnt-1
    return xnew, cntnew

  start = time.time()

  final, cnt= tf.while_loop(condition, body, [specVar,cnt], parallel_iterations=1)
  final, cnt =  sess.run([final,cnt])

  print 'tensorflow:{}s'.format(time.time()-start)

# Equivalent numpy loop
start = time.time()
x = spec
for i in range(niterations):
   xrev = np.fft.ifft(x)
   x= np.fft.fft(xrev)
   
print 'numpy:{}s'.format(time.time() - start)

@aselle
Copy link
Contributor

aselle commented Jan 17, 2017

Closing due to lack of recent activity. We will reopen when additional information becomes available. Thanks!

@aselle aselle closed this as completed Jan 17, 2017
@Czxck001
Copy link
Contributor

Czxck001 commented Jun 1, 2017

I've done some simple profiling and it seems the CPU version of tf.spectrum.fft is slow in small nfft because of EIGEN fft function. The most time spent on the execution of CPU fft kernel is by EIGEN fft function.

@zaccharieramzi
Copy link
Contributor

zaccharieramzi commented Aug 23, 2019

I have done some profiling of my own using the trace profiler and the results suggest that the tensorflow FFT2D is about 50 times slower than numpy's.

Here is what I have done (with IFFT2D but the results are the same for FFT2D):

from keras.layers import Layer, Input, Lambda
from keras.models import Model
import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
from tensorflow.signal import ifft2d


def tf_unmasked_adj_op(x):
    return tf.expand_dims(ifft2d(x[..., 0]), axis=-1)

# Model definition (we basically perform and inverse fft and then get the module)
input_size=(320, None, 1)
kspace_input = Input(input_size, dtype='complex64', name='kspace_input')
zero_filled = Lambda(tf_unmasked_adj_op, output_shape=input_size, name='ifft_simple')(kspace_input)
zero_filled_abs = Lambda(tf.math.abs)(zero_filled)
model = Model(inputs=kspace_input, outputs=zero_filled_abs)
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()

model.compile(
    optimizer='adam',
    loss='mse',
    options=run_options,
    run_metadata=run_metadata,
)

# fake data
data_x = np.random.rand(35, 320, 320, 1) + 1j * np.random.rand(35, 320, 320, 1)
data_y = np.random.rand(35, 320, 320, 1)


# a single pass
model.fit(
    x=data_x, 
    y=data_y, 
    batch_size=35, 
    epochs=1,
    verbose=2, 
    shuffle=False,
)


# profiling output
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline_fft_test.json', 'w') as f:
    f.write(ctf)

On 8 CPUs (intel i7), ubuntu 16.04, tensorflow 1.14.0, keras 2.2.4, the fitting time is 5s. Looking into the timeline (chrome://tracing/), I find that the IFFT2D has a duration of 5000ms (=5s).

With a GPU, on ubuntu 18.04, other versions similar, the IFFT2D takes 500ms.

As a comparison, the numpy version is only 80ms.

Here is the code I used to perform the numpy profiling (in a jupyter notebook, with the same data_x):

def np_inv_fft(x):
    return np.fft.ifft2(x)

%%timeit
for i in range(len(data_x)):
    np_inv_fft(np.squeeze(data_x[i]))

Maybe I am doing the profiling in a wrong way, but it does feel like the FFT2D is the bottleneck (and a very big one) in my actual models (this is just a minimal reproducible example).

I haven't taken the time to see if it was due to the fourier transform 2D being in a Lambda layer in a keras model, but I guess it shouldn't. I will try to do this in pure tensorflow to see how it is.

@zaccharieramzi
Copy link
Contributor

This is what I tried in pure tensorflow with the same results:

import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
from tensorflow.signal import ifft2d


def tf_unmasked_adj_op(x):
    return tf.expand_dims(ifft2d(x[..., 0]), axis=-1)

# Model definition (we basically perform and inverse fft and then get the module)
kspace_input = tf.placeholder(dtype='complex64', shape=(None, 320, None, 1))
fake_output = tf.placeholder(dtype='float32', shape=(None, 320, None, 1))
zero_filled = tf_unmasked_adj_op(kspace_input)
zero_filled_abs = tf.math.abs(zero_filled)
# this is just to have an adam optimizer similar to the keras training in order to keep things similar (we just compare the IFFT2D compute time in the timeline anyway)
w = tf.Variable((1,), trainable=True, dtype=tf.float32) 
zero_filled_abs = tf.math.multiply(zero_filled_abs, w)

loss = tf.losses.mean_squared_error(fake_output, zero_filled_abs)

adam = tf.train.AdamOptimizer(learning_rate=0.3)
a = adam.minimize(loss, var_list=[w])


run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()

# fake data
data_x = np.random.rand(35, 320, 320, 1) + 1j * np.random.rand(35, 320, 320, 1)
data_y = np.random.rand(35, 320, 320, 1)


# a single pass
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(1):
        sess.run(a, feed_dict={kspace_input: data_x, fake_output: data_y}, options=run_options, run_metadata=run_metadata)

# profiling output
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline_fft_test.json', 'w') as f:
    f.write(ctf)

@zaccharieramzi
Copy link
Contributor

zaccharieramzi commented Aug 26, 2019

So I tried one further thing to see how big the problem was.
I wrapped the numpy IFFT2D in a tensorflow layer using concepts from odl (itself inspired by gists shared in #1095).

This is the code I used to create the tensorflow layer:

import numpy as np
import uuid
import tensorflow as tf
from tensorflow.python.framework import ops


# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, name=None, grad=None):
    if grad is None:
        return tf.py_func(func, inp, Tout, stateful=False, name=name)
    else:
        override_name = 'PyFuncStateless'

        # Need to generate a unique name to avoid duplicates:
        rnd_name = override_name + 'Grad' + str(uuid.uuid4())

        tf.RegisterGradient(rnd_name)(grad)
        g = tf.get_default_graph()

        with g.gradient_override_map({override_name: rnd_name}):
            return tf.py_func(func, inp, Tout, stateful=False,
                              name=name)


def fft_layer(mask, op_name='forward'):
    def forward_op(imgs):
        fft_coeffs = np.empty_like(imgs)
        for i, img in enumerate(imgs):
            fft_coeffs[i] = mask * np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(img[..., 0]), norm='ortho'))[..., None]
        return fft_coeffs

    def adj_op(kspaces):
        imgs = np.empty_like(kspaces)
        for i, kspace in enumerate(kspaces):
            masked_fft_coeffs = mask * kspace[..., 0]
            imgs[i] = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(masked_fft_coeffs), norm='ortho'))[..., None]
        return imgs

    if op_name == 'forward':
        op = forward_op
        grad_op = adj_op
    else:
        op = adj_op
        grad_op = forward_op

    def tf_grad_op(x, dy, name):
        with tf.name_scope(name):
            out_shape = x.get_shape()
            with ops.name_scope(name + '_pyfunc', values=[x, dy]) as name_call:
                result = py_func(
                    grad_op,
                    [dy],
                    [tf.complex64],
                    name=name_call,
                )

                # We must manually set the output shape since tensorflow cannot
                # figure it out
                result = result[0]
                result.set_shape(out_shape)
                return result

    # Def custom square function using np.square instead of tf.square:
    def tf_op(x, name=None):
        with tf.name_scope(name, op_name, values=[x]) as name:
            x_shape = x.get_shape()
            def tensorflow_layer_grad(op, grad):
                """Thin wrapper for the gradient."""
                x = op.inputs[0]
                return tf_grad_op(x, grad, name=name + '_grad')
            with ops.name_scope(name + '_pyfunc', values=[x]) as name_call:
                result = py_func(
                    op,
                    [x],
                    [tf.complex64],
                    name=name_call,
                    grad=tensorflow_layer_grad,
                )
                # We must manually set the output shape since tensorflow cannot
                # figure it out
                result = result[0]
                result.set_shape(x_shape)
                return result

    return tf_op

(a bit more complex than a simple inverse fourier transform but I need it for my application)

Even with this slightly more complex function, in numpy, on CPU, the operation now takes 200ms.

I don't know if I am doing things incorrectly, if so please tell me. I don't think I could help solve this problem unfortunately as I don't know how to code in C (I guess if there is a problem it's related to the eigen support of FFT2D).

@aselle @vrv I see that you were involved in this discussion earlier, do you have any idea of what's going on? (sorry to ping you directly but this is really a bottleneck for me).

@rryan I also know from my PR that you are involved with the fft stuff, do you have any idea of what's happening by any chance?

@rryan
Copy link
Member

rryan commented Aug 27, 2019

The main problem with FFT ops in TensorFlow that makes them slow is that we compute the FFT plan on every execution instead of caching it for a given size. Due to the multi-threaded nature of op execution, nobody has done the work of implementing a plan cache that would be thread safe. Beyond this, Eigen's "TensorFFT" itself is not particularly fast when compared to other libraries like FFTW (which we can't use in TensorFlow due to lack legal approval).

@zaccharieramzi
Copy link
Contributor

But even for a batch size of 1 (i.e. the caching isn't involved), TensorFlow's FFT is still slower than numpy's wrapped (in a TensorFlow py_func) IFFT2D (130ms compared to 25ms on CPU).
I also don't understand how the benchmark I did relates to @aselle 's one. Is it because I am considering 2D?

@rryan
Copy link
Member

rryan commented Aug 28, 2019

But even for a batch size of 1 (i.e. the caching isn't involved), TensorFlow's FFT is still slower than numpy's wrapped (in a TensorFlow py_func) IFFT2D (130ms compared to 25ms on CPU).
I also don't understand how the benchmark I did relates to @aselle 's one. Is it because I am considering 2D?

Even at batch size 1, for every invocation of the op (i.e. sess.run call) we are re-computing a plan for the FFT, executing the 1 FFT, then returning. Thankfully when the batch size is greater than 1 we re-use the plan within that execution :), just not across executions.

Numpy caches the FFT plan in a process-wide cache, so separate calls to ifft2d will re-use the plan made on the first call with that shape.

@rryan
Copy link
Member

rryan commented Aug 28, 2019

It looks like you're producing trace timelines (great!) -- can you post them?

I would like to see the actual measured time within the op instead of the overall time of calling sess.run (which as @aselle points out can be a source of measurement error)

@zaccharieramzi
Copy link
Contributor

It looks like you're producing trace timelines (great!) -- can you post them?

Sure. Attached you will find a zip tf_fft_timelines.zip containing 4 files. They correspond to the experiments I showed above. Basically the network is doing an IFFT2D, masking the result, and taking the complex module. The batch size by default is 35. I am fitting the network, but since there is no parameters to fit, it's just doing a forward pass (results are roughly the same for only a batch prediction).

  • timeline_fft_test_tf_bs1.json : the IFFT2D is done with tf.signal.ifft2d and the batch size is 1. Here the measured time for the op is 130ms.
  • timeline_fft_test_tf.json : the IFFT2D is done with tf.signal.ifft2d. Here the measured time for the ops is 3900ms.
  • timeline_fft_test_tf_numpy_bs1.json : the IFFT2D is done with the layer described in this comment (i.e. wrapping np.fft.ifft2d in a tensorFlow layer with some extra tweaks) and the batch size is 1. Here the measured time for the op is 28ms.
  • timeline_fft_test_tf_numpy.json : the IFFT2D is done with the layer described in this comment (i.e. wrapping np.fft.ifft2d in a tensorFlow layer with some extra tweaks). Here the measured time for the op is 228ms.

This benchmark was done with Ubuntu 16.04, on CPU only.

@rryan
Copy link
Member

rryan commented Aug 29, 2019

  • timeline_fft_test_tf_bs1.json : the IFFT2D is done with tf.signal.ifft2d and the batch size is 1. Here the measured time for the op is 130ms.
  • timeline_fft_test_tf.json : the IFFT2D is done with tf.signal.ifft2d. Here the measured time for the ops is 3900ms.
  • timeline_fft_test_tf_numpy_bs1.json : the IFFT2D is done with the layer described in this comment (i.e. wrapping np.fft.ifft2d in a tensorFlow layer with some extra tweaks) and the batch size is 1. Here the measured time for the op is 28ms.
  • timeline_fft_test_tf_numpy.json : the IFFT2D is done with the layer described in this comment (i.e. wrapping np.fft.ifft2d in a tensorFlow layer with some extra tweaks). Here the measured time for the op is 228ms.

Thanks for that! What CPU are you using? Does it have AVX support and was TensorFlow built with AVX support?

@zaccharieramzi
Copy link
Contributor

What CPU are you using?

Output of lshw:

*-cpu:0
          product: Intel(R) Xeon(R) Silver 4114 CPU @ 2.20GHz
          vendor: Intel Corp.
          physical id: 6
          bus info: cpu@0
          size: 800MHz
          capacity: 3GHz
          width: 64 bits

Does it have AVX support and was TensorFlow built with AVX support?

I did not build TensorFlow from source so it was not built with AVX support. However, even when I use my GPU (Quadro P5000), the IFFT2D is still slower with tf.signal.ifft2d than with np.fft.ifft2d wrapped (don't know if it's supposed to make a big difference), but faster than on CPU.

@zaccharieramzi
Copy link
Contributor

I just did another simple benchmark, this time comparing TensorFlow (used with keras) to pytorch for the IFFT2D. I find that for prediction, pytorch is 40 times faster than tensorflow when using the IFFT2D.

@rryan
Copy link
Member

rryan commented Aug 29, 2019

Thanks @zaccharieramzi!

Could you please run a GPU benchmark as well? Based on this I think PyTorch uses cuFFT and MKL, so on GPU TensorFlow and PyTorch should be using the same underlying FFT implementation. It would be nice to confirm whether or not there is a difference when both systems are using cuFFT because that would point to inefficiencies outside the particular FFT implementation that need to be addressed.

@rryan rryan changed the title Why fft implementation in tensorflow is so slow? tf.signal CPU FFT implementation is slower than NumPy, PyTorch, etc. Aug 29, 2019
@rryan rryan reopened this Aug 29, 2019
@rryan
Copy link
Member

rryan commented Aug 29, 2019

I'm reopening because it is a known issue that TensorFlow's CPU FFT implementation is based on Eigen's TensorFFT, which is not fast compared to FFTW, MKL (what PyTorch uses), or FFTPACK (what NumPy uses).

For projects I work on at Google, we use TPUs and GPUs for FFTs in both training and serving, so I think we are not noticing the pain of the CPU FFTs being slow. Sorry about that :(. Unfortunately we are a little stuck in terms of legal approval for the above mentioned libraries. When we implemented tf.signal Eigen's TensorFFT was the best option in terms of features (non-power-of-2 FFTs via Bluestein's algorithm is a hard requirement) and legal compatibility.

@zaccharieramzi
Copy link
Contributor

I am training a model on my main GPU right now, but as soon as I can I will run the same benchmarks. But as I recall (and wrote here), even with GPU, tf.signal.ifft2d is quite slow.

I will come back with exact numbers coming from the trace timelines (and a zip with all of them).

@zaccharieramzi
Copy link
Contributor

zaccharieramzi commented Aug 30, 2019

I ran the benchmarks on GPU. You can find a zip tf_fft_timelines_gpu.zip attached with similar nomenclature as before. (You will notice that only one run "epoch" appears in each trace, I did not find a way -did not look very hard as well- to have all the "epochs" in once trace)

The first "epoch" is way slower (that's why I did more than 1 "epoch" this time). Was the remark about not caching the plan only applicable to CPU?

This time the pure tf.signal op takes only 1.2 ms on GPU after being cached (500 ms before being cached so still slower than numpy on CPU in that regard, to be considered in my case where all the images don't have the same shape) for a batch size of 35.

I also enhanced my pytorch vs keras benchmark, to feature trace timelines keras_vs_pytorch_fft_timelines.zip for both for the forward passes, and taking into account caching. I think the IFFT2D times are equivalent when looking at the timelines (but there is some much stuff in there I am not sure how to read them, if you have a resource it would be welcome).
However, when I look at the overall run time for, for example, prediction, pytorch is 4 times faster (7 times on fitting). I don't think it's only due to the fact that the tensors are already in GPU for pytorch. But maybe it's due to something else that I don't get.

@rryan
Copy link
Member

rryan commented Aug 30, 2019

Thanks again @zaccharieramzi. While an end-to-end benchmark is useful, for the purpose of this bug I'm only concerned with the FFT ops themselves, so I think it could be simplified.

Some issues I noticed:

  • The PyTorch benchmark starts with x already on the GPU, while the TensorFlow one has to transfer the NumPy array from host to device.
  • The PyTorch benchmark does not measure device to host time for the result tensor (the final r in the predict benchmark is a device = cuda object, while I believe Keras would implicitly fetch the tensor from device to host when running predict_on_batch.
  • The TensorFlow benchmark is running with tracing enabled -- tracing has some overhead, and so should not be enabled in a benchmark.
  • Instead of %%time, it could give a lower variance estimate if you use %%timeit. Both of these methods are subject to variance introduced by overheads in the Python interpreter, etc.

Here's a more "raw" benchmark for TensorFlow that can give a better sense of what's going on.

import tensorflow as tf
import numpy as np

tf.reset_default_graph()

x_np = np.random.rand(35, 320, 320) + 1j * np.random.rand(35, 320, 320)
with tf.device('/gpu:0'):
  x_tf = tf.constant(x_np)
  y = tf.signal.ifft2d(x_tf)

sess = tf.Session()

And then:

%%timeit -n1000
sess.run(y.op)  # Note the .op avoids fetching y from device to host.

On a public Colab runtime with a K80 GPU, I'm observing a very high initial run time, followed by a steady result of:

1000 loops, best of 3: 3.47 ms per loop

Compared to PyTorch, which seems to be faster by about 2x:

import torch
dtype = torch.float32
device = torch.device("cuda")
x = torch.randn(35, 320, 320, 2, device=device, dtype=dtype)

And then:

%%timeit -n1000
y_pt = torch.ifft(x, 2)

I'm getting:

1000 loops, best of 3: 1.45 ms per loop

I believe this is apples-to-apples, since we're not measuring device/host transfer time, and the input shapes are the same.

Here is a screenshot of a trace of the above TensorFlow graph executing on the same K80:

Screenshot 2019-08-30 at 9 48 06 AM

You can see there are no memcpy D2H/H2Ds (good). You can see there is over 0.5 ms of time between the IFFT2D op execution (GPU:0 on the bottom) and the CUDA stream beginning work (stream:all). Then the CUDA FFT kernel itself takes about 3ms total.

I'll take a look at how PyTorch invokes cuFFT to see how they differ from TensorFlow. The performance should be identical on GPU.

Here is the PyTorch trace of the above:

Screenshot 2019-08-30 at 10 16 45 AM

I'm less familiar with how to interpret PyTorch traces, so I don't know how to figure out what the latency was between CPU scheduling the kernel and it starting, but it at least confirms the overall execution time took just over 2 ms, and I don't see anything that looks like a D2H/H2D copy.

@av8ramit av8ramit added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting response Status - Awaiting response from author labels Aug 30, 2019
@zaccharieramzi
Copy link
Contributor

While an end-to-end benchmark is useful, for the purpose of this bug I'm only concerned with the FFT ops themselves, so I think it could be simplified.

Totally agree. Thanks for the remarks on the different aspects of the benchmark.

However, the trace timelines I got (still relevant I think) are not exactly the same as yours for TensorFlow. For example, I only have 2 items in the CUDA stream.
The overall time taken by the IFFT (1.7ms) is also much closer to that of pytorch (1.4ms). Do you think it might be due to differences in the GPU we are using?

Anyway, I do think your benchmark shows that it might be worth investigating.

@rmothukuru rmothukuru removed their assignment May 28, 2021
@rmothukuru rmothukuru added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels May 28, 2021
@Andreas5739738
Copy link

I added jax to the comparison, which is also much faster than tensorflow:

seconds (lower is better):
Tensorflow 2.5.0: 23.75205922899977
Numpy: 2.10841278099997
Jax: 1.2626468630001

jax uses pocketfft (https://jax.readthedocs.io/en/latest/_modules/jax/_src/lax/fft.html#fft), which seems to be much faster then Tensorflow's Eigen.

Updated colab: https://colab.research.google.com/gist/Andreas5739738/fc603468829ee0fc7e40a2e27d8a6661/fft.ipynb

@mohantym
Copy link
Contributor

mohantym commented Aug 17, 2021

Yes ! Inline with @Andreas5739738 Gist , was able to replicate the issue in CPU environment for TF 2.5,But in GPU environment it is approx 405 times faster than numpy , providing Gist for reference . Thanks!

@Andreas5739738
Copy link

Andreas5739738 commented Aug 17, 2021

The situation has not improved with TF 2.6:

seconds (lower is better):

Tensorflow 2.6.0 26.172108803000015
Numpy:  1.4308665990000122
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Jax:  1.412082421999969

@mohantym mohantym added the 2.6.0 label Aug 17, 2021
@mohantym mohantym removed their assignment Oct 14, 2021
@Andreas5739738
Copy link

Andreas5739738 commented Jan 26, 2022

This is still an issue in current tf-nightly (2.8+).

Also noticed something interesting: when using double precision, TF is more than twice as fast as with single precision (although still ~8x slower than Jax):

print("seconds (lower is better):")
print(f"Tensorflow {tf.__version__}", timeit.timeit('X = tf.signal.rfft(x)', setup='import tensorflow as tf; x = tf.random.normal([50000, 512])', number=10))
print(f"Tensorflow {tf.__version__}, double precision", timeit.timeit('X = tf.cast(tf.signal.rfft(tf.cast(x, tf.float64)), tf.complex64)', setup='import tensorflow as tf; x = tf.random.normal([50000, 512])', number=10))
print("Numpy: ", timeit.timeit('X = numpy.fft.rfft(x)', setup='import numpy.fft; import tensorflow as tf; x = tf.random.normal([50000, 512])', number=10))
print("Jax: ", timeit.timeit('jnp.fft.rfft(x).block_until_ready()', setup='import jax.numpy as jnp; import tensorflow as tf; x = tf.random.normal([50000, 512]).numpy()', number=10))
seconds (lower is better):
Tensorflow 2.9.0-dev20220126 24.444629474999942
Tensorflow 2.9.0-dev20220126, double precision 11.099884848999977
Numpy:  2.1127802319999773
Jax:  1.2639448529999981

https://colab.research.google.com/gist/Andreas5739738/fc603468829ee0fc7e40a2e27d8a6661/fft.ipynb

@Andreas5739738
Copy link

Andreas5739738 commented Feb 1, 2022

Looks like Jax team found the same issue with XLA FFT slowness, and integrated PocketFFT as a workaround: google/jax#2952
Would be great if the pocketfft OP could also be integrated into XLA itself, so that both TF and Jax benefit from the speedup.

@ornob39
Copy link

ornob39 commented Mar 25, 2022

Is there any workaround for GPU? Most of my computational cost is just going to compute FFT. PyTorch's implementation is also very fast.

@Andreas5739738
Copy link

Looks like things have improved significantly with the 2.9 release, although there still is a large gap to numpy and Jax:

seconds (lower is better):
Tensorflow 2.9.1 5.495112890999991
Tensorflow 2.9.1, double precision 7.629201937000033
Numpy:  2.1803204349999987
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Jax:  1.4081462569999985

@ddgonzal3
Copy link

Are there any plans to address this issue any further? It's a huge inconvenience for anyone wanting to perform FFT transformations in a deployed model.

@mohantym
Copy link
Contributor

@diggerdu !
I could replicate this issue in 2.11 CPU version. Attached gist and output for reference.

Colab only includes TensorFlow 2.x; %tensorflow_version has no effect.
2.11.0
 GPUs exists False
seconds (lower is better):
Tensorflow 2.11.0 8.614034000000004
Numpy:  2.022543939000002
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Jax:  1.7004664330000026

Thank you!

@tilakrayal
Copy link
Contributor

I tried to execute the mentioned code on tf-nightly(2.17.0-dev20240403) on CPU and observed that the time take for the execution on tensorflow is lesser than numpy.

Tensorflow 2.17.0-dev20240403 1.0675210610000079
Numpy:  1.7257418959999882
Jax:  2.5296104049999997

Kindly find the gist of it here. Thank you!

@tilakrayal tilakrayal added the stat:awaiting response Status - Awaiting response from author label Apr 4, 2024
@cantonios
Copy link
Contributor

Yes, we now use DUCC FFT in all of TF, JAX, XLA. This is resolved.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@tilakrayal tilakrayal self-assigned this Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:signal tf.signal related issues stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.11 Issues related to TF 2.11 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests