Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we get PocketFFT ported to Tensorflow? #56685

Closed
ddgonzal3 opened this issue Jul 6, 2022 · 24 comments
Closed

Can we get PocketFFT ported to Tensorflow? #56685

ddgonzal3 opened this issue Jul 6, 2022 · 24 comments
Assignees
Labels
comp:apis Highlevel API related issues comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:feature Feature requests

Comments

@ddgonzal3
Copy link

ddgonzal3 commented Jul 6, 2022

Click to expand!

Issue Type

Feature Request

Source

source

Tensorflow Version

2.9.1

Custom Code

No

OS Platform and Distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current Behaviour?

Is it possible to integrate PocketFFT when using a CPU into Tensorflow, which functions like tf.signal.stft and tf.signal.inverse_stft can leverage? 

Currently, Tensorflow's FFT uses EigenFFT, which is almost 3x slower than Numpy and Jax, which use PocketFFT.

There are heaps more details here: https://github.com/tensorflow/tensorflow/issues/6541

I'm sure many projects would benefit from this investment considering much of what's done for speech and music these days use STFT data.

Standalone code to reproduce the issue

print("seconds (lower is better):")
print(f"Tensorflow {tf.__version__}", timeit.timeit('X = tf.signal.rfft(x)', setup='import tensorflow as tf; x = tf.random.normal([50000, 512])', number=10))
print(f"Tensorflow {tf.__version__}, double precision", timeit.timeit('X = tf.cast(tf.signal.rfft(tf.cast(x, tf.float64)), tf.complex64)', setup='import tensorflow as tf; x = tf.random.normal([50000, 512])', number=10))
print("Numpy: ", timeit.timeit('X = numpy.fft.rfft(x)', setup='import numpy.fft; import tensorflow as tf; x = tf.random.normal([50000, 512])', number=10))
print("Jax: ", timeit.timeit('jnp.fft.rfft(x).block_until_ready()', setup='import jax.numpy as jnp; import tensorflow as tf; x = tf.random.normal([50000, 512]).numpy()', number=10))

Relevant log output

seconds (lower is better):
Tensorflow 2.9.1 5.495112890999991
Tensorflow 2.9.1, double precision 7.629201937000033
Numpy:  2.1803204349999987
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Jax:  1.4081462569999985
@google-ml-butler google-ml-butler bot added the type:feature Feature requests label Jul 6, 2022
@sushreebarsa sushreebarsa added comp:apis Highlevel API related issues TF 2.9 Issues found in the TF 2.9 release (or RCs) labels Jul 6, 2022
@chunduriv chunduriv assigned gowthamkpr and unassigned chunduriv Jul 8, 2022
@ddgonzal3
Copy link
Author

MKL FFT could be another alternative if pocketFFT is not possible.

@gowthamkpr gowthamkpr added comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Jul 18, 2022
@gowthamkpr
Copy link

@ddgonzal3 PocketFFt is still not integrated in tensorflow. jax has used PocketFFT as a workaround as mentioned here but tensorflow has not.

@ddgonzal3
Copy link
Author

Hey @gowthamkpr, thanks for your response!

Yeah I've gotten the JAX fft to work using jax2tf.convert, but it is not able to be serialized into a protobuf (falls back to EigenFFT), so it's only usable in a python environment.

Do you know if there are currently any plans to integrate PocketFFT into tensorflow?

@cantonios
Copy link
Contributor

Eigen can use PocketFFT if we include the pocketfft header (pcketfft_hdronly.h) in tensorflow and define EIGEN_POCKETFFT_DEFAULT=1. That was added in !356. We'd probably be willing to accept a pull-request to enable it.

The header would need to be pulled in under third_party (similar to the relevant JAX change in 4699.

@ddgonzal3
Copy link
Author

ddgonzal3 commented Sep 8, 2022

Hey @cantonios, thanks for your reply.

Do you know if that's all that would be needed? If so, that sounds like a very quick change for something that would give immense benefit to the community for many use cases that require STFT transformations in a Tensorflow model inference pipeline.

If one wants to deploy a Tensorflow model in C++ on a CPU that operates on STFTs, they're forced to implement the STFT in C++ and feed that data into Tensorflow, otherwise you face a 3x performance hit. The problem gets immensely more complicated when the STFT transformations are in the middle of a TF inference pipeline.

I haven't been involved in contributing to Tensorflow in the past, so I would likely need to invest quite some time to learning the build system to do it myself. I'm currently swamped with high priority work before a going on vacation, so won't have the time to do so any time soon.

Is this something someone at Google could quickly try to implement, given the huge benefit it would provide?

Based off of this issue, it's a highly requested feature, as expected considering how widely used FFTs are in ML (speech, music, video, imaging, etc).

@cantonios
Copy link
Contributor

cantonios commented Sep 9, 2022

Looks like JAX switched to ducc (which is an "evolution" of pocketfft written by the same author). We may consider the same.

@ddgonzal3 it's not high priority for anyone on our end because we typically do model serving with GPU/TPU. I've asked around, as well as the JAX team about their experience. Depending on what I hear back, I may put it on my list, though again it would be near the bottom of my priority queue. If anyone else here is interested, I'm happy to look at PRs.

@cantonios
Copy link
Contributor

Looks like the PocketFFT author now recommends using DUCC, and the JAX team is happy with the transition, so we should do the same.

@ddgonzal3
Copy link
Author

@cantonios Awesome! That sounds great to me. I really appreciate you taking the time to look into this and check in with the Jax team.

Please let me know if you're ever able to get around to implementing DUCC, and I would be very keen on testing and using it.
Thank you!

@ddgonzal3
Copy link
Author

hey @cantonios, happy new year!

Just wanted to check in to see if you happen to have some bandwidth for pushing this feature onto your priority queue.

Looking forward to hearing from you soon. Thanks!

@mreineck
Copy link

Hi, author of pocketfft here :-) I'm happy to help if there are any questions.
One thing to keep in mind regarding pocketfft vs. DUCC: pocketfft requires C++11 support in the compilers, DUCC requires C++17. Unless you can guarantee C++17 support on all target platforms, you may want to use pocketfft, but if this is not an issue, I strongly recommend DUCC.

@cantonios
Copy link
Contributor

Hi, author of pocketfft here :-) I'm happy to help if there are any questions.
One thing to keep in mind regarding pocketfft vs. DUCC: pocketfft requires C++11 support in the compilers, DUCC requires C++17. Unless you can guarantee C++17 support on all target platforms, you may want to use pocketfft, but if this is not an issue, I strongly recommend DUCC.

Thanks @mreineck !

C++17 isn't an issue - TF now requires c++17 anyways. I think the main blockers are:

  1. We can't actually use std::mutex or std::thread directly in TF (or internally, at all, within Google)
  2. We would like to be able to use TF's threadpool for execParallel(...)

For JAX, we initially simply set DUCC0_NO_THREADING. However, we later found that we ran into race conditions in the get_plan cache due to parallel FFTs. We could disable plan caching as well, but I'm not sure how negatively that would impact performance. We should solve this before adding it to TF.

I think we can work around this with minor modifications to ducc0: add optional inputs with defaults for the parallelizer and plan cache. I haven't looked into it too closely. How open would you be to modifications like this?

@mreineck
Copy link

We could disable plan caching as well, but I'm not sure how negatively that would impact performance.

Actually that should not be a problem at all ... unless someone is calling 1D FFTs of length 10 over and over :-)
If you have lengths above, say, 256, or multi-D transforms, plan caching isn't an issue.

I think we can work around this with minor modifications to ducc0: add optional inputs with defaults for the parallelizer and plan cache. I haven't looked into it too closely. How open would you be to modifications like this?

I'd have to see the changes to judge this. Can you point me to a code location where you do something equivalent to ducc's execParallel?

@cantonios
Copy link
Contributor

Actually that should not be a problem at all ... unless someone is calling 1D FFTs of length 10 over and over :-)

We have people doing many batches of small FFTs. For example, we have an internal TF user that ran into an issue with cuFFT that was doing millions of training iterations, each computing 10k batches of 16x16 2D FFTs. What was the motivation for implementing the get_plan LRU cache, and do you think it would apply here?

I'd have to see the changes to judge this. Can you point me to a code location where you do something equivalent to ducc's execParallel?

TF's threadpool has ParallelFor. We would probably need something like:

template<typename Tplan, typename T, typename T0, typename Exec, typename Parallelizer = DefaultParallelizer, typename PlanCache = DefaultPlanCache >
DUCC0_NOINLINE void general_nd(const cfmav<T> &in, vfmav<T> &out,
  const shape_t &axes, T0 fct, size_t nthreads, const Exec &exec,
  const bool /*allow_inplace*/=true,
  Parallelizer parellelizer = get_default_parallelizer(),  
  PlanCache plan_cache = get_default_plan_cache<T>()
)

The defaults would be simple structs that call execParallel(...) and get_plan<T>(...) directly, but for TF's usage we would pass in our own versions that use TF's threadpool, and our own thread-safe caching mechanism.

@mreineck
Copy link

We have people doing many batches of small FFTs. For example, we have an internal TF user that ran into an issue with cuFFT that was doing millions of training iterations, each computing 10k batches of 16x16 2D FFTs.

If the 10k FFTs are done in a single call (i.e. you call with a 3D array, but transform only along two axes), then plan caching is irrelevant. The plan for a singe length-16 FFT will be computed once at the start, and the cost for this is negligible compared to the actual transforms. (And even if the 16x16 transforms are called individually, the plan will be re-used 32 times, so that the actual FFT dominates). It's important to realize that pocketfft plans are much faster to compute than even FFTW's FFTW_ESTIMATE plans.

That said, if performance of 16x16 FFTs is really important for an application, I'd look into custom-generated transforms for exactly this size... they can be much faster than any general purpose approach.

What was the motivation for implementing the get_plan LRU cache, and do you think it would apply here?

This mainly exists because scipy people wanted to have it. I even have a PR against scipy (scipy/scipy#12307) which would have disabled it because it can cause too much memory consumption in edge cases.
This is no longer a problem with the current ducc version, but I still don't see real benefit in it.

Thanks for the code snippet! I'm sure we can work something out if necessary.
BTW, we can work on a potential solution for this on https://github.com/mreineck/ducc; it should be easier to collaborate there than on the Gitlab instance where I have my master repo.

@cantonios
Copy link
Contributor

This mainly exists because scipy people wanted to have it....

Perfect, this simplifies things - we can probably skip caching altogether. Yes, the batch would be called all at once.

Thanks for the code snippet! I'm sure we can work something out if necessary...

Great, I was wondering how best to work together on something like this. I will try to create a prototype then and ping you on the github repo when I have something.

@ddgonzal3
Copy link
Author

Thanks guys!! Excited to be hopefully seeing faster FFTs in Tensorflow soon! :D

@Mddct
Copy link

Mddct commented Aug 8, 2023

any unpdate on this?

@cantonios
Copy link
Contributor

In process as we speak. We needed upstream changes to ducc0 fft to be able to use our custom threadpools.

@ddgonzal3
Copy link
Author

Hey @cantonios , happy new year!

Been a while - do you know if there has been any update for the ducc FFT integration?

@cantonios
Copy link
Contributor

Yes, this was completed a while ago.

@ddgonzal3
Copy link
Author

This is awesome!! I'm using the nightly build from today and am seeing way faster times for the FFT and IFFT operations!

@ddgonzal3
Copy link
Author

@cantonios Would these changes make there way to TF 2.16? And if so, any idea when it might get released?

Thanks again for the incredible work!!

@cantonios
Copy link
Contributor

@cantonios Would these changes make there way to TF 2.16? And if so, any idea when it might get released?

Thanks again for the incredible work!!

It was included in the 2.14 release, and will continue to be in all future releases.

@ddgonzal3
Copy link
Author

Ah whoops, I thought I didn't see the changes in the release but I just tried it with 2.15 and it's working well. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:apis Highlevel API related issues comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

7 participants