-
Notifications
You must be signed in to change notification settings - Fork 74.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tf.Module doesn't recognise non trainable variables from keras layers [TF 2.0] #29329
Comments
I could reproduce the reported issue here on tf-nightly-2.0-preview version. Thanks! |
Thanks for reporting the issue. Here is some context about the root cause. There are two "trainable" concept here, one is that whether the variable itself is trainable, the second is that whether the layer contains the variable is trainable. The first one is immutable, while the second is mutable. The layer/container's trainable state will affect the return value for trainable/non_trainable_variable. When creating a variable, user could do layer.add_variable(trainable=False), that will force the variable to be non-trainable, regardless whether the layer itself is trainable or not. In the case that the layer is not trainable, we will create the variable as trainable variable. We check layer.trainable state in the layer.trainable_weights/non_trainable_weights to return the correct value. In your specific case, tf.module just recursively visit all its attribute, and find all the variable like object. It discards the container/layer trainable state, and relying on only the variable trainable information, which is why it returns 8 here. A simple workaround is to change the base class for M to be layer, which will correctly check the sub layer trainable state, while we are fixing the issue on the tf.module side. |
Also adding @tomhennigan who is the owner of tf.Module |
Unfortunately, the trainable_variable definition for keras.layer is different from tf.module. keras.layer will respect both "trainable" concept, while tf.module only respect the variable level "trainable" state. |
+1 to what @qlzh727 said. One option if you want to have the trainable behavior from Keras is to swap from class M(tf.keras.layers.Layer):
# .. same as your example ..
print(len(m.trainable_variables)) # 6 I think in general it's worth pointing out that there are a few places where Keras and TF don't agree on the definition of trainable, print(sum(1 for v in m.variables if v.trainable)) # 8
with tf.GradientTape() as tape:
m(tf.ones((10, 10)))
print(len(tape.watched_variables())) # 8 |
I think this issue should not be closed, until it fixed in |
Not all TensorFlow users use Keras. Keras has it's own definition of trainable/non-trainable variables (defined in terms of trainable/non-trainable |
System information
When using Keras layers inside
tf.Module
module and settingtrainable=False
in the keras layer doesn't results in non-trainable variables in thetf.Module
scope.The below example code module
M
'strainable_variables
should return 6, But it returns 8.Code to reproduce the issue
The text was updated successfully, but these errors were encountered: