Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] layout failed: Invalid argument: size of values 0 does not match size of permutation 4. #34499

Closed
wkdgnsgo opened this issue Nov 21, 2019 · 14 comments
Assignees
Labels
comp:grappler Grappler related issues comp:ops OPs related issues stat:awaiting response Status - Awaiting response from author TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@wkdgnsgo
Copy link

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary):
  • TensorFlow version (use command below): 2.0
  • Python version: 3.7
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory: RTX 2080Ti 11GB

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" 2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior

Originally, I built this model in tensorflow 1.1x and I transferred the model to TF 2.0 manually to use tf.keras. It is working but it shows me this error message (E tensorflow/core/grappler/optimizers/meta_optimizer.cc:502] layout failed: Invalid argument: size of values 0 does not match size of permutation 4.) and its performance is worse than tf 1.1x.

I suspect that this error interrupts to train somehow.

I didn't put any permutation layer in my model. It is hard to find it.

Describe the expected behavior

Code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate the problem.

Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

@jjqtony
Copy link

jjqtony commented Nov 24, 2019

same issue for me.

@gadagashwini-zz gadagashwini-zz self-assigned this Nov 25, 2019
@gadagashwini-zz gadagashwini-zz added TF 2.0 Issues relating to TensorFlow 2.0 comp:grappler Grappler related issues labels Nov 25, 2019
@gadagashwini-zz
Copy link
Contributor

@wkdgnsgo, Please paste the standalone code to reproduce the issue.
Follow the instructions mentioned in the Tensorflow site to migrate from TF1 to TF2. Thanks!

@gadagashwini-zz gadagashwini-zz added the stat:awaiting response Status - Awaiting response from author label Nov 25, 2019
@miguelalba96
Copy link

I have the exact same issue training a cycleGAN, my operations are built in tf.Module and I’m using tf.optimizers.Adam

@MadsAdrian
Copy link

I get the same error message(/warning?). Python 3.6.8 [GCC 8.3.0], tf.version v2.0.0-rc2-26-g64c3d38 2.0.0 in the official tf docker image.

Quite a large I2I translation project, so it's hard to make minimal reproducing code. An issue that might be related is a shape error that is produced if the train function is decorated with @tf.function, which works in EE mode.

A par of subclassed tf.keras.Models f_x, f_y is used with the train step pattern

self.loss_object = tf.keras.losses.MeanSquaredError()
self.optimizer = tf.keras.optimizers.Adam(tf.keras.optimizers.schedules.ExponentialDecay(...))

def train_step(x,y,cross_loss_weight)
    with tf.GradientTape as tape: 
        y_hat, x_hat = f_x(x), f_y(y)
        x_tilde, y_tilde = f_y(y_hat), f_x(x_hat)
        fx_loss = {
                "cross": self.loss_object(y, y_hat, clw),
                "cycle": self.loss_object(x, x_tilde),
                "reg": sum(self._fx.losses),  # regularization from submodel
        }
        fx_loss = {key: self.lambdas[key] * value for key, value in fx_loss.items()}
        fx_total_loss = sum(fx_loss.values())
        ... same for f_y...
    gradient_targets = self._fx.trainable_variables + self._fy.trainable_variables
    gradients = tape.gradient(loss_value, gradient_targets)
    self.optimizer.apply_gradients(zip(gradients, gradient_targets))

So it seems the use is similar to @miguelalba96 with a cyclic loss term, but no adversarial term in my case. The model is hardly a DAG, but rather two DAGs trained in conjunction. Can the issue be related to this?

@AntoinePlumerault
Copy link

AntoinePlumerault commented Dec 3, 2019

I am training a VAE and I encountered the same issue and managed to remove the error but I do not understand what caused the issue. When I run this code I get the error:

@tf.function
def forward(x_real):
    eps = tf.random.normal([FLAGS.batch_size, FLAGS.latent_dim])
    z_mu, z_log_sigma = E(x_real, training=True)
    z = z_mu + tf.exp(z_log_sigma) * eps
    x_real_mu = D(z, training=True)

    kl_loss = tf.reduce_mean(
        kl_divergence(z_mu, z_log_sigma))
    ll_loss = tf.reduce_mean(
        negative_log_likelyhood(x_real, x_real_mu))
    return x_real_mu, ll_loss, kl_loss

but the error vanishes if I run this one:

@tf.function
def forward(x_real):
    eps = tf.random.normal([FLAGS.batch_size, FLAGS.latent_dim])
    z_mu, z_log_sigma = E(x_real, training=True)
    z = z_mu + tf.exp(z_log_sigma) * eps
    x_real_mu = D(z, training=True)

    kl_loss = tf.reduce_mean(
        kl_divergence(z_mu, z_log_sigma))
    ll_loss = tf.reduce_mean(
        negative_log_likelyhood(x_real+0.0, x_real_mu)) # <-- change here
    return x_real_mu, ll_loss, kl_loss

I wonder why the first version raises the error...

@gadagashwini-zz gadagashwini-zz added comp:ops OPs related issues type:bug Bug labels Dec 5, 2019
@gadagashwini-zz
Copy link
Contributor

@AntoinePlumerault, Could provide the complete code snippet to replicate the reported issue. Thanks!

@gadagashwini-zz
Copy link
Contributor

Automatically closing due to lack of recent activity. Please update the issue when new information becomes available, and we will reopen the issue. Thanks!

@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@ktatar
Copy link

ktatar commented Dec 18, 2019

I confirm the same issue. I am working with python 3.7 and tensorflow-gpu 2.0 release installed from pip.
Here is the code to reproduce the issue. You would need a folder of images with the same size to run this code:

CVAE-model-error-code.txt

As @AntoinePlumerault mentioned, I could solve the issue by changing line 173,
cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)

to this,

cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x+0.0)

@naturomics
Copy link

Similar issue. Here is my solution. In my case the error was caused by the following code:

def _ndtr(self, x):
  # x is a tensor with shape [bs, height, width, channels]
  half_sqrt_2 = 0.5 * np.sqrt(2.)
  w = x * half_sqrt_2
  z = tf.abs(w)
  y = tf.where(z < half_sqrt_2, 1. + tf.math.erf(w), tf.where(w > 0., 2. - tf.math.erfc(z), tf.math.erfc(z)))
  return 0.5 * y

Which gives layout failed: Invalid argument: Size of values 0 does not match size of permutation 4 @ fanin shape ingradient_tape/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer.

Looks like in the layout optimization stage tensors are transposed from NHWC to NCHW for performance, but some are failed.
So transposing it manually will solve:

def _ndtr(self, x):
  # x is a tensor with shape [bs, height, width, channels]
  x = tf.transpose(x, [0, 3, 1, 2])
  half_sqrt_2 = 0.5 * np.sqrt(2.)
  w = x * half_sqrt_2
  z = tf.abs(w)
  y = tf.where(z < half_sqrt_2, 1. + tf.math.erf(w), tf.where(w > 0., 2. - tf.math.erfc(z), tf.math.erfc(z)))
  y = tf.transpose(y, [0, 2, 3, 1])
  return 0.5 * y

The best solution is writing your model in NCHW format, so you can skip the layout optimization stage. Here a tutorial on what tf does in graph optimization stage.

@YanSte
Copy link

YanSte commented Aug 4, 2023

This problem seems to come back on the latest TF versions.

@edumotya
Copy link

Another workaround is to disable the layout_optimizer for the problematic layers. Following https://www.tensorflow.org/guide/graph_optimization, wrapping the problematic code block with with options({"layout_optimizer": False}): does the trick.

@stanlee321
Copy link

I solve it using this workaround in colab

file_path = '/content/models/research/object_detection/model_main_tf2.py'

# Read the content of the file
with open(file_path, 'r') as file:
    content = file.readlines()



# Find the index of the import tensorflow line
import_line_index = next(i for i, line in enumerate(content) if 'import tensorflow' in line)

# Define the fix command to insert
fix_command = "tf.config.optimizer.set_experimental_options({'layout_optimizer': False})\n"

# Insert the fix command after the TensorFlow import
content.insert(import_line_index + 1, fix_command)


# Write the modified content back to the file
with open(file_path, 'w') as file:
    file.writelines(content)

Korred added a commit to Korred/unet-pp that referenced this issue Mar 15, 2024
- added option for batch_normalization
- added option for different activation functions (https://keras.io/api/layers/activations/#available-activations)
- added option for dropout (generic/pixel-wise and spatial/feature-wise) including the option to adjust the dropout rate
  - had to add a workaround to get dropout to work as mentioned here: tensorflow/tensorflow#34499 (comment)
- simplified _conv_block
- removed outdated comments
- improve typing
Korred added a commit to Korred/unet-pp that referenced this issue Mar 15, 2024
- added option for batch_normalization
- added option for different activation functions (https://keras.io/api/layers/activations/#available-activations)
- added option for dropout (generic/pixel-wise and spatial/feature-wise) including the option to adjust the dropout rate
  - had to add a workaround to get dropout to work as mentioned here: tensorflow/tensorflow#34499 (comment)
- simplified _conv_block
- removed outdated comments
- improve typing
Korred added a commit to Korred/unet-pp that referenced this issue Mar 27, 2024
- added option for batch_normalization
- added option for different activation functions (https://keras.io/api/layers/activations/#available-activations)
- added option for dropout (generic/pixel-wise and spatial/feature-wise) including the option to adjust the dropout rate
  - had to add a workaround to get dropout to work as mentioned here: tensorflow/tensorflow#34499 (comment)
- simplified _conv_block
- removed outdated comments
- improve typing
@nngsam
Copy link

nngsam commented May 2, 2024

def _ndtr(self, x):
  # x is a tensor with shape [bs, height, width, channels]
  x = tf.transpose(x, [0, 3, 1, 2])
  half_sqrt_2 = 0.5 * np.sqrt(2.)
  w = x * half_sqrt_2
  z = tf.abs(w)
  y = tf.where(z < half_sqrt_2, 1. + tf.math.erf(w), tf.where(w > 0., 2. - tf.math.erfc(z), tf.math.erfc(z)))
  y = tf.transpose(y, [0, 2, 3, 1])
  return 0.5 * y

The best solution is writing your model in NCHW format, so you can skip the layout optimization stage. Here a tutorial on what tf does in graph optimization stage.

Already did this and still the problem. I reduced the batch size and it is running, praying to run smoothly now. Anyone share your experience ^^

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:grappler Grappler related issues comp:ops OPs related issues stat:awaiting response Status - Awaiting response from author TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests