Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.keras.layers.BatchNormalization() throws TypeError: Incompatible types: <dtype: 'resource'> vs. int64. Value is 0 #31894

Closed
notabee opened this issue Aug 22, 2019 · 9 comments
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 1.14 for issues seen with TF 1.14 type:bug Bug

Comments

@notabee
Copy link

notabee commented Aug 22, 2019

import tensorflow as tf
batch_size = 20
inp = tf.placeholder(tf.float32, [batch_size, 19, 64, 64, 3])
out = tf.placeholder(tf.float32, [batch_size, 19, 60, 60, 16])
def model(inp):

  enc = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(128, activation='relu', kernel_size=3,kernel_initializer='glorot_uniform'))(inp)
  enc = tf.keras.layers.TimeDistributed(tf.keras.layers.BatchNormalization())(enc)
  enc = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(16, activation='relu',kernel_size=3,kernel_initializer='glorot_uniform'))(enc)
  return enc

pred = model(inp)

loss = tf.reduce_mean(tf.keras.backend.binary_crossentropy(out, pred))
lr = 0.0001
train_op = tf.train.AdamOptimizer(lr).minimize(loss)

Throws error::

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-43-c9408a385d78> in <module>()
      1 lr = 0.0001
----> 2 train_op = tf.train.AdamOptimizer(lr).minimize(reconstuction_loss)

7 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/optimizer.py in minimize(self, loss, global_step, var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, name, grad_loss)
    401         aggregation_method=aggregation_method,
    402         colocate_gradients_with_ops=colocate_gradients_with_ops,
--> 403         grad_loss=grad_loss)
    404 
    405     vars_with_grad = [v for g, v in grads_and_vars if g is not None]

/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/optimizer.py in compute_gradients(self, loss, var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, grad_loss)
    510         gate_gradients=(gate_gradients == Optimizer.GATE_OP),
    511         aggregation_method=aggregation_method,
--> 512         colocate_gradients_with_ops=colocate_gradients_with_ops)
    513     if gate_gradients == Optimizer.GATE_GRAPH:
    514       grads = control_flow_ops.tuple(grads)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_impl.py in gradients(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients)
    156         ys, xs, grad_ys, name, colocate_gradients_with_ops,
    157         gate_gradients, aggregation_method, stop_gradients,
--> 158         unconnected_gradients)
    159   # pylint: enable=protected-access
    160 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
    718               # issue here because of zeros.
    719               if loop_state:
--> 720                 out_grads[i] = loop_state.ZerosLike(op, i)
    721               else:
    722                 out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py in ZerosLike(self, op, index)
   1229       # If the shape is known statically, just create a zero tensor with
   1230       # the right shape in the grad loop context.
-> 1231       result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
   1232       if dead_branch:
   1233         # op is a cond switch. Guard the zero tensor with a switch.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name)
    244   """
    245   return _constant_impl(value, dtype, shape, name, verify_shape=False,
--> 246                         allow_broadcast=True)
    247 
    248 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast)
    282       tensor_util.make_tensor_proto(
    283           value, dtype=dtype, shape=shape, verify_shape=verify_shape,
--> 284           allow_broadcast=allow_broadcast))
    285   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    286   const_tensor = g.create_op(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
    499                             dtype.base_dtype != numpy_dtype.base_dtype):
    500     raise TypeError("Incompatible types: %s vs. %s. Value is %s" %
--> 501                     (dtype, nparray.dtype, values))
    502 
    503   # If shape is not given, get the shape from the numpy array.

TypeError: Incompatible types: <dtype: 'resource'> vs. int64. Value is 0

UPDATE::
this works but if you set the trainable boolean to True, it throws the same error

  enc = tf.keras.layers.TimeDistributed(tf.layers.BatchNormalization(trainable = False))(enc)
@gadagashwini-zz gadagashwini-zz self-assigned this Aug 23, 2019
@gadagashwini-zz gadagashwini-zz added comp:keras Keras related issues type:bug Bug labels Aug 23, 2019
@gadagashwini-zz
Copy link
Contributor

I could reproduce the issue with Tensorflow 1.14.0 and tf-nightly. Here is the gist.
@iamnotahumanbecauseiamabot, Which version of tensorflow you using. Thanks!

@gadagashwini-zz gadagashwini-zz added the stat:awaiting response Status - Awaiting response from author label Aug 23, 2019
@notabee
Copy link
Author

notabee commented Aug 23, 2019

@gadagashwini I am using 1.14.0, though if you don't apply TimeDistributed layer on BatchNormalisation, it will work.

@gadagashwini-zz
Copy link
Contributor

@iamnotahumanbecauseiamabot, Thanks for the update.

@gadagashwini-zz gadagashwini-zz added TF 1.14 for issues seen with TF 1.14 and removed stat:awaiting response Status - Awaiting response from author labels Aug 23, 2019
@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Aug 23, 2019
@notabee
Copy link
Author

notabee commented Aug 24, 2019

@robieta I have updated the issue please see the update.

@sseveran
Copy link

I am seeing the same thing when using the MirroredStrategy. The model works fine when executing normally. We don't have TimeDistributed Layers. If you want another ticket I am happy to make one but I am probably not going to be able to generate a repro case as the model is quite large.

@jvishnuvardhan
Copy link
Contributor

@sseveran It would be good if you create another issue with MirroredStrategy. It's even better if you can provide a simple standalone code. Thanks!

@robieta robieta removed their assignment Feb 8, 2020
@akloss
Copy link

akloss commented Mar 4, 2020

I'm getting the same error on tensorflow 1.15 when using batch normalization inside a custom RNNCell + RNN layer. The error only appears in the while loop of the rnn.

It also seems like the problem was resolved in tensorflow version 2. Would be great if the fix could be ported back into version 1 if that is possible. (I tried adapting control_flow_ops.py accordingly, but ran into more errors that I don't understand)

@saikumarchalla
Copy link

@ akloss Closing this issue as it was resolved in TF version 2. Thanks!

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 1.14 for issues seen with TF 1.14 type:bug Bug
Projects
None yet
Development

No branches or pull requests

7 participants