Skip to content

Commit

Permalink
Fix bug with loading nested model with trainable/nontrainable weights.
Browse files Browse the repository at this point in the history
Changed the test to the example from #27769.

PiperOrigin-RevId: 254305891
  • Loading branch information
k-w-w authored and tensorflower-gardener committed Jun 21, 2019
1 parent 7e5a151 commit f42549a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 38 deletions.
40 changes: 20 additions & 20 deletions tensorflow/python/keras/saving/hdf5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,30 +272,30 @@ def convert_nested_model(weights):
Returns:
A list of weights values (Numpy arrays).
"""
new_weights = []
# trainable weights
for sublayer in layer.layers:
num_weights = len(sublayer.trainable_weights)
if num_weights > 0:
new_weights.extend(preprocess_weights_for_loading(
layer=sublayer,
weights=weights[:num_weights],
original_keras_version=original_keras_version,
original_backend=original_backend))
weights = weights[num_weights:]
trainable_weights = weights[:len(layer.trainable_weights)]
non_trainable_weights = weights[len(layer.trainable_weights):]

new_trainable_weights = []
new_non_trainable_weights = []

# non-trainable weights
for sublayer in layer.layers:
num_weights = len([l for l in sublayer.weights
if l not in sublayer.trainable_weights])
if num_weights > 0:
new_weights.extend(preprocess_weights_for_loading(
num_trainable_weights = len(sublayer.trainable_weights)
num_non_trainable_weights = len(sublayer.non_trainable_weights)
if sublayer.weights:
preprocessed = preprocess_weights_for_loading(
layer=sublayer,
weights=weights[:num_weights],
weights=(trainable_weights[:num_trainable_weights] +
non_trainable_weights[:num_non_trainable_weights]),
original_keras_version=original_keras_version,
original_backend=original_backend))
weights = weights[num_weights:]
return new_weights
original_backend=original_backend)
new_trainable_weights.extend(preprocessed[:num_trainable_weights])
new_non_trainable_weights.extend(preprocessed[num_trainable_weights:])

trainable_weights = trainable_weights[num_trainable_weights:]
non_trainable_weights = non_trainable_weights[
num_non_trainable_weights:]

return new_trainable_weights + new_non_trainable_weights

# Convert layers nested in Bidirectional/Model/Sequential.
# Both transformation should be ran for both Keras 1->2 conversion
Expand Down
36 changes: 18 additions & 18 deletions tensorflow/python/keras/saving/hdf5_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,33 +265,33 @@ def test_nested_model_weight_loading(self):
self.addCleanup(shutil.rmtree, temp_dir)
h5_path = os.path.join(temp_dir, 'test.h5')

num_hidden = 5
input_dim = 3
batch_size = 5
num_classes = 2
shape = (None, None, 3)

with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
model.add(keras.layers.Dense(num_classes))
def gen_model():

nested_model = keras.models.Sequential()
nested_model.add(keras.layers.Dense(num_hidden, input_dim=num_classes))
nested_model.add(keras.layers.Dense(num_classes))
model.add(nested_model)
def seq_model():
model = keras.models.Sequential([
keras.layers.Conv2D(3, 1, input_shape=shape),
keras.layers.BatchNormalization()])
return model

x = np.random.random((batch_size, input_dim))
x = inner_inputs = keras.layers.Input((None, None, 3))
x = seq_model()(x)
x = seq_model()(x)
inner_model = keras.models.Model(inner_inputs, x)

inputs = keras.layers.Input(shape)
return keras.models.Model(inputs, inner_model(inputs))

model = gen_model()
x = np.random.random((batch_size, 1, 1, 3))
ref_y = model.predict(x)

model.save_weights(h5_path)

model = keras.models.Sequential()
model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
model.add(keras.layers.Dense(num_classes))
nested_model = keras.models.Sequential()
nested_model.add(keras.layers.Dense(num_hidden, input_dim=num_classes))
nested_model.add(keras.layers.Dense(num_classes))
model.add(nested_model)
model = gen_model()
model.load_weights(h5_path)
y = model.predict(x)

Expand Down

0 comments on commit f42549a

Please sign in to comment.