Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance of nn.conv1d and keras.layers.Conv1D is low the first time any given input size is processed even if retracing is prevented! #54456

Open
roebel opened this issue Feb 19, 2022 · 9 comments
Assignees
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.8 type:performance Performance Issue

Comments

@roebel
Copy link

roebel commented Feb 19, 2022

Please make sure that this is an issue related to performance of TensorFlow.
As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):

yes

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):

Linux Ubuntu 20.04

  • TensorFlow installed from (source or binary):

pip

  • TensorFlow version (use command below):

tested with TF 2.4, 2.6, 2.8

  • Python version:

3.7

  • CUDA/cuDNN version:

Cuda Toolkit 11.3.1
CUDNN 8.3.1

  • GPU model and memory:

Tested with GeForce GTX 1050 Ti and GeForce GTX 1080 Ti

Describe the current behavior

In both cases using eager mode and using a tf.function with experimental_relax_shapes=True
running tf.nn.conv1d is slow the first time a tensor of any given size is processed and the processing of a new tensor of the same size then becomes 4 times (on GPU 1050 TI) or 10 times (on GPU 1080 Ti) faster the second or further times
.
The observed behavior is a severe problem for running inference with audio signals because audio signals generally have very different sizes and therefore in a production environment the code will run only with 10% of maximum performance (on a GPU 1080 Ti)
for the first few 100k examples until the model has seen sufficiently many lengths to achieve peak performance.

Describe the expected behavior

conv1d processing time should depend on the size of the input vector and not on the number of times the same size has been seen.

Standalone code to reproduce the issue

colab notebook is
here

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem.

Result running on colab with GPU

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
=== run 1 ===
len 10000 time: 10.187 speed 9.82kHz su (10, 9991, 100)
len 10001 time: 0.603 speed 165.93kHz su (10, 9992, 100)
len 10002 time: 0.613 speed 163.08kHz su (10, 9993, 100)
len 10003 time: 0.628 speed 159.26kHz su (10, 9994, 100)
len 10004 time: 0.587 speed 170.35kHz su (10, 9995, 100)
=== run 2 ===
len 10000 time: 0.051 speed 1946.11kHz su (10, 9991, 100)
len 10001 time: 0.050 speed 2000.30kHz su (10, 9992, 100)
len 10002 time: 0.066 speed 1518.87kHz su (10, 9993, 100)
len 10003 time: 0.056 speed 1788.66kHz su (10, 9994, 100)
len 10004 time: 0.059 speed 1705.02kHz su (10, 9995, 100)

Notes:

  • You notice 10 times increase the second time the inner loop is run. If you run the same notebook again the first run will be fast as well, which indicates the GPU is caching something.
  • to have the effect again your need to restart the notebook
  • the behavior is not CUDA imposed because an equivalent PyTorch script runs the first and second pass of the inner loop without speed difference.
@roebel roebel added the type:performance Performance Issue label Feb 19, 2022
@roebel
Copy link
Author

roebel commented Feb 20, 2022

I believe to have made some progress using tensorboard profiler. The first time the convolution is run on a given size the profiler displays running all the following kernels

  • maxwell_scudnn_128x128_relu_small_nn_v1,
  • implicit_convolve_sgemm,
  • cudnn::cnn::im2col4d_kernel,
  • explicit_convolve_sgemm,
  • fft1d_c2r_256,
  • maxwell_gcgemm_32x32_nt.

The second time any given size is processed only a single kernel is used. In the present case these are either maxwell_gcgemm_32x32_nt or maxwell_scudnn_128x128_relu_small_nn_v1. So it appears tensorflow is trying to optimize by means of adaptively selecting the best kernel for each size.

This is problematic notably for inference with audio where size of audio files can vary between 1 second (16000 samples) and 40 seconds (640000 samples) because as mentioned above the software will spend most of the time trying to optimize and will therefore need more than a day before it starts to perform optimally. The question here would be whether this trial stage can be prevented by means of manually preselecting a strategy.

@mohantym mohantym added comp:ops OPs related issues TF 2.8 labels Feb 21, 2022
@mohantym
Copy link
Contributor

Hi @chunduriv ! Could you please look at this issue? Its replicating in 2.8 and throwing different error in 2.7 and nightly. Thanks!

@mohantym
Copy link
Contributor

@roebel ! Did you try the same in distribution training yet?

@mohantym mohantym assigned chunduriv and unassigned mohantym Feb 21, 2022
@roebel
Copy link
Author

roebel commented Feb 21, 2022

@mohantym Thanks for your reply. But I wonder why you suggest distribution training? I don't have a problem with training I have a problem with inference. Also for my use cases, I don't have access to multiple GPUs so I cannot use distribution inference - or could I? If these tests would be run in parallel on the same GPU this may help but would also require more memory.

@chunduriv chunduriv assigned sachinprasadhs and unassigned chunduriv Mar 1, 2022
@sachinprasadhs sachinprasadhs added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 2, 2022
@LerysG
Copy link

LerysG commented Jun 14, 2022

Hi, got a similar problem (running on GTX 1080), using conv1d layers makes my kernel died every time, as my memory GPU is fully saturated even with very low dim input shpaes. I tried tensorflow version 2.5, 2.6, 2.7 and 2.9. I am running on CUDA v11.6 and CudNN v8.4.1 (windows). I could not find any thing to avoid crashing. If anyone has an idea, your very much welcome!

@sachinprasadhs
Copy link
Contributor

@roebel , Could you try the solution from the above link, also check the solution from the link here #56387 which is similar to your issue, which was solved by exporting path.

@sachinprasadhs sachinprasadhs added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Jun 14, 2022
@roebel
Copy link
Author

roebel commented Jun 15, 2022

@sachinprasadhs

I am sorry but the report of @LerysG is completely unrelated to what I describe here.

  • I don't have crash
  • I don't have a library missing.

So there is nothing I could reasonably check here.

Similarly, issue 56387 is not related at all.

The problem I describe is an implementation problem in the TensorFlow code.

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Jun 15, 2022
@sachinprasadhs sachinprasadhs added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 15, 2022
@roebel
Copy link
Author

roebel commented Jul 10, 2022

I think I found the source code that is responsible for the issue: cuda supports a number of convolution kernels that in the current version of TensorFlow 2.9.0 are obtained by means of CudnnSupport::GetConvolveRunners here

port::Status CudnnSupport::GetConvolveRunners(

Which is then used here in autotune functions

TF_RETURN_IF_ERROR(stream->parent()->GetConvolveRunners(

It appears that each time a configuration consisting of data shape, filter shape, and maybe other parameters are encountered the cuda driver tests all of the kernels and retains the most efficient one. This is a very nice optimization for most cases, notably training with constant batch shapes, or inference with constant image sizes. For inference with audio signals the cuda implementation is testing most of the time all kernels versions. It hardly ever benefits, from the information which of the kernels is the most efficient one for any given configuration, as the same configuration is hardly ever encountered a second time.

All this reminds me of fftw3, which also uses such kind of autotuning depending on the FFT size, but there FFT sizes do not change so much, and more over it is possible to store results in form of wisdom.

I think for the present case it would be nice to affect the kernel selection explicitly. I wonder if one cannot define a maximum size above which the selected kernel would no longer be adapted. E.g assuming one would be able to select a maximum adaptation length via env variable TENSORFLOW_CONV_MAXADAPLEN and if the variable is set, one would then always run autotuning with min (shape[i], TENSORFLOW_CONV_MAXADAPLEN). In that case the current problem could be solved if needed without any negative impact for te cases that have networks working on constant shapes (image processing).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:ops OPs related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.8 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

5 participants