-
Notifications
You must be signed in to change notification settings - Fork 74.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tf.signal CPU FFT implementation is slower than NumPy, PyTorch, etc. #6541
Comments
Likely you are measuring the overhead of .run(), the copying of the feed variable spec which doesn't happen in numpy. In general, you will get good performance if each .run() does a significant portion of work. For example, I move your computation into a while loop inside tensorflow graph and get
Here's the code: import numpy as np
import tensorflow as tf
import time
niterations=1000
N=1024
wav = np.random.random_sample((N,))
spec = np.fft.fft(wav)[:N/2+1]
sess = tf.Session()
# Create and initialize variables
cnt = tf.Variable(tf.constant(niterations))
specVar = tf.Variable(spec, dtype=tf.complex64)
sess.run(tf.variables_initializer([specVar,cnt]))
# While loop that counts down to zero and computes reverse and forward fft's
def condition(x,cnt):
return cnt <= 0
def body(x,cnt):
xrev=tf.ifft(x)
xnew=tf.fft(xrev)
cntnew=cnt-1
return xnew, cntnew
start = time.time()
tf.while_loop(condition, body, [specVar,cnt], parallel_iterations=1)
print 'tensorflow:{}s'.format(time.time()-start)
# Equivalent numpy loop
start = time.time()
x = spec
for i in range(niterations):
xrev = np.fft.ifft(x)
x= np.fft.fft(xrev)
print 'numpy:{}s'.format(time.time() - start) |
@aselle do you run the ops in your version or just construct them? |
Good catch @vrv, forgot the run() call. When putting in a run() call (see below code), I get
which I suppose is comparible to your results (yours was numpy 66x faster, and mine was like numpy 33x faster). One explanation is that the GPU FFT implementation is really not tuned to smalls sizes, so that it can't achieve the same performance of the CPU FFT on a relatively small 513 element array. That's only 2 KBytes of data, which is not much for throughput optimized devices. As such I ran another one where I set
import numpy as np
import tensorflow as tf
import time
niterations=1000
N=16384
wav = np.random.random_sample((N,))
spec = np.fft.fft(wav)[:N/2+1]
with tf.Session() as sess:
# Create and initialize variables
cnt = tf.Variable(tf.constant(niterations))
specVar = tf.Variable(spec, dtype=tf.complex64)
sess.run(tf.variables_initializer([specVar,cnt]))
# While loop that counts down to zero and computes reverse and forward fft's
def condition(x,cnt):
return cnt > 0
def body(x,cnt):
xrev=tf.ifft(x)
xnew=tf.fft(xrev)
cntnew=cnt-1
return xnew, cntnew
start = time.time()
final, cnt= tf.while_loop(condition, body, [specVar,cnt], parallel_iterations=1)
final, cnt = sess.run([final,cnt])
print 'tensorflow:{}s'.format(time.time()-start)
# Equivalent numpy loop
start = time.time()
x = spec
for i in range(niterations):
xrev = np.fft.ifft(x)
x= np.fft.fft(xrev)
print 'numpy:{}s'.format(time.time() - start) |
Closing due to lack of recent activity. We will reopen when additional information becomes available. Thanks! |
I've done some simple profiling and it seems the CPU version of tf.spectrum.fft is slow in small |
I have done some profiling of my own using the trace profiler and the results suggest that the Here is what I have done (with IFFT2D but the results are the same for FFT2D): from keras.layers import Layer, Input, Lambda
from keras.models import Model
import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
from tensorflow.signal import ifft2d
def tf_unmasked_adj_op(x):
return tf.expand_dims(ifft2d(x[..., 0]), axis=-1)
# Model definition (we basically perform and inverse fft and then get the module)
input_size=(320, None, 1)
kspace_input = Input(input_size, dtype='complex64', name='kspace_input')
zero_filled = Lambda(tf_unmasked_adj_op, output_shape=input_size, name='ifft_simple')(kspace_input)
zero_filled_abs = Lambda(tf.math.abs)(zero_filled)
model = Model(inputs=kspace_input, outputs=zero_filled_abs)
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
model.compile(
optimizer='adam',
loss='mse',
options=run_options,
run_metadata=run_metadata,
)
# fake data
data_x = np.random.rand(35, 320, 320, 1) + 1j * np.random.rand(35, 320, 320, 1)
data_y = np.random.rand(35, 320, 320, 1)
# a single pass
model.fit(
x=data_x,
y=data_y,
batch_size=35,
epochs=1,
verbose=2,
shuffle=False,
)
# profiling output
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline_fft_test.json', 'w') as f:
f.write(ctf) On 8 CPUs (intel i7), ubuntu 16.04, tensorflow 1.14.0, keras 2.2.4, the fitting time is 5s. Looking into the timeline (chrome://tracing/), I find that the IFFT2D has a duration of 5000ms (=5s). With a GPU, on ubuntu 18.04, other versions similar, the IFFT2D takes 500ms. As a comparison, the Here is the code I used to perform the def np_inv_fft(x):
return np.fft.ifft2(x)
%%timeit
for i in range(len(data_x)):
np_inv_fft(np.squeeze(data_x[i])) Maybe I am doing the profiling in a wrong way, but it does feel like the FFT2D is the bottleneck (and a very big one) in my actual models (this is just a minimal reproducible example). I haven't taken the time to see if it was due to the fourier transform 2D being in a |
This is what I tried in pure import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
from tensorflow.signal import ifft2d
def tf_unmasked_adj_op(x):
return tf.expand_dims(ifft2d(x[..., 0]), axis=-1)
# Model definition (we basically perform and inverse fft and then get the module)
kspace_input = tf.placeholder(dtype='complex64', shape=(None, 320, None, 1))
fake_output = tf.placeholder(dtype='float32', shape=(None, 320, None, 1))
zero_filled = tf_unmasked_adj_op(kspace_input)
zero_filled_abs = tf.math.abs(zero_filled)
# this is just to have an adam optimizer similar to the keras training in order to keep things similar (we just compare the IFFT2D compute time in the timeline anyway)
w = tf.Variable((1,), trainable=True, dtype=tf.float32)
zero_filled_abs = tf.math.multiply(zero_filled_abs, w)
loss = tf.losses.mean_squared_error(fake_output, zero_filled_abs)
adam = tf.train.AdamOptimizer(learning_rate=0.3)
a = adam.minimize(loss, var_list=[w])
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
# fake data
data_x = np.random.rand(35, 320, 320, 1) + 1j * np.random.rand(35, 320, 320, 1)
data_y = np.random.rand(35, 320, 320, 1)
# a single pass
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(1):
sess.run(a, feed_dict={kspace_input: data_x, fake_output: data_y}, options=run_options, run_metadata=run_metadata)
# profiling output
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('timeline_fft_test.json', 'w') as f:
f.write(ctf) |
So I tried one further thing to see how big the problem was. This is the code I used to create the import numpy as np
import uuid
import tensorflow as tf
from tensorflow.python.framework import ops
# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, name=None, grad=None):
if grad is None:
return tf.py_func(func, inp, Tout, stateful=False, name=name)
else:
override_name = 'PyFuncStateless'
# Need to generate a unique name to avoid duplicates:
rnd_name = override_name + 'Grad' + str(uuid.uuid4())
tf.RegisterGradient(rnd_name)(grad)
g = tf.get_default_graph()
with g.gradient_override_map({override_name: rnd_name}):
return tf.py_func(func, inp, Tout, stateful=False,
name=name)
def fft_layer(mask, op_name='forward'):
def forward_op(imgs):
fft_coeffs = np.empty_like(imgs)
for i, img in enumerate(imgs):
fft_coeffs[i] = mask * np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(img[..., 0]), norm='ortho'))[..., None]
return fft_coeffs
def adj_op(kspaces):
imgs = np.empty_like(kspaces)
for i, kspace in enumerate(kspaces):
masked_fft_coeffs = mask * kspace[..., 0]
imgs[i] = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(masked_fft_coeffs), norm='ortho'))[..., None]
return imgs
if op_name == 'forward':
op = forward_op
grad_op = adj_op
else:
op = adj_op
grad_op = forward_op
def tf_grad_op(x, dy, name):
with tf.name_scope(name):
out_shape = x.get_shape()
with ops.name_scope(name + '_pyfunc', values=[x, dy]) as name_call:
result = py_func(
grad_op,
[dy],
[tf.complex64],
name=name_call,
)
# We must manually set the output shape since tensorflow cannot
# figure it out
result = result[0]
result.set_shape(out_shape)
return result
# Def custom square function using np.square instead of tf.square:
def tf_op(x, name=None):
with tf.name_scope(name, op_name, values=[x]) as name:
x_shape = x.get_shape()
def tensorflow_layer_grad(op, grad):
"""Thin wrapper for the gradient."""
x = op.inputs[0]
return tf_grad_op(x, grad, name=name + '_grad')
with ops.name_scope(name + '_pyfunc', values=[x]) as name_call:
result = py_func(
op,
[x],
[tf.complex64],
name=name_call,
grad=tensorflow_layer_grad,
)
# We must manually set the output shape since tensorflow cannot
# figure it out
result = result[0]
result.set_shape(x_shape)
return result
return tf_op (a bit more complex than a simple inverse fourier transform but I need it for my application) Even with this slightly more complex function, in I don't know if I am doing things incorrectly, if so please tell me. I don't think I could help solve this problem unfortunately as I don't know how to code in C (I guess if there is a problem it's related to the eigen support of FFT2D). @aselle @vrv I see that you were involved in this discussion earlier, do you have any idea of what's going on? (sorry to ping you directly but this is really a bottleneck for me). @rryan I also know from my PR that you are involved with the fft stuff, do you have any idea of what's happening by any chance? |
The main problem with FFT ops in TensorFlow that makes them slow is that we compute the FFT plan on every execution instead of caching it for a given size. Due to the multi-threaded nature of op execution, nobody has done the work of implementing a plan cache that would be thread safe. Beyond this, Eigen's "TensorFFT" itself is not particularly fast when compared to other libraries like FFTW (which we can't use in TensorFlow due to lack legal approval). |
But even for a batch size of 1 (i.e. the caching isn't involved), TensorFlow's FFT is still slower than numpy's wrapped (in a TensorFlow |
Even at batch size 1, for every invocation of the op (i.e. sess.run call) we are re-computing a plan for the FFT, executing the 1 FFT, then returning. Thankfully when the batch size is greater than 1 we re-use the plan within that execution :), just not across executions. Numpy caches the FFT plan in a process-wide cache, so separate calls to ifft2d will re-use the plan made on the first call with that shape. |
It looks like you're producing trace timelines (great!) -- can you post them? I would like to see the actual measured time within the op instead of the overall time of calling sess.run (which as @aselle points out can be a source of measurement error) |
Sure. Attached you will find a zip tf_fft_timelines.zip containing 4 files. They correspond to the experiments I showed above. Basically the network is doing an IFFT2D, masking the result, and taking the complex module. The batch size by default is 35. I am fitting the network, but since there is no parameters to fit, it's just doing a forward pass (results are roughly the same for only a batch prediction).
This benchmark was done with Ubuntu 16.04, on CPU only. |
Thanks for that! What CPU are you using? Does it have AVX support and was TensorFlow built with AVX support? |
Output of
I did not build TensorFlow from source so it was not built with AVX support. However, even when I use my GPU (Quadro P5000), the IFFT2D is still slower with |
I just did another simple benchmark, this time comparing TensorFlow (used with |
Thanks @zaccharieramzi! Could you please run a GPU benchmark as well? Based on this I think PyTorch uses cuFFT and MKL, so on GPU TensorFlow and PyTorch should be using the same underlying FFT implementation. It would be nice to confirm whether or not there is a difference when both systems are using cuFFT because that would point to inefficiencies outside the particular FFT implementation that need to be addressed. |
I'm reopening because it is a known issue that TensorFlow's CPU FFT implementation is based on Eigen's TensorFFT, which is not fast compared to FFTW, MKL (what PyTorch uses), or FFTPACK (what NumPy uses). For projects I work on at Google, we use TPUs and GPUs for FFTs in both training and serving, so I think we are not noticing the pain of the CPU FFTs being slow. Sorry about that :(. Unfortunately we are a little stuck in terms of legal approval for the above mentioned libraries. When we implemented tf.signal Eigen's TensorFFT was the best option in terms of features (non-power-of-2 FFTs via Bluestein's algorithm is a hard requirement) and legal compatibility. |
I am training a model on my main GPU right now, but as soon as I can I will run the same benchmarks. But as I recall (and wrote here), even with GPU, I will come back with exact numbers coming from the trace timelines (and a zip with all of them). |
I ran the benchmarks on GPU. You can find a zip tf_fft_timelines_gpu.zip attached with similar nomenclature as before. (You will notice that only one run "epoch" appears in each trace, I did not find a way -did not look very hard as well- to have all the "epochs" in once trace) The first "epoch" is way slower (that's why I did more than 1 "epoch" this time). Was the remark about not caching the plan only applicable to CPU? This time the pure I also enhanced my |
Thanks again @zaccharieramzi. While an end-to-end benchmark is useful, for the purpose of this bug I'm only concerned with the FFT ops themselves, so I think it could be simplified. Some issues I noticed:
Here's a more "raw" benchmark for TensorFlow that can give a better sense of what's going on. import tensorflow as tf
import numpy as np
tf.reset_default_graph()
x_np = np.random.rand(35, 320, 320) + 1j * np.random.rand(35, 320, 320)
with tf.device('/gpu:0'):
x_tf = tf.constant(x_np)
y = tf.signal.ifft2d(x_tf)
sess = tf.Session() And then: %%timeit -n1000
sess.run(y.op) # Note the .op avoids fetching y from device to host. On a public Colab runtime with a K80 GPU, I'm observing a very high initial run time, followed by a steady result of:
Compared to PyTorch, which seems to be faster by about 2x: import torch
dtype = torch.float32
device = torch.device("cuda")
x = torch.randn(35, 320, 320, 2, device=device, dtype=dtype) And then: %%timeit -n1000
y_pt = torch.ifft(x, 2) I'm getting:
I believe this is apples-to-apples, since we're not measuring device/host transfer time, and the input shapes are the same. Here is a screenshot of a trace of the above TensorFlow graph executing on the same K80: You can see there are no memcpy D2H/H2Ds (good). You can see there is over 0.5 ms of time between the IFFT2D op execution ( I'll take a look at how PyTorch invokes cuFFT to see how they differ from TensorFlow. The performance should be identical on GPU. Here is the PyTorch trace of the above: I'm less familiar with how to interpret PyTorch traces, so I don't know how to figure out what the latency was between CPU scheduling the kernel and it starting, but it at least confirms the overall execution time took just over 2 ms, and I don't see anything that looks like a D2H/H2D copy. |
Totally agree. Thanks for the remarks on the different aspects of the benchmark. However, the trace timelines I got (still relevant I think) are not exactly the same as yours for TensorFlow. For example, I only have 2 items in the CUDA stream. Anyway, I do think your benchmark shows that it might be worth investigating. |
I added jax to the comparison, which is also much faster than tensorflow: seconds (lower is better): jax uses pocketfft (https://jax.readthedocs.io/en/latest/_modules/jax/_src/lax/fft.html#fft), which seems to be much faster then Tensorflow's Eigen. Updated colab: https://colab.research.google.com/gist/Andreas5739738/fc603468829ee0fc7e40a2e27d8a6661/fft.ipynb |
Yes ! Inline with @Andreas5739738 Gist , was able to replicate the issue in CPU environment for TF 2.5,But in GPU environment it is approx 405 times faster than numpy , providing Gist for reference . Thanks! |
The situation has not improved with TF 2.6:
|
This is still an issue in current tf-nightly (2.8+). Also noticed something interesting: when using double precision, TF is more than twice as fast as with single precision (although still ~8x slower than Jax):
https://colab.research.google.com/gist/Andreas5739738/fc603468829ee0fc7e40a2e27d8a6661/fft.ipynb |
Looks like Jax team found the same issue with XLA FFT slowness, and integrated PocketFFT as a workaround: jax-ml/jax#2952 |
Is there any workaround for GPU? Most of my computational cost is just going to compute FFT. PyTorch's implementation is also very fast. |
Looks like things have improved significantly with the 2.9 release, although there still is a large gap to numpy and Jax:
|
Are there any plans to address this issue any further? It's a huge inconvenience for anyone wanting to perform FFT transformations in a deployed model. |
@diggerdu !
Thank you! |
I tried to execute the mentioned code on tf-nightly(2.17.0-dev20240403) on CPU and observed that the time take for the execution on tensorflow is lesser than numpy. Tensorflow 2.17.0-dev20240403 1.0675210610000079
Numpy: 1.7257418959999882
Jax: 2.5296104049999997 Kindly find the gist of it here. Thank you! |
Yes, we now use DUCC FFT in all of TF, JAX, XLA. This is resolved. |
What related GitHub issues or StackOverflow threads have you found by searching the web for your problem?
Environment info
Operating System:
Ubuntu 16.04 LTS 64bit
Installed version of CUDA and cuDNN:
(please attach the output of
ls -l /path/to/cuda/lib/libcud*
):-rw-r--r-- 1 root root 558720 9月 15 07:02 libcudadevrt.a lrwxrwxrwx 1 root root 16 9月 15 07:05 libcudart.so -> libcudart.so.8.0 lrwxrwxrwx 1 root root 19 9月 15 07:05 libcudart.so.8.0 -> libcudart.so.8.0.44 -rw-r--r-- 1 root root 415432 9月 15 07:02 libcudart.so.8.0.44 -rw-r--r-- 1 root root 775162 9月 15 07:02 libcudart_static.a -rwxr-xr-x 1 root root 79337624 10月 27 23:13 libcudnn.so -rwxr-xr-x 1 root root 79337624 10月 27 23:13 libcudnn.so.5 -rwxr-xr-x 1 root root 79337624 10月 27 23:13 libcudnn.so.5.1.5 -rw-r--r-- 1 root root 69756172 10月 27 23:13 libcudnn_static.a
If installed from binary pip package, provide:
A link to the pip package you installed:
The output from
python -c "import tensorflow; print(tensorflow.__version__)"
.0.12.head
If installed from source, provide
The commit hash (
git rev-parse HEAD
)The output of
bazel version
If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code)
What other attempted solutions have you tried?
Logs or other output that would be helpful
(If logs are large, please upload as attachment or provide link).
The text was updated successfully, but these errors were encountered: