Skip to content

tf.Module doesn't recognise non trainable variables from keras layers [TF 2.0] #29329

@n3011

Description

@n3011

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): tf-nightly-2.0-preview==2.0.0.dev20190602
  • Python version: 3.6
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: CPU
  • GPU model and memory:

When using Keras layers inside tf.Module module and setting trainable=False in the keras layer doesn't results in non-trainable variables in the tf.Module scope.
The below example code module M's trainable_variables should return 6, But it returns 8.

Code to reproduce the issue



class M(tf.Module):

    def __init__(self):
        super(M, self).__init__()
        self.list = []
        self.list.append([tf.keras.layers.Dense(5, trainable=False), tf.keras.layers.Dense(5)])
        self.list.append([tf.keras.layers.Dense(5), tf.keras.layers.Dense(5)])

    def __call__(self, inputs):
        output = inputs
        for l_list in self.list:
            for l in l_list:
                output = l(output)
        return output

m = M()
m(tf.ones((10, 10)))
Got: print(len(m.trainable_variables))  = 8

Expected: print(len(m.trainable_variables)) = 6

Metadata

Metadata

Labels

TF 2.0Issues relating to TensorFlow 2.0comp:kerasKeras related issuestype:bugBug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions