Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CancelledError: [_Derived_]RecvAsync is cancelled. #45594

Open
stefan-falk opened this issue Dec 11, 2020 · 14 comments
Open

CancelledError: [_Derived_]RecvAsync is cancelled. #45594

stefan-falk opened this issue Dec 11, 2020 · 14 comments
Assignees
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.3 Issues related to TF 2.3 TF 2.4 for issues related to TF 2.4 type:bug Bug

Comments

@stefan-falk
Copy link

stefan-falk commented Dec 11, 2020

Note

I am opening this issue because the error I am describing seem to affect quite some people. See Related Issues, but those have been closed "due to inactivity".


System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16, 18
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v2.3.0-54-gfcc4b966f1 2.3.1
  • Python version: 3.8
  • CUDA/cuDNN version: 10.1
  • GPU model and memory: GeForce 1080 TI, 11 GB

Describe the current behavior

The training fails seemingly randomly with CancelledError: [_Derived_]RecvAsync is cancelled. All cases seem to have recurrent layer in common (see Related Issues).

After starting, the training will run (in my case) for some time and then just crash with the above error.

Describe the expected behavior

Don't crash.

Standalone code to reproduce the issue

There are some in #33721.

Other info / logs

2020-12-05 18:06:59.383572: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at cudnn_rnn_ops.cc:1517 : Unknown: CUDNN
_STATUS_BAD_PARAM
in tensorflow/stream_executor/cuda/cuda_dnn.cc(1484): 'cudnnSetTensorNdDescriptor( tensor_desc.get(), data_type, sizeof(dims) / sizeof(dims[0]), dims, strides)'
2020-12-05 18:06:59.383906: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at cudnn_rnn_ops.cc:1517 : Unknown: CUDNN_STATUS_BAD_PARAM
in tensorflow/stream_executor/cuda/cuda_dnn.cc(1484): 'cudnnSetTensorNdDescriptor( tensor_desc.get(), data_type, sizeof(dims) / sizeof(dims[0]), dims, strides)'
2020-12-05 18:06:59.384114: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at cudnn_rnn_ops.cc:1517 : Unknown: CUDNN_STATUS_BAD_PARAM
in tensorflow/stream_executor/cuda/cuda_dnn.cc(1484): 'cudnnSetTensorNdDescriptor( tensor_desc.get(), data_type, sizeof(dims) / sizeof(dims[0]), dims, strides)'
Traceback (most recent call last):
  File "asr/bin/train_keras.py", line 300, in <module>
    app.run(main)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "asr/bin/train_keras.py", line 236, in main
    model.fit(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 807, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2829, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1843, in _filtered_call
    return self._call_flat(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1923, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 545, in call
    outputs = execute.execute(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.CancelledError:  [_Derived_]RecvAsync is cancelled.
         [[{{node div_no_nan/ReadVariableOp_5/_1620}}]]
         [[GroupCrossDeviceControlEdges_2/Identity_6/_1699]] [Op:__inference_train_function_159151]

Function call stack:
train_function

Related Issues

@ravikyram
Copy link
Contributor

@stefan-falk

Please, share colab link or simple standalone code to reproduce the issue in our environment. It helps us in localizing the issue faster. Thanks!

@ravikyram ravikyram added stat:awaiting response Status - Awaiting response from author TF 2.3 Issues related to TF 2.3 labels Dec 11, 2020
@stefan-falk
Copy link
Author

@ravikyram I am going to try and build a small example but the nature of this issue does not make this easy, as nobody seems to know where this issue is coming from (if one follows what people post e.g. in #33721).

At this point I think it would help if somebody could help us track the source of this issue down on our side. E.g. somebody who has any clue whatsoever what RecvAsync means and does. From the name and the stack trace it seems to be an issue that occurs when using multiple GPUs although this might not always be the case?

Could you mention somebody here who you think could help us all track this issue down somehow?

@ravikyram ravikyram added comp:keras Keras related issues and removed stat:awaiting response Status - Awaiting response from author labels Dec 11, 2020
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Dec 11, 2020
@jvishnuvardhan
Copy link
Contributor

@stefan-falk Is it possible to test with recent TF versions (TF2.4rc4 and/or tf-nightly). Thanks!

@stefan-falk
Copy link
Author

@jvishnuvardhan I can certainly try that. Can I simply upgrade to TF2.4rc4 or does anything else change like CUDA/CUDNN?

@stefan-falk
Copy link
Author

stefan-falk commented Dec 14, 2020

@jvishnuvardhan Apparently I cannot do this with nightly just like so. It appears that there are some breaking changes which I have to adapt first.

Update:

I switched to TF2.4rc4 because tf-nightly seems to require more changes. The training is running. In a few hours I should see whether it is still crashing.

@stefan-falk
Copy link
Author

stefan-falk commented Dec 15, 2020

@jvishnuvardhan @ravikyram I have upgraded to 2.4.rc4 2.4.0 (stable) and started some experiments.

The problem persists though. After a few hours of training, the program crashes.

Is there anything else I can do i.o. to track this issue down?

@jvishnuvardhan jvishnuvardhan added the TF 2.4 for issues related to TF 2.4 label Dec 15, 2020
@stefan-falk
Copy link
Author

One of my latest changes was using tf.data.experimental.bucket_by_sequence_length. I've just removed it i.o. to see if the issue comes from there but it does not.

With that being said, I have no clue where this is coming from. I am 100% sure I didn't have this in 2.1.0 and I am not even sure if I had it with 2.3.0 but I certainly got that issue with 2.3.1. Not saying the problem is Tensorflow, maybe I am doing something wrong somewhere, but I have no idea how I could possibly track the source of this down.

@stefan-falk
Copy link
Author

stefan-falk commented Dec 17, 2020

@jvishnuvardhan @ravikyram I think I have fixed the issue. The root of this was bucket_by_sequence_length and me setting drop_remainder=False.

What seems to happen here is that there are batches which do not have enough samples s.t. there weren't enough examples for all cards. Since I set drop_remainder=True I don't get this error anymore.

I don't know whether it is possible to raise an error or log a warning in such a case because the current error message is not really a good indicator i.o. to get an idea where to look.

@stefan-falk
Copy link
Author

stefan-falk commented Dec 18, 2020

This code does not reproduces the error from above exactly but I think this is what is happening.

If I run the code below on e.g. 4 GPUs it will simply crash because there will be a batch with just one example.

I guess we can expect something like this but to me it was not very obvious where to look.

import tensorflow as tf  # v2.4.0
import numpy as np

from tensorflow import keras
from tensorflow.keras import layers


def sample_generator(nb_samples):

    for i in range(nb_samples):
        l = np.random.randint(6, 20)
        yield np.random.rand(l, 8), np.random.rand(1, 1)

    # One example for bucket (1, 5)
    yield np.random.rand(3, 8), np.random.rand(1, 1)


def sample_len(sample, *_):
    return tf.shape(sample)[0]


nb_replica = max(1, len(tf.config.experimental.list_physical_devices('GPU')))
assert nb_replica > 1, f'Number of GPUs must be >1 got {nb_replica}'

dataset = tf.data.Dataset.from_generator(
    lambda: sample_generator(500),
    output_types=(tf.float32, tf.float32),
    output_shapes=((None, 8), (None, 1))
)
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(options)

boundaries = [5, 10]
batch_sizes = [i * nb_replica for i in range(len(boundaries) + 1)]

bucketing = tf.data.experimental.bucket_by_sequence_length(
    sample_len,
    bucket_boundaries=boundaries,
    bucket_batch_sizes=batch_sizes,
    drop_remainder=True
)

dataset = dataset.apply(bucketing).repeat()

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    inputs = layers.Input(shape=(None, 8))
    x = inputs
    x = layers.LSTM(16)(x)
    x = layers.Dense(1)(x)
    model = keras.Model(inputs=inputs, outputs=x)
    model.compile(loss='mse')

model.fit(
    dataset,
    epochs=2,
    steps_per_epoch=100,
)

Output:

tensorflow.python.framework.errors_impl.InvalidArgumentError: 5 root error(s) found.
  (0) Invalid argument:  Window size must be greater than zero, but got 0.
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]]
	 [[IteratorGetNext]]
  (1) Invalid argument:  Window size must be greater than zero, but got 0.
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]]
	 [[IteratorGetNext]]
	 [[RMSprop/Cast_10/ReadVariableOp/_8]]
  (2) Invalid argument:  Window size must be greater than zero, but got 0.
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]]
	 [[IteratorGetNext]]
	 [[div_no_nan/ReadVariableOp_1/_64]]
  (3) Invalid argument:  Window size must be greater than zero, but got 0.
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]]
	 [[IteratorGetNext]]
	 [[group_deps/_111]]
  (4) Invalid argument:  Window size must be greater than zero, but got 0.
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]]
	 [[IteratorGetNext]]
	 [[RMSprop/Cast_3/ReadVariableOp/_6]]

@FanchenBao
Copy link

I was running RNN on Kaggle. The error message I encounter is

CancelledError:  [_Derived_]RecvAsync is cancelled.
	 [[{{node gradient_tape/sequential_4/embedding_4/embedding_lookup/Reshape/_24}}]] [Op:__inference_train_function_81303]

Apparently something was not right with the embedding layer. This is my embedding layer:

layers.Embedding(
    input_dim=SIZE_VOCAB,
    output_dim=EMBED_DIM,
    mask_zero=True,
    input_length=MAX_SEQ_LEN,
),

The suspect is mask_zero=True. I set it to true according to Understanding masking & padding, which allows the embedding layer to ignore the padded zeros.

After I comment out mask_zero=True, the error does not occur anymore when using GPU, with all default settings.

The code example is available here

@JustStas
Copy link

JustStas commented Feb 9, 2021

Also getting this error after training for ~6 epochs. I am using a GRU with embeddings (tf nightly gpu 2.5.0-dev20210115).
Error below:

CancelledError                            Traceback (most recent call last)
<timed exec> in <module>

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1133                 _r=1):
   1134               callbacks.on_train_batch_begin(step)
-> 1135               tmp_logs = self.train_function(iterator)
   1136               if data_handler.should_sync:
   1137                 context.async_wait()

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    795     tracing_count = self.experimental_get_tracing_count()
    796     with trace.Trace(self._name) as tm:
--> 797       result = self._call(*args, **kwds)
    798       compiler = "xla" if self._jit_compile else "nonXla"
    799       new_tracing_count = self.experimental_get_tracing_count()

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    823       # In this case we have created variables on the first call, so we run the
    824       # defunned version which is guaranteed to never create variables.
--> 825       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    826     elif self._stateful_fn is not None:
    827       # Release the lock early so that multiple threads can perform the call

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2970        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   2971     return graph_function._call_flat(
-> 2972         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   2973 
   2974   @property

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1946       # No tape is watching; skip to running the function.
   1947       return self._build_call_outputs(self._inference_function.call(
-> 1948           ctx, args, cancellation_manager=cancellation_manager))
   1949     forward_backward = self._select_forward_and_backward_functions(
   1950         args,

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    559               inputs=args,
    560               attrs=attrs,
--> 561               ctx=ctx)
    562         else:
    563           outputs = execute.execute_with_cancellation(

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

CancelledError:  [_Derived_]RecvAsync is cancelled.
	 [[{{node gradient_tape/model/txn_type_emb/embedding_lookup/Reshape/_352}}]] [Op:__inference_train_function_6250]

Function call stack:
train_function

@summa-code
Copy link

Any solution to this Multi GPU?

W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at cudnn_rnn_ops.cc:1562 : UNKNOWN: CUDNN_STATUS_BAD_PARAM in tensorflow/stream_executor/cuda/cuda_dnn.cc(1588): 'cudnnSetTensorNdDescriptor( tensor_desc.get(), data_type, sizeof(dims) / sizeof(dims[0]), dims, strides)' Traceback (most recent call last):

 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.CancelledError:  [_Derived_]RecvAsync is cancelled.
	 [[{{node div_no_nan_1/ReadVariableOp_3/_40}}]] [Op:__inference_test_function_29082]

I am using

tf.keras.utils.timeseries_dataset_from_array
from this link

@conorg000
Copy link

@stefan-falk thanks for digging into this! Definitely looks like a batch issue. I was using batch_size=2 across 2 GPUs and it was failing because, like you said, the final batch would only have 1 piece of data to send across both GPUs (odd number of training samples). Bigger batch size has solved the problem.

Thanks for your help :)

@pinesnow72
Copy link

pinesnow72 commented Mar 10, 2022

I have also suffered from the same issue in my NER model with mirrored strategy using 2 GPUs.
I have tested with the example code in Distributed training with Keras and get to a conclusion that this issue is related to the number of samples in the last batch of data.

The environment I used is as follows:
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
TensorFlow installed from (source or binary): binary
TensorFlow version (use command below): v2.5.0
Python version: 3.8
CUDA/cuDNN version: 11.2 (V11.2.152)
GPU model and memory: GeForce 1080, 11 GB * 2 EA

The example code uses the MNIST data, which has 60000 examples in training data and 10000 examples in test data.
I fixed the (global) BATCH_SIZE to 100, and changed the number of samples in training data in range of [59900 - 60000].
With the 2 GPUs, each GPU gets 50 samples in each step (this is the replica batch size). With changing the number of samples in training data, the number of samples in the last batch also changes in range of [1 - 100]. I set drop_remainder to False (the default value) when making batch:

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=False)

In this experiment, when the number of samples in the last batch was 1 to 50, model.fit caused the CUDNN_STATUS_BAD_PARAM, while the number between 51 to 100 does not cause any error (It learned correctly with 2 GPUs). From this experiment, I get to a conclusion that the issue seems to occur when the last batch has samples less than or equal to the replica batch size. It seems that each replica does not get the same number of samples from one batch of data (I mean that non-even distribution among replica). For example, assume 2 GPUs (GPU0 and GPU1), 100 global batch size, 50 samples in the last batch. GPU0 seems to get the all first 50 samples and GPU1 gets no samples. I am not sure but I think this situation seems to cause the issue. To prevent this situation, we could simply set drop_remainder to True as commented by @stefan-falk or select insufficient number of samples from training data (randomly or based on some criteria like sample's weight) and add them to training data to make all batch have the same number (i.e., global batch size) of samples. The former is a very simple solution but loses some number of training samples while the latter does not. Through this approach, I could resolved this issue in my NER model training.

Validation / test data may cause the same error. When I used the MNIST test data for validation with Model.fit(), the insufficient number of samples in the last validation batch does not cause any error. However, in my NER model training with CoNLL2003 data, the insufficient number of samples in the last validation batch caused the same error:

Function call stack:
test_function

For the validation or test data, I didn't use dropping or adding method because this method may make the result of evaluation incomparable. Instead of dropping or adding, I dynamically changed the validation or test data batch size by increasing one by one until the number of samples in the last batch gets larger than replica batch size. Through this method, I could solved the same issue in validation or test data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.3 Issues related to TF 2.3 TF 2.4 for issues related to TF 2.4 type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants