-
Notifications
You must be signed in to change notification settings - Fork 74.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can we get PocketFFT ported to Tensorflow? #56685
Comments
MKL FFT could be another alternative if pocketFFT is not possible. |
@ddgonzal3 PocketFFt is still not integrated in tensorflow. jax has used PocketFFT as a workaround as mentioned here but tensorflow has not. |
Hey @gowthamkpr, thanks for your response! Yeah I've gotten the JAX fft to work using jax2tf.convert, but it is not able to be serialized into a protobuf (falls back to EigenFFT), so it's only usable in a python environment. Do you know if there are currently any plans to integrate PocketFFT into tensorflow? |
Eigen can use PocketFFT if we include the pocketfft header ( The header would need to be pulled in under |
Hey @cantonios, thanks for your reply. Do you know if that's all that would be needed? If so, that sounds like a very quick change for something that would give immense benefit to the community for many use cases that require STFT transformations in a Tensorflow model inference pipeline. If one wants to deploy a Tensorflow model in C++ on a CPU that operates on STFTs, they're forced to implement the STFT in C++ and feed that data into Tensorflow, otherwise you face a 3x performance hit. The problem gets immensely more complicated when the STFT transformations are in the middle of a TF inference pipeline. I haven't been involved in contributing to Tensorflow in the past, so I would likely need to invest quite some time to learning the build system to do it myself. I'm currently swamped with high priority work before a going on vacation, so won't have the time to do so any time soon. Is this something someone at Google could quickly try to implement, given the huge benefit it would provide? Based off of this issue, it's a highly requested feature, as expected considering how widely used FFTs are in ML (speech, music, video, imaging, etc). |
Looks like JAX switched to ducc (which is an "evolution" of pocketfft written by the same author). We may consider the same. @ddgonzal3 it's not high priority for anyone on our end because we typically do model serving with GPU/TPU. I've asked around, as well as the JAX team about their experience. Depending on what I hear back, I may put it on my list, though again it would be near the bottom of my priority queue. If anyone else here is interested, I'm happy to look at PRs. |
Looks like the PocketFFT author now recommends using DUCC, and the JAX team is happy with the transition, so we should do the same. |
@cantonios Awesome! That sounds great to me. I really appreciate you taking the time to look into this and check in with the Jax team. Please let me know if you're ever able to get around to implementing DUCC, and I would be very keen on testing and using it. |
hey @cantonios, happy new year! Just wanted to check in to see if you happen to have some bandwidth for pushing this feature onto your priority queue. Looking forward to hearing from you soon. Thanks! |
Hi, author of pocketfft here :-) I'm happy to help if there are any questions. |
Thanks @mreineck ! C++17 isn't an issue - TF now requires c++17 anyways. I think the main blockers are:
For JAX, we initially simply set I think we can work around this with minor modifications to |
Actually that should not be a problem at all ... unless someone is calling 1D FFTs of length 10 over and over :-)
I'd have to see the changes to judge this. Can you point me to a code location where you do something equivalent to |
We have people doing many batches of small FFTs. For example, we have an internal TF user that ran into an issue with cuFFT that was doing millions of training iterations, each computing 10k batches of 16x16 2D FFTs. What was the motivation for implementing the
TF's threadpool has template<typename Tplan, typename T, typename T0, typename Exec, typename Parallelizer = DefaultParallelizer, typename PlanCache = DefaultPlanCache >
DUCC0_NOINLINE void general_nd(const cfmav<T> &in, vfmav<T> &out,
const shape_t &axes, T0 fct, size_t nthreads, const Exec &exec,
const bool /*allow_inplace*/=true,
Parallelizer parellelizer = get_default_parallelizer(),
PlanCache plan_cache = get_default_plan_cache<T>()
) The defaults would be simple structs that call |
If the 10k FFTs are done in a single call (i.e. you call with a 3D array, but transform only along two axes), then plan caching is irrelevant. The plan for a singe length-16 FFT will be computed once at the start, and the cost for this is negligible compared to the actual transforms. (And even if the 16x16 transforms are called individually, the plan will be re-used 32 times, so that the actual FFT dominates). It's important to realize that pocketfft plans are much faster to compute than even FFTW's FFTW_ESTIMATE plans. That said, if performance of 16x16 FFTs is really important for an application, I'd look into custom-generated transforms for exactly this size... they can be much faster than any general purpose approach.
This mainly exists because Thanks for the code snippet! I'm sure we can work something out if necessary. |
Perfect, this simplifies things - we can probably skip caching altogether. Yes, the batch would be called all at once.
Great, I was wondering how best to work together on something like this. I will try to create a prototype then and ping you on the github repo when I have something. |
Thanks guys!! Excited to be hopefully seeing faster FFTs in Tensorflow soon! :D |
any unpdate on this? |
In process as we speak. We needed upstream changes to ducc0 fft to be able to use our custom threadpools. |
Hey @cantonios , happy new year! Been a while - do you know if there has been any update for the ducc FFT integration? |
Yes, this was completed a while ago. |
This is awesome!! I'm using the nightly build from today and am seeing way faster times for the FFT and IFFT operations! |
@cantonios Would these changes make there way to TF 2.16? And if so, any idea when it might get released? Thanks again for the incredible work!! |
It was included in the 2.14 release, and will continue to be in all future releases. |
Ah whoops, I thought I didn't see the changes in the release but I just tried it with 2.15 and it's working well. Thanks again! |
Click to expand!
Issue Type
Feature Request
Source
source
Tensorflow Version
2.9.1
Custom Code
No
OS Platform and Distribution
No response
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/Compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current Behaviour?
Standalone code to reproduce the issue
Relevant log output
seconds (lower is better): Tensorflow 2.9.1 5.495112890999991 Tensorflow 2.9.1, double precision 7.629201937000033 Numpy: 2.1803204349999987 WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) Jax: 1.4081462569999985
The text was updated successfully, but these errors were encountered: