Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv2D with XLA jit_compile=True fails to run #57748

Open
Co1lin opened this issue Sep 19, 2022 · 4 comments
Open

Conv2D with XLA jit_compile=True fails to run #57748

Co1lin opened this issue Sep 19, 2022 · 4 comments
Assignees
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.10 type:bug Bug

Comments

@Co1lin
Copy link

Co1lin commented Sep 19, 2022

Click to expand!

Issue Type

Bug

Source

binary

Tensorflow Version

2.11.0.dev20220914

Custom Code

No

OS Platform and Distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current Behaviour?

The following code works well without jit_compile=True. However, if we enable XLA compilation by adding jit_compile=True, it will throw an error. Reproduced in CoLab notebook here.

Standalone code to reproduce the issue

import tensorflow as tf
from keras import layers

class MyModule(tf.Module):
    def __init__(self):
        super().__init__()
        self.conv = layers.Conv2D(2, 1, padding='valid', dtype=tf.float64, autocast=False)

    @tf.function(jit_compile=True) # without jit_compile=True works fine
    def __call__(self, i0):
        o0 = tf.floor(i0)
        o1 = self.conv(o0)
        o2 = tf.add(o1, o0)
        return o2

def simple():
    inp = {
        "i0": tf.constant(
            3.14, shape=[1,1,3,2], dtype=tf.float64
        ),
    }
    m = MyModule()

    out = m(**inp) # Error!

    print(out)
    print(out.shape)

if __name__ == "__main__":
    simple()

Relevant log output

2022-09-18 01:33:53.096156: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x55a9ace73180 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2022-09-18 01:33:53.096176: I tensorflow/compiler/xla/service/service.cc:181]   StreamExecutor device (0): NVIDIA GeForce RTX 3080 Ti, Compute Capability 8.6
2022-09-18 01:33:53.098645: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2022-09-18 01:33:53.537659: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8100
2022-09-18 01:33:54.161249: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:5341] Disabling cuDNN frontend for the following convolution:
  input: {count: 1 feature_map_count: 2 spatial: 1 3  value_min: 0.000000 value_max: 0.000000 layout: BatchDepthYX}
  filter: {output_feature_map_count: 2 input_feature_map_count: 2 layout: OutputInputYX shape: 1 1 }
  {zero_padding: 0 0  pad_alignment: default filter_strides: 1 1  dilation_rates: 1 1 }
  ... because it uses an identity activation.
2022-09-18 01:33:54.749772: I tensorflow/compiler/jit/xla_compilation_cache.cc:476] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Traceback (most recent call last):
  File "/home/colin/code/test_proj/scripts/tflite2.py", line 41, in simple
    out = m(**inp)
  File "/home/colin/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/colin/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 52, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.UnknownError: CUDNN_STATUS_NOT_SUPPORTED
in tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc(5151): 'status' [Op:__inference___call___46]
@mohantym
Copy link
Contributor

Hi @Co1lin !
It is replicating in 2.9, 2.10 and nightly (but not in 2.8).

Work around is to use auto clustering with xla. work around gist.

@gadagashwini !
Could you look at this issue.

Thank you!

@mohantym mohantym assigned gadagashwini and unassigned mohantym Sep 19, 2022
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Sep 27, 2022
@cheshire
Copy link
Member

Thanks for the bug report! Filed b/249449866 to track internally.

@bviyer
Copy link

bviyer commented Oct 7, 2022

Hello Colin,
I see that you are using cuDNN version 8100 and that could be the reason for your issue. Please see here (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc#L5309) . Is it possible for you to upgrade cuDNN version > 8205 in your system? I was able to get your code working using cuDNN 8302.

Here is the screenshot of what I am seeing:

...
I1006 14:03:28.553708 2111518 cuda_dnn.cc:428] Loaded cuDNN version 8302
I1006 14:03:28.972418 2111518 xla_compilation_cache.cc:476] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
tf.Tensor(
[[[[-0.14627518  4.15663878]
   [-0.14627518  4.15663878]
   [-0.14627518  4.15663878]]]], shape=(1, 1, 3, 2), dtype=float64)
(1, 1, 3, 2)
...

@ganler
Copy link
Contributor

ganler commented Oct 25, 2022

@bviyer @cheshire I can confirm that @bviyer's suggestions are correct. The error occurs on a machine using CUDNN 8202 and it can work smoothly on another machine using CUDNN8500+. But I think it is ideal to make XLA more robust putting a more reasonable message or disabling such optimization for TensorFlow with CUDNN <= 8205. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.10 type:bug Bug
Projects
None yet
Development

No branches or pull requests

7 participants