Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Max pooling cause error on empty batch #21338

Closed
wrongtest-intellif opened this issue Aug 2, 2018 · 16 comments
Closed

Max pooling cause error on empty batch #21338

wrongtest-intellif opened this issue Aug 2, 2018 · 16 comments
Assignees

Comments

@wrongtest-intellif
Copy link

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 3.10.0-693.2.2.el7.x86_64
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 1.8.0
  • Python version: Python 2.7.14 :: Anaconda
  • Bazel version (if compiling from source): None
  • GCC/Compiler version (if compiling from source): None
  • CUDA/cuDNN version: cuda==9.0, cudnn==7.0.4
  • GPU model and memory: None
  • Exact command to reproduce: See below

Describe the problem

When batch_size is 0, max pooling operation seems to produce an unhandled cudaError_t status. It may cause subsequent operations fail with odd error message. That is extremely difficult to debug.

(This corner case bothers us, where we first extract some bounding boxes and then run traditional convolution operations on areas specified by them. The above error occurs in case that no bounding boxes are detected thus batch_size becomes 0. However, the python exception will be randomly thrown at following operation or following session run steps)

import tensorflow as tf
import numpy as np

x = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4, 1])
pool_op = tf.nn.pool(x, pooling_type="MAX", window_shape=[2, 2], strides=[1, 1], padding="SAME")

y = tf.placeholder(dtype=tf.float32, shape=[None])
other_op = tf.where(tf.equal(y, 1.0))

normal_data = np.zeros([1, 4, 4, 1], dtype="float32")
empty_data = np.zeros([0, 4, 4, 1], dtype="float32")

# cudaError is thread local, limit thread pool size to make it easy to reproduce
config = tf.ConfigProto()
config.inter_op_parallelism_threads = 1
with tf.Session(config=config) as sess:
    # run other_op success
    print sess.run(other_op, {y: [1.0, 2.0, 3.0, 4.0]})  # [[0]]

    # run pooling on datas success
    print sess.run(pool_op, {x: normal_data}).shape  # (1, 4, 4, 1)
    print sess.run(pool_op, {x: empty_data}).shape  # (0, 4, 4, 1)

    # run other_op now failed
    print sess.run(other_op, {y: [1.0, 2.0, 3.0, 4.0]})  # err

Above code report error:
tensorflow.python.framework.errors_impl.InternalError: WhereOp: Could not launch cub::DeviceReduce::Sum to count number of true / nonzero indices. temp_storage_bytes: 1, status: invalid configuration argument

"invalid configuration argument" seems to be message return by cudaGetError, which indicates a failed kernel launch due to zero or too large number of block threads.

Source code / logs

image

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Aug 2, 2018

Interesting findings! The error disappears when you use NCHW data format.
NCHW maxpooling will use cudnn and I've fixed the empty-input case for cudnn in https://github.com/tensorflow/tensorflow/pull/15264/files#diff-13381a722607d8496555bfd0e84c19ba.

However the NHWC code path uses a custom cuda kernel which does not explicitly handle empty input tensor.

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Aug 2, 2018

Adding an if-else somewhere in the op to check empty inputs can solve this issue. But before that I think it's worth fixing the unit test framework first:

There is actually a test that should've triggered this error:

def _testMaxPoolEmptyInput(self, use_gpu):
self._VerifyValues(
gen_nn_ops.max_pool_v2,
input_sizes=[0, 8, 8, 8],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=[],
use_gpu=use_gpu)
. Maybe the unit tests should check cuda error after each session run.

@tensorflowbutler tensorflowbutler added the stat:awaiting response Status - Awaiting response from author label Aug 2, 2018
@tensorflowbutler
Copy link
Member

Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks.
Mobile device

@wrongtest-intellif
Copy link
Author

wrongtest-intellif commented Aug 3, 2018

@ppwwyyxx Thanks for your nice notes! The error disappears after I change to NCWH.

I go through some test cases and find that most of them just do a session.run() and check result with function like assertAllClose(). And rare case of cudaGetLastError() or cudaPeekAtLastError() are used directly in tf repo (Does it mean kernel error will somewhat "broadcast" into execution engine?).

In WhereOp case, the error seems to finally get detected in nvidia cub lib https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/where_op_gpu.cu.h#L21

Any idea about mechanism to do such kind of checking in python?

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Aug 3, 2018

I don't think the error can be checked from Python. I hope the TF team will find a way to fix the tests soon. It will be even better if there is a way to identify ops that do not support empty inputs. Similar error probably exists in many places.

I reference a horovod issue above which may be related (I'm also using empty inputs when I met that issue, but not using NHWC maxpool). Others have seen similar issues, for example: tensorpack/tensorpack#760, CharlesShang/FastMaskRCNN#159, #16035, tensorflow/serving#627. All of these issues have not been clearly resolved, and they are all object detection use cases where empty inputs appear quite often. So it sounds like they are all related to the bug you found.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Aug 3, 2018
ppwwyyxx added a commit to ppwwyyxx/tensorflow that referenced this issue Aug 4, 2018
@angerson angerson added the stat:contribution welcome Status - Contributions welcome label Aug 7, 2018
@angerson
Copy link
Contributor

angerson commented Aug 7, 2018

Thanks @ppwwyyxx! I marked this as "Contributions Welcome" since you're looking at it.

@angerson angerson removed their assignment Aug 7, 2018
@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Aug 8, 2018

@angersson No I'm not. I can fix the pooling ops but I expect TF team to find out why the tests did not detect such failure.

@wrongtest-intellif
Copy link
Author

At least in synchronous execution on gpu devices, whether the computation is buggy or not seems to be checked by OpContext's status, which requires OP developers to follow some error handling conventions, or else errors may "leak".

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/common_runtime/gpu/gpu_device.cc#L506-L508

I test to do some additional check here and recompile, do get nonzero cuda error code (9). Wonder why such kind of post op check not done currently.

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Aug 9, 2018

To summarize, what happened is:

  1. cuda kernel launch mykernel<<<0, ...>>> will do nothing (exactly what we need) but set cuda error
  2. the error is not checked after the kernel launch.
  3. the error may appear as an error of other ops (if the other op happen to check cuda error)

I think it's a quite serious issue:

  1. other ops (in addition to pooling) may have the same problem.
  2. it causes very misleading error messages
  3. the unit test framework now fails to detect it
    Spamming XQ @zheng-xq and Toby @tfboyd to escalate.

@ppwwyyxx
Copy link
Contributor

ping @tfboyd and @zheng-xq

@ppwwyyxx
Copy link
Contributor

ping @tfboyd and @zheng-xq
btw, in general I think a clear bug should not be marked "contribution welcome"

@ppwwyyxx
Copy link
Contributor

@tfboyd

@facaiy facaiy assigned facaiy and tfboyd and unassigned facaiy Sep 20, 2018
@facaiy facaiy added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Sep 20, 2018
@tfboyd
Copy link
Member

tfboyd commented Oct 4, 2018

@asimshankar Can you review this with the changes to the team? I was not sure which github was Alek. This was marked as contributions welcome and that should be reviewed.

@tfboyd tfboyd removed the stat:contribution welcome Status - Contributions welcome label Oct 4, 2018
@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 5, 2018
@chsigg
Copy link
Contributor

chsigg commented Nov 27, 2018

Thanks Yuxin, I've approved your fix PR.

I will investigate what we can do about checking CUDA errors where they happen. The correct way to prevent CUDA errors from tunneling to other places is to check cudaGetLastError after each kernel launch. Most kernel launches don't currently do that. We are considering adding infrastructure code that all kernel launches would go through, which would allow for a central place to fix this (among other things related to launching kernels directly through the CUDA runtime). Until this is in place, it might be worthwhile to just add those checks where they are missing.

@ppwwyyxx
Copy link
Contributor

Thanks @chsigg!

@mohantym mohantym self-assigned this Mar 15, 2022
@mohantym
Copy link
Contributor

Hi @wrongtest ! Closing this issue as it has been resolved in 2.8 version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants