Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.Module doesn't recognise non trainable variables from keras layers [TF 2.0] #29329

Closed
n3011 opened this issue Jun 2, 2019 · 8 comments
Closed
Assignees
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@n3011
Copy link
Contributor

n3011 commented Jun 2, 2019

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): tf-nightly-2.0-preview==2.0.0.dev20190602
  • Python version: 3.6
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: CPU
  • GPU model and memory:

When using Keras layers inside tf.Module module and setting trainable=False in the keras layer doesn't results in non-trainable variables in the tf.Module scope.
The below example code module M's trainable_variables should return 6, But it returns 8.

Code to reproduce the issue



class M(tf.Module):

    def __init__(self):
        super(M, self).__init__()
        self.list = []
        self.list.append([tf.keras.layers.Dense(5, trainable=False), tf.keras.layers.Dense(5)])
        self.list.append([tf.keras.layers.Dense(5), tf.keras.layers.Dense(5)])

    def __call__(self, inputs):
        output = inputs
        for l_list in self.list:
            for l in l_list:
                output = l(output)
        return output

m = M()
m(tf.ones((10, 10)))
Got: print(len(m.trainable_variables))  = 8

Expected: print(len(m.trainable_variables)) = 6
@n3011 n3011 changed the title tf.Module doesn't recognise non trainable variables from keras layers tf.Module doesn't recognise non trainable variables from keras layers [TF 2.0] Jun 2, 2019
@gadagashwini-zz gadagashwini-zz self-assigned this Jun 3, 2019
@gadagashwini-zz gadagashwini-zz added TF 2.0 Issues relating to TensorFlow 2.0 comp:keras Keras related issues type:bug Bug labels Jun 3, 2019
@gadagashwini-zz
Copy link
Contributor

I could reproduce the reported issue here on tf-nightly-2.0-preview version. Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 4, 2019
@qlzh727
Copy link
Member

qlzh727 commented Jun 4, 2019

Thanks for reporting the issue. Here is some context about the root cause.

There are two "trainable" concept here, one is that whether the variable itself is trainable, the second is that whether the layer contains the variable is trainable. The first one is immutable, while the second is mutable. The layer/container's trainable state will affect the return value for trainable/non_trainable_variable.

When creating a variable, user could do layer.add_variable(trainable=False), that will force the variable to be non-trainable, regardless whether the layer itself is trainable or not.

In the case that the layer is not trainable, we will create the variable as trainable variable. We check layer.trainable state in the layer.trainable_weights/non_trainable_weights to return the correct value.

In your specific case, tf.module just recursively visit all its attribute, and find all the variable like object. It discards the container/layer trainable state, and relying on only the variable trainable information, which is why it returns 8 here.

A simple workaround is to change the base class for M to be layer, which will correctly check the sub layer trainable state, while we are fixing the issue on the tf.module side.

@qlzh727
Copy link
Member

qlzh727 commented Jun 4, 2019

Also adding @tomhennigan who is the owner of tf.Module

@qlzh727
Copy link
Member

qlzh727 commented Jun 4, 2019

Unfortunately, the trainable_variable definition for keras.layer is different from tf.module. keras.layer will respect both "trainable" concept, while tf.module only respect the variable level "trainable" state.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 5, 2019
@tomhennigan
Copy link
Member

+1 to what @qlzh727 said. One option if you want to have the trainable behavior from Keras is to swap from tf.Module to tf.keras.layers.Layer as your base class. Keras layers support being deeply nested so the rest of your code works unchanged:

class M(tf.keras.layers.Layer):
  # .. same as your example ..

print(len(m.trainable_variables))  # 6

I think in general it's worth pointing out that there are a few places where Keras and TF don't agree on the definition of trainable, tf.Module is consistent with these other parts of TF. A few examples (ignoring TF1 and global collections etc):

print(sum(1 for v in m.variables if v.trainable))  # 8

with tf.GradientTape() as tape:
  m(tf.ones((10, 10)))
  print(len(tape.watched_variables()))  # 8

@tensorflow-bot
Copy link

tensorflow-bot bot commented Jun 5, 2019

Are you satisfied with the resolution of your issue?
Yes
No

@n3011
Copy link
Contributor Author

n3011 commented Jun 7, 2019

I think this issue should not be closed, until it fixed in tf.Module. If tf.Module can't work properly with tf.keras.layers; what's the purpose of it?

@tomhennigan
Copy link
Member

Not all TensorFlow users use Keras. tf.Module is a simple, framework independent base class for stateful objects in TensorFlow. It enables checkpointing, and variable/module tracking. For more motivation please see the RFC tensorflow/community#56.

Keras has it's own definition of trainable/non-trainable variables (defined in terms of trainable/non-trainable Layers), if you want to use the Keras definition then please use Keras directly. The good news is that since 23c8fd4 Keras Layer extends tf.Module so if subclassing Layer is a more appropriate choice you don't lose the benefits enabled by tf.Module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests

7 participants