Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF 2.0] constant folding failed: invalid argument: unsupported type: 21 #29525

Closed
llan-ml opened this issue Jun 7, 2019 · 27 comments
Closed
Assignees
Labels
comp:grappler Grappler related issues comp:ops OPs related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@llan-ml
Copy link

llan-ml commented Jun 7, 2019

System information

  • TensorFlow installed from (source or binary): binary

  • TensorFlow version (use command below): tf-nightly-gpu-2.0-preview 2.0.0.dev20190606

  • Python version: 3.6.5

Code to reproduce the issue

import numpy as np
import tensorflow as tf


class Model(tf.keras.Model):
    def __init__(self):
        super(Model, self).__init__()
        self.dense = tf.keras.layers.Dense(10)

    def call(self, inputs):
        return self.dense(inputs)


model = Model()


def forward(x):
    batch_size = x.shape[0]
    ys = tf.TensorArray(tf.float32, size=batch_size)
    for i in tf.range(batch_size):
        y = model(x[i][tf.newaxis, :])
        ys = ys.write(i, y)
    return ys.stack()


def train(x, forward_func):
    with tf.GradientTape() as tape:
        ys = forward_func(x)
        loss = tf.reduce_mean(ys)
    grads = tape.gradient(loss, model.trainable_weights)
    return grads


def big_train(x):
    with tf.GradientTape() as tape:
        batch_size = x.shape[0]
        ys = tf.TensorArray(tf.float32, size=batch_size)
        for i in tf.range(batch_size):
            y = model(x[i][tf.newaxis, :])
            ys = ys.write(i, y)
        ys = ys.stack()
        loss = tf.reduce_mean(ys)
    grads = tape.gradient(loss, model.trainable_weights)
    return grads


x = np.random.rand(10, 5).astype(np.float32)

codes_buggy = [
    "tf.function(train)(x, forward)",
    "tf.function(big_train)(x)"
]

codes_normal = [
    "tf.function(train)(x, tf.function(forward))",
    "train(x, tf.function(forward))",
    "train(x, forward)",
    "big_train(x)"
]


def test(code):
    tf.print("==========================")
    tf.print(f"{code}:")
    exec(code)


test(codes_buggy[0])
test(codes_buggy[1])

test(codes_normal[0])
test(codes_normal[1])
test(codes_normal[2])
test(codes_normal[3])

Other info / logs
Print:

==========================
tf.function(train)(x, forward):
2019-06-07 16:46:23.314712: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant folding failed: Invalid argument: Unsupported type: 21
2019-06-07 16:46:23.357137: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant folding failed: Invalid argument: Unsupported type: 21
2019-06-07 16:46:23.460568: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcublas.so.10.0
==========================
tf.function(big_train)(x):
2019-06-07 16:46:24.139754: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant folding failed: Invalid argument: Unsupported type: 21
2019-06-07 16:46:24.180814: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant folding failed: Invalid argument: Unsupported type: 21
==========================
tf.function(train)(x, tf.function(forward)):
==========================
train(x, tf.function(forward)):
==========================
train(x, forward):
==========================
big_train(x):

Related to #28626 .

@vejvarm
Copy link

vejvarm commented Jun 10, 2019

Have the same issue on a TF2.0 GPU beta0. It really influences performance.

@llan-ml
Copy link
Author

llan-ml commented Jun 10, 2019

Hi @vejvarm What kind of performance do you mean? Training speed or accuracy?

@vejvarm
Copy link

vejvarm commented Jun 10, 2019

Hi @llan-ml,

sorry for not ellaborating on that. By performance I mean the training speed. If I remember correctly, with the warning it took about 2 seconds/batch while without it I'm at 2 to 4 batches/second. So roughly 4 to 8 times slowdown with the warning. Not really sure about a specific number, but it was significant.

As of accuracy, I haven't had the time to run the model for long enough to see if it has some inpact on that.

@gadagashwini-zz gadagashwini-zz self-assigned this Jun 14, 2019
@gadagashwini-zz
Copy link
Contributor

@llan-ml I tried to reproducing the issue on colab with latest tf-nightly-gpu-2.0-preview but i did not get any error. Can you try once and let us know if that still an issue. Thanks!

@gadagashwini-zz gadagashwini-zz added the stat:awaiting response Status - Awaiting response from author label Jun 14, 2019
@vejvarm
Copy link

vejvarm commented Jun 14, 2019

@llan-ml I tried to reproducing the issue on colab with latest tf-nightly-gpu-2.0-preview but i did not get any error. Can you try once and let us know if that still an issue. Thanks!

Just tried it and to my knowledge it is still there as of 2.0.0.dev20190614. It's just not written dirrectly to the cell output as it is not an error but a warning. It can be found in the runtime logs of the notebook:

WARNING | 2019-06-14  10:21:43.847057: E  tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant  folding failed: Invalid argument: Unsupported type: 21
-- | --

@llan-ml
Copy link
Author

llan-ml commented Jun 15, 2019

@gadagashwini I tested with 2.0.0.dev20190615, and the error still appears.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jun 15, 2019
@gadagashwini-zz gadagashwini-zz added TF 2.0 Issues relating to TensorFlow 2.0 comp:grappler Grappler related issues type:bug Bug labels Jun 17, 2019
@bionicles
Copy link

same issue on tf-gpu 1.14 now

@ymodak ymodak added comp:ops OPs related issues and removed comp:grappler Grappler related issues labels Jun 24, 2019
@ymodak ymodak assigned alextp and unassigned ymodak Jun 24, 2019
@alextp alextp assigned rmlarsen and unassigned alextp Jun 24, 2019
@alextp alextp added the comp:grappler Grappler related issues label Jun 24, 2019
@alextp
Copy link
Contributor

alextp commented Jun 24, 2019

@rmlarsen this looks like a grappler issue, can you triage?

@pandrey-fr
Copy link
Contributor

pandrey-fr commented Jun 28, 2019

I am having a similar issue, also on TensorFlow 2.0 beta with GPU enabled.

Interestingly, hiding the GPU away from Tensorflow (using export CUDA_VISIBLE_DEVICES=-1 before running the script) enables the code to run (but still prints out the error message this Issue is about, and feels slower than it should), while using the GPU results in a memory leakage that end up with the system crashing due to the GPU memory being saturated.

System information

  • OS Platform and Distribution: Linux Mint 19.1
  • TensorFlow installed from: binary (using pip)
  • TensorFlow version: v2.0.0-beta0-16-g1d91213fe7
  • Python version: 3.6.8
  • CUDA/cuDNN version: 10.0 / 7
  • GPU model and memory: QUADRO P-1000 with 4 GB of dedicated RAM (+ 16 GB of system RAM)

Code to reproduce the issue

import numpy as np
import tensorflow as tf


def build_autoencoder(input_dim, embed_dim=100): 
    """Set up an auto-encoder model made of two BiLSTM layers.""" 
    # Set up input tensors.
    inputs = tf.keras.Input((None, input_dim), dtype=tf.float32) 
    mask = tf.keras.Input((None,), dtype=tf.bool) 
    # Set up encoder and decoder BiLSTM layers.
    encoder = tf.keras.layers.Bidirectional( 
        tf.keras.layers.LSTM(embed_dim, return_sequences=True),
        merge_mode='sum' 
    ) 
    decoder = tf.keras.layers.Bidirectional( 
        tf.keras.layers.LSTM(input_dim, return_sequences=True),
        merge_mode='sum' 
    ) 
    # Build the outputs tensor.
    outputs = decoder(encoder(inputs, mask=mask), mask=mask) 
    # Set up, compile and return the model.
    model = tf.keras.Model(inputs=[inputs, mask], outputs=outputs) 
    model.compile('adam', tf.keras.losses.mse) 
    return model


def build_mock_data(dim, nsamples, maxlen, seed=0):
    """Build some mock data for bug demonstration purposes.

    Return an array of zero-padded sequences of random
    actual length, and an associated boolean mask Tensor.
    
    Use a random seed for reproducibility.
    """
    np.random.seed(seed)
    sizes = np.random.choice(maxlen, size=nsamples)
    inputs = np.random.normal(size=(nsamples, max(sizes), dim))
    for i, size in enumerate(sizes):
        inputs[i, size:] = 0.
    mask = tf.sequence_mask(sizes, dtype=tf.bool)
    return inputs.astype(np.float32), mask


if __name__ == '__main__':
    # Generate the mock data. Instantiate the mdoel.
    inputs, mask = build_mock_data(dim=100, nsamples=64, maxlen=500, seed=0)
    model = build_autoencoder(input_dim=100, embed_dim=50)

    # This works fine.
    model.predict([inputs, mask])

    # This also works.
    model.evaluate([inputs, mask], inputs)

    # This is where things go wrong.
    model.fit([inputs, mask], inputs)

Error with GPU enabled

Train on 64 samples
2019-06-28 17:11:29.014129: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant folding failed: Invalid argument: Unsupported type: 21
2019-06-28 17:11:29.778881: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant folding failed: Invalid argument: Unsupported type: 21
2019-06-28 17:11:42.317560: E tensorflow/stream_executor/cuda/cuda_driver.cc:890] failed to alloc 8589934592 bytes on host: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2019-06-28 17:11:42.317583: W ./tensorflow/core/common_runtime/gpu/gpu_host_allocator.h:44] could not allocate pinned host memory of size: 8589934592
2019-06-28 17:11:42.317604: E tensorflow/stream_executor/cuda/cuda_driver.cc:890] failed to alloc 7730940928 bytes on host: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2019-06-28 17:11:42.317609: W ./tensorflow/core/common_runtime/gpu/gpu_host_allocator.h:44] could not allocate pinned host memory of size: 7730940928
2019-06-28 17:11:53.241866: W tensorflow/core/common_runtime/bfc_allocator.cc:314] Allocator (GPU_0_bfc) ran out of memory trying to allocate 12.5KiB (rounded to 12800).  Current allocation summary follows.
Killed

Error with GPU disabled

Train on 64 samples
2019-06-28 17:20:25.088606: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant folding failed: Invalid argument: Unsupported type: 21
2019-06-28 17:20:25.709958: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant folding failed: Invalid argument: Unsupported type: 21
64/64 [==============================] - 3s 44ms/sample - loss: 0.4860

@pandrey-fr
Copy link
Contributor

Hi,

I did some additional testing based on my previous bug-yielding example and would like to report on it, in hope that it may help track down, and ultimately fix, the issue at stake.

Setting and consequences

What I did was getting rid of sequences masking for the BiLSTM layers, thus using a less-general model expecting batches of same-length sequences. In this case, I no longer encounter GPU memory leakage (at least, not something that would make my computer crash on the first run of fitting the model), however an optimization warning is raised - and I have no idea whether it relates to the initial issue or not. It shows up both with and without enabling the use of the GPU, and for each use of the model (not just for the fitting process).

Warning message

2019-07-01 09:22:25.637712: W tensorflow/core/grappler/optimizers/implementation_selector.cc:199] Skipping optimization due to error while loading function libraries: Invalid argument: Functions '__inference___backward_cudnn_lstm_860_1038' and '__inference___backward_cudnn_lstm_860_1038_specialized_for_Adam_gradients_encoder_StatefulPartitionedCall_1_grad_StatefulPartitionedCall_at___inference_keras_scratch_graph_5563' both implement 'lstm_7a1d4064-50de-41c0-86d3-5a99f303e8d7' but their signatures do not match.

Code

In the code below, I allow distinct batches to contain sequences of different length, however I also made a test using a strict parameter (i.e. setting the InputLayer's shape to (length, input_dim) with length an integer instead of None), which yields exactly the same error message.

import numpy as np
import tensorflow as tf


def build_autoencoder(input_dim, embed_dim=100): 
    """Set up an auto-encoder model made of two BiLSTM layers.""" 
    # Set up the input tensor.
    inputs = tf.keras.Input((None, input_dim), dtype=tf.float32) 
    # Set up encoder and decoder BiLSTM layers.
    encoder = tf.keras.layers.Bidirectional( 
        tf.keras.layers.LSTM(embed_dim, return_sequences=True),
        merge_mode='sum', name='encoder'
    ) 
    decoder = tf.keras.layers.Bidirectional( 
        tf.keras.layers.LSTM(input_dim, return_sequences=True),
        merge_mode='sum', name='decoder'
    ) 
    # Build the outputs tensor.
    outputs = decoder(encoder(inputs)) 
    # Set up, compile and return the model.
    model = tf.keras.Model(inputs=inputs, outputs=outputs) 
    model.compile('adam', tf.keras.losses.mse) 
    return model


def build_mock_data(dim, nsamples, length, seed=0):
    """Build some mock data for bug demonstration purposes.

    Return an array of shape (nsamples, length, dim) filled
    with random normally-distributed data.
    
    Use a random seed for reproducibility.
    """
    np.random.seed(seed)
    return np.random.normal(size=(nsamples, length, dim))


if __name__ == '__main__':
    # Generate the mock data. Instantiate the mdoel.
    inputs = build_mock_data(dim=100, nsamples=64, length=500, seed=0)
    model = build_autoencoder(input_dim=100, embed_dim=50)

    # This works but prints the error warning.
    model.predict(inputs)

    # Same thing here.
    model.evaluate(inputs, inputs)

    # Same thing here.
    model.fit(inputs, inputs)

I hope this helps solving the initial issue. Please let me know if there is any additional info I can provide or test I can run to help. At the moment, not being able to fit models with LSTM layers using properly-masked variable-length sequences is quite an issue to put code into production under TensorFlow 2.0. I know this is the whole point of a beta release (having a not-yet-quite-stable version out to identify issued that need solving before the actual release), but the programming logic has been so greatly altered as compared with TF 1.x that it would also be unpractical not to start taking the step (getting used to Eager execution demands an important effort, after having extensively used the low-level placeholder / session API)...

@pandrey-fr
Copy link
Contributor

Note: this issue is quite similar to the newly-opened #30263

@pandrey-fr
Copy link
Contributor

Additional test/results (sorry for the multiplication of messages - I really want to provide as much info as possible, hoping it can help solve the issue):

  • Changing my code to feed the model with a numpy array containing the batched sequences' lengths (then converting it to a sequence mask Tensor using tf.sequence_mask within the model) partly fixes the issue.

  • E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] constant folding failed: Invalid argument: Unsupported type: 21 still shows up twice when first calling the fit method of the model built (both when using GPU or CPU-only)

  • The model can be run and fit, both with and without the GPU. After the first call to fit, the error message no longer shows up, and I can see the loss decreasing along the iterations (up to some point).

Code:

def build_autoencoder(input_dim, embed_dim=100): 
    """Set up an auto-encoder model made of two BiLSTM layers.""" 
    # Set up the input tensors.
    inputs = tf.keras.Input((None, input_dim), dtype=tf.float32)
    sizes = tf.keras.Input((), dtype=tf.int32)
    # Set up encoder and decoder BiLSTM layers.
    encoder = tf.keras.layers.Bidirectional( 
        tf.keras.layers.LSTM(embed_dim, return_sequences=True),
        merge_mode='sum', name='encoder'
    ) 
    decoder = tf.keras.layers.Bidirectional( 
        tf.keras.layers.LSTM(input_dim, return_sequences=True),
        merge_mode='sum', name='decoder'
    ) 
    # Build the outputs tensor.
    mask = tf.sequence_mask(sizes, maxlen=tf.shape(inputs)[1])
    outputs = decoder(encoder(inputs, mask=mask), mask=mask) 
    # Set up, compile and return the model.
    model = tf.keras.Model(inputs=[inputs, sizes], outputs=outputs) 
    model.compile('adam', tf.keras.losses.mse) 
    return model

def build_mock_data(dim, nsamples, maxlen, seed=0):
    """Build some mock data for bug demonstration purposes.

    Return an array of zero-padded sequences of random
    actual length, and an array containing those lengths.
    
    Use a random seed for reproducibility.
    """
    np.random.seed(seed)
    sizes = np.random.choice(maxlen, size=nsamples)
    inputs = np.random.normal(size=(nsamples, max(sizes), dim))
    for i, size in enumerate(sizes):
        inputs[i, size:] = 0.
    return inputs.astype(np.float32), sizes

if __name__ == '__main__':
    # Generate the mock data. Instantiate the mdoel.
    inputs, sizes = build_mock_data(dim=100, nsamples=64, maxlen=500, seed=0)
    model = build_autoencoder(input_dim=100, embed_dim=50)

    # This works fine.
    model.predict([inputs, sizes])

    # This also works.
    model.evaluate([inputs, sizes], inputs)

    # This prints out the error messages, but works.
    model.fit([inputs, sizes], inputs)

    # Further calls no longer print errors, and the loss decreases.
    model.fit([inputs, sizes], inputs)
    model.fit([inputs, sizes], inputs)
    model.fit([inputs, sizes], inputs)

Conclusion:

  • So, I guess the initial issue (the error showing up) is not solved.

  • There is, additionally, the issue previously pointed out (and also object of issue [TF2.0]: Skipping optimization due to error while loading function #30263) of an optimization error (and apparent failure to fit models) when using fixed-length sequences.

  • However, the GPU memory issue I was personally encountering seems to have been related to the use of a Tensor (instead of a numpy array) as input to my model. I don't know whether this is by design (in which case it might be worth it to add warnings when users do that?) or a separate issue, but I was able to fix it with better code design.

@ialdencoots
Copy link

I also run into this issue when using masking on a GRU/LSTM layer, though running on CPU does not prevent the memory from blowing up and crashing the machine. In fact, even when running on GPU, system memory maxes out, though it looks as though the printed errors imply that GPU memory has been completely filled as well. Removing the masking, however allows training to occur without issue, though the "constant folding failed: Invalid argument: Unsupported type: 21" message still occurs.

@pandrey-fr
Copy link
Contributor

Hi,
Could anyone from the TF team confirm that this issue is being researched / worked out? It appears that it does not show in every setting (thus the difficulty to pass "front-row" issues screeners, as in the newly opened #30533), but it causes major performance issues to people who are confronted to it (see my performance tests on issue #30263). It should also be noted that this not only affects TF 2.0, but also 1.14 when Eager execution is enabled...

@pandrey-fr
Copy link
Contributor

Hi, I can confirm that this is not inherent to the LSTM implementation. I have just reproduced the same error with GRU. Maybe it has something to do with the gpu optimized CuDNN implementations?

Thank you for sharing this. The issue seems to be at the grappler level, which if I am not mistaken is indeed the mechanism that chooses the backend kernel to use, which can be a CuDNN one...

@pandrey-fr
Copy link
Contributor

Interestingly, I am encountering this issue in TF 1.14, in TF 2.0b1 installed through pip, but not in TF 2.0b1 installed from source using the r2.0 branch, and not always in TF 2.0b1 installed from source using yesterday's state of the master branch.

Using this issue's code, on the latter installation, I have a distinct bug, namely repeated prints similar to W tensorflow/core/grappler/costs/virtual_scheduler.cc:794] Output node [ gradients/while_grad/while_grad/gradients/zeros_1_switch/_43 ] has alread seen this input node [ gradients/while_grad/while_grad/merge/_25 -- possibly due to Swith-Merge in previous nodes. Skip to increment num_inputs_ready.

@jkamalu
Copy link

jkamalu commented Jul 11, 2019

Edit: I should note that I am running on the gpu nightly pip build as of the time stamp on this comment.

Another interesting piece of narrowing information. In the piece of code below, everything runs without a hitch if the for loop (tf.while_loop behind the scenes) is removed. That is...

Without for loop: tf function routine runs twice, code runs ad infinitum

With for loop: tf function routine runs twice, graph placement issue and and code breaks

Here's the code:

https://github.com/jkamalu/tensorflow_bugs/blob/master/LSTMGraphPlacement.py

Another thing worth noting is that this issue appears even without the while loop with tensorflow GPU distributed strategies as seen #29189

@pandrey-fr a note: if the cudnn implementation is not important to you (I don't know why it wouldn't be, but just in case), you can wrap the LSTMCell layer in the RNN layer and it works fine... another hint that this error might be in the optimized implementation.

@rmlarsen
Copy link
Member

The warning should go away in the next nightly. I'm looking into the original issue with unsupported types in constant folding.

@rmlarsen
Copy link
Member

rmlarsen commented Jul 11, 2019

The issue is that the error handling in many places in Grappler is much too conservative. In this case we bail completely out of folding because we fail to convert a constant of an unknown type early. I'll work on making the code more robust in this sense.

@rmlarsen
Copy link
Member

rmlarsen commented Jul 11, 2019

The particular error in this case was due to ZerosLike being overloaded for DT_VARIANT types: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/constant_op.cc#L267

I am submitting a fix now.

@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

tensorflow-copybara pushed a commit that referenced this issue Jul 11, 2019
…or DT_RESOURCE.

Addresses: #29525

PiperOrigin-RevId: 257703800
@rmlarsen
Copy link
Member

Fix submitted: 2417464

@pandrey-fr
Copy link
Contributor

Great, thank you @rmlarsen!

@pandrey-fr
Copy link
Contributor

pandrey-fr commented Jul 12, 2019

As announced by @rmlarsen, the fix (which is now included in the nightly build) removes the error message ; however it appears (at least in my case) that LSTM layers with masking still won't be moved to the GPU (when Eager is enabled at least - I am still trying to figure out whether it is the case with Eager disabled), which is somehow confusing. Do you have any idea why this is the case?

@jkamalu
Copy link

jkamalu commented Jul 12, 2019

Do you mean they won't be moved to the GPU or that the graph won't be built with the CuDNN implementation? My bootleg LSTM layers (see below) exist on the GPU with the standard implementation (I verify this by watching nvidia-smi). I use masking (right-padding, so TF v2.0 CuDNN compatible), but end up having to use RNN-wrapped LSTMCell instances, which don't use the CuDNN implementation.

It should be noted that in a while loop for dynamic decoding, the GPU enabled, CuDNN compatible tf.keras.layers.LSTM implementation does not function, nor does this specific setup work (even without the while loop) on multiple GPUs via a distributed strategy.

@pandrey-fr
Copy link
Contributor

To be honest I am not quite sure... What I did was using a tf.keras.callbacks.TensorBoard callback to trace the fitting of my model (in Eager mode), and I found out that on TensorBoard the LSTM unit is represented in a different color than the other bits (when I set the visualization parameter to "device used"), with the other bits' color being labeled "GPU:0". I also verified that when I use custom layers of mine that make use of masking, they are clearly drawn to have been placed on the GPU.

If you have any advice as to how to properly keep track of where operations are being performed (maybe also when Eager execution is disabled), I would be glad to use them!

@jkamalu
Copy link

jkamalu commented Aug 13, 2019

Hi @rmlarsen, I just wanted to let you know that errors resembling this decrease in speed were reintroduced by later nightly builds. This isn't a request for a fix (I downgraded to the July 24 nightly and everything works fine now), but I thought you might like to know just in case it's a simple thing.

With the same code (multi-gpu setting on TF v2 with LSTM) ...

On the July 24 build... model trains quickly on all GPUs and is correct and gives spurious error messages

2019-08-13 11:31:15.551518: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] implementation_selector failed: Invalid argument: Invalid format of input node name: Expected: {forward_node_name}:{index}
2019-08-13 11:31:32.258063: W tensorflow/core/grappler/optimizers/implementation_selector.cc:310] Skipping optimization due to error while loading function libraries: Invalid argument: Functions '__inference_cudnn_lstm_with_fallback_186209' and '__inference_standard_lstm_185862_specialized_for_model_lstm_2_StatefulPartitionedCall_at___inference_step_311955' both implement 'lstm_03256996-2770-4288-91ed-338407bd3cc3' but their signatures do not match.

On the August 12 build... model trains on all GPUs and is correct but takes ~50 times more time. Not an exaggeration.

2019-08-13 09:39:40.107723: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] function_optimizer failed: Invalid argument: Input 1 of node se_q3/seq_encoder/while/body/_1/TensorListPushBack_42 was passed float from se_q3/seq_encoder/while/body/_1/decoder_c/lstm_3/StatefulPartitionedCall:9 incompatible with expected variant.
2019-08-13 09:39:53.596733: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] function_optimizer failed: Invalid argument: Input 1 of node se_q3/seq_encoder/while/body/_1/TensorListPushBack_42 was passed float from se_q3/seq_encoder/while/body/_1/decoder_c/lstm_3/StatefulPartitionedCall:9 incompatible with expected variant.
2019-08-13 09:39:58.604309: W tensorflow/core/common_runtime/process_function_library_runtime.cc:686] Ignoring multi-device function optimization failure: Invalid argument: Input 1 of node se_q3/seq_encoder/while/body/_1/TensorListPushBack_69 was passed float from se_q3/seq_encoder/while/body/_1/decoder_c/lstm_2/StatefulPartitionedCall:9 incompatible with expected variant.

danieldk added a commit to stickeritis/sticker that referenced this issue Nov 4, 2019
Recent versions of Tensorflow Keras will automatically switch
between cuDNN and Tensorflow implementations. The trained parameters
work regardless of the selected implementation.

The conditions for using the cuDNN implementation are documented at:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM

They boil down to: 1. a NVIDIA GPU is available, 2. certain hyper
parameters (e.g. activations) are set to specific values. If the cuDNN
implementation is selected, this results in a nice speedup.

The Tensorflow requirements are bumped to 1.15.0. This setup fails with
1.14.0 with a constant folding error in Grappler:

tensorflow/tensorflow#29525
danieldk added a commit to stickeritis/sticker that referenced this issue Nov 5, 2019
Recent versions of Tensorflow Keras will automatically switch
between cuDNN and Tensorflow implementations. The trained parameters
work regardless of the selected implementation.

The conditions for using the cuDNN implementation are documented at:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM

They boil down to: 1. a NVIDIA GPU is available, 2. certain hyper
parameters (e.g. activations) are set to specific values. If the cuDNN
implementation is selected, this results in a nice speedup.

The Tensorflow requirements are bumped to 1.15.0. This setup fails with
1.14.0 with a constant folding error in Grappler:

tensorflow/tensorflow#29525
danieldk added a commit to stickeritis/sticker that referenced this issue Nov 5, 2019
Recent versions of Tensorflow Keras will automatically switch
between cuDNN and Tensorflow implementations. The trained parameters
work regardless of the selected implementation.

The conditions for using the cuDNN implementation are documented at:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM

They boil down to: 1. a NVIDIA GPU is available, 2. certain hyper
parameters (e.g. activations) are set to specific values. If the cuDNN
implementation is selected, this results in a nice speedup.

The Tensorflow requirements are bumped to 1.15.0. This setup fails with
1.14.0 with a constant folding error in Grappler:

tensorflow/tensorflow#29525
danieldk added a commit to stickeritis/sticker that referenced this issue Nov 5, 2019
Recent versions of Tensorflow Keras will automatically switch
between cuDNN and Tensorflow implementations. The trained parameters
work regardless of the selected implementation.

The conditions for using the cuDNN implementation are documented at:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM

They boil down to: 1. a NVIDIA GPU is available, 2. certain hyper
parameters (e.g. activations) are set to specific values. If the cuDNN
implementation is selected, this results in a nice speedup.

The Tensorflow requirements are bumped to 1.15.0. This setup fails with
1.14.0 with a constant folding error in Grappler:

tensorflow/tensorflow#29525
danieldk added a commit to stickeritis/sticker that referenced this issue Nov 7, 2019
Recent versions of Tensorflow Keras will automatically switch
between cuDNN and Tensorflow implementations. The trained parameters
work regardless of the selected implementation.

The conditions for using the cuDNN implementation are documented at:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM

They boil down to: 1. a NVIDIA GPU is available, 2. certain hyper
parameters (e.g. activations) are set to specific values. If the cuDNN
implementation is selected, this results in a nice speedup.

The Tensorflow requirements are bumped to 1.15.0. This setup fails with
1.14.0 with a constant folding error in Grappler:

tensorflow/tensorflow#29525
danieldk added a commit to stickeritis/sticker that referenced this issue Nov 7, 2019
Recent versions of Tensorflow Keras will automatically switch
between cuDNN and Tensorflow implementations. The trained parameters
work regardless of the selected implementation.

The conditions for using the cuDNN implementation are documented at:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM

They boil down to: 1. a NVIDIA GPU is available, 2. certain hyper
parameters (e.g. activations) are set to specific values. If the cuDNN
implementation is selected, this results in a nice speedup.

The Tensorflow requirements are bumped to 1.15.0. This setup fails with
1.14.0 with a constant folding error in Grappler:

tensorflow/tensorflow#29525
danieldk added a commit to stickeritis/sticker that referenced this issue Nov 7, 2019
Recent versions of Tensorflow Keras will automatically switch
between cuDNN and Tensorflow implementations. The trained parameters
work regardless of the selected implementation.

The conditions for using the cuDNN implementation are documented at:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM

They boil down to: 1. a NVIDIA GPU is available, 2. certain hyper
parameters (e.g. activations) are set to specific values. If the cuDNN
implementation is selected, this results in a nice speedup.

The Tensorflow requirements are bumped to 1.15.0. This setup fails with
1.14.0 with a constant folding error in Grappler:

tensorflow/tensorflow#29525
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:grappler Grappler related issues comp:ops OPs related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests