Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in tf.contrib.layers.batch_norm when center=False, data_format='NCHW' and zero_debias_moving_mean=True #11673

Closed
Kipok opened this issue Jul 21, 2017 · 1 comment · Fixed by #12865
Labels
stat:contribution welcome Status - Contributions welcome

Comments

@Kipok
Copy link

Kipok commented Jul 21, 2017

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 14.04
  • TensorFlow installed from (source or binary): source
  • TensorFlow version (use command below): 1.2.1
  • Python version: 3.4.3
  • Bazel version (if compiling from source): 0.5.2
  • CUDA/cuDNN version: 8.0/6.0
  • GPU model and memory: Tesla P100 (16 gb)

Describe the problem

Batch norm layer fails with an error when both center=False, data_format='NCHW' and zero_debias_moving_mean=True arguments are used. It looks like the solution would be just adding additional if to check if beta is None in the same way it is done for gamma, but maybe there are some more dependencies.

Source code / logs

The following code could be used to reproduce the issue:

import tensorflow as tf
a = tf.placeholder(tf.float32, shape=(10, 10, 10, 10))
b = tf.contrib.layers.batch_norm(a, center=False, data_format='NCHW',
                                 zero_debias_moving_mean=True)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

The error log:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/op_def_library.py in apply_op(self, op_type_name, name, **keywords)
    489                 as_ref=input_arg.is_ref,
--> 490                 preferred_dtype=default_dtype)
    491           except TypeError as err:

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype)
    675         if ret is None:
--> 676           ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    677 

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    120   _ = as_ref
--> 121   return constant(v, dtype=dtype, name=name)
    122 

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name, verify_shape)
    101   tensor_value.tensor.CopyFrom(
--> 102       tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
    103   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
    363     if values is None:
--> 364       raise ValueError("None values not supported.")
    365     # if dtype is provided, forces numpy array to be the type

ValueError: None values not supported.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/op_def_library.py in apply_op(self, op_type_name, name, **keywords)
    503               observed = ops.internal_convert_to_tensor(
--> 504                   values, as_ref=input_arg.is_ref).dtype.name
    505             except ValueError as err:

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype)
    675         if ret is None:
--> 676           ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    677 

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    120   _ = as_ref
--> 121   return constant(v, dtype=dtype, name=name)
    122 

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name, verify_shape)
    101   tensor_value.tensor.CopyFrom(
--> 102       tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
    103   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
    363     if values is None:
--> 364       raise ValueError("None values not supported.")
    365     # if dtype is provided, forces numpy array to be the type

ValueError: None values not supported.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-1-c9cf0f67668a> in <module>()
      3 
      4 a = tf.placeholder(tf.float32, shape=(10, 10, 10, 10))
----> 5 b = tf.contrib.layers.batch_norm(a, center=False, data_format='NCHW', zero_debias_moving_mean=True)
      6 sess = tf.Session()
      7 sess.run(tf.global_variables_initializer())

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py in func_with_args(*args, **kwargs)
    179       current_args = current_scope[key_func].copy()
    180       current_args.update(kwargs)
--> 181     return func(*args, **current_args)
    182   _add_op(func)
    183   setattr(func_with_args, '_key_op', _key_op(func))

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/contrib/layers/python/layers/layers.py in batch_norm(inputs, decay, center, scale, epsilon, activation_fn, param_initializers, param_regularizers, updates_collections, is_training, reuse, variables_collections, outputs_collections, trainable, batch_weights, fused, data_format, zero_debias_moving_mean, scope, renorm, renorm_clipping, renorm_decay)
    806       mean = array_ops.reshape(mean, params_shape_broadcast)
    807       variance = array_ops.reshape(variance, params_shape_broadcast)
--> 808       beta = array_ops.reshape(beta, params_shape_broadcast)
    809       if gamma is not None:
    810         gamma = array_ops.reshape(gamma, params_shape_broadcast)

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/ops/gen_array_ops.py in reshape(tensor, shape, name)
   2500   """
   2501   result = _op_def_lib.apply_op("Reshape", tensor=tensor, shape=shape,
-> 2502                                 name=name)
   2503   return result
   2504 

~/Documents/weight-normalization-exps/venv/lib/python3.4/site-packages/tensorflow/python/framework/op_def_library.py in apply_op(self, op_type_name, name, **keywords)
    506               raise ValueError(
    507                   "Tried to convert '%s' to a tensor and failed. Error: %s" %
--> 508                   (input_name, err))
    509             prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
    510                       (input_name, op_type_name, observed))

ValueError: Tried to convert 'tensor' to a tensor and failed. Error: None values not supported.
shreyneil added a commit to shreyneil/tensorflow that referenced this issue Jul 23, 2017
I have added a check on beta with reference to the following issue raised.
tensorflow#11673

Please verify and get back.
Thank you.
@reedwm
Copy link
Member

reedwm commented Jul 24, 2017

Marking as contributions welcome, since it seems #11687 will fix this issue.

@reedwm reedwm added the stat:contribution welcome Status - Contributions welcome label Jul 24, 2017
shreyneil added a commit to shreyneil/tensorflow that referenced this issue Oct 20, 2017
The following PR contains the required unit test case regarding issue tensorflow#11673 
and it also has a fix in another pull request tensorflow#13829
This was referenced Oct 20, 2017
shreyneil added a commit to shreyneil/tensorflow that referenced this issue Nov 22, 2017
The Changes have been made as suggested in my previous pull request.
tensorflow#13829

The unit test case has been updated along with the update in layers.py with reference to the issue tensorflow#11673. 
Please verify and get back.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants