Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[r2.0-CherryPick] Deduplicate Keras weights #32257

Merged
merged 1 commit into from Sep 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 9 additions & 8 deletions tensorflow/python/keras/engine/base_layer.py
Expand Up @@ -905,18 +905,20 @@ def input_spec(self, value):
def trainable_weights(self):
if self.trainable:
nested = self._gather_children_attribute('trainable_weights')
return self._trainable_weights + nested
return self._dedup_weights(self._trainable_weights + nested)
else:
return []

@property
def non_trainable_weights(self):
if self.trainable:
nested = self._gather_children_attribute('non_trainable_weights')
return self._non_trainable_weights + nested
non_trainable_weights = self._non_trainable_weights + nested
else:
nested = self._gather_children_attribute('weights')
return self._trainable_weights + self._non_trainable_weights + nested
non_trainable_weights = (
self._trainable_weights + self._non_trainable_weights + nested)
return self._dedup_weights(non_trainable_weights)

@property
def weights(self):
Expand Down Expand Up @@ -2452,14 +2454,13 @@ def _list_functions_for_serialization(self, serialization_cache):
serialization_cache))
return fns

@property
def _unique_trainable_weights(self):
"""Dedupe trainable weights while maintaining order as much as possible."""
trainable_weights = self.trainable_weights
def _dedup_weights(self, weights):
"""Dedupe weights while maintaining order as much as possible."""
output, seen_weights = [], object_identity.ObjectIdentitySet()
for w in trainable_weights:
for w in weights:
if w not in seen_weights:
output.append(w)
# Track the Variable's identity to avoid __eq__ issues.
seen_weights.add(w)
return output

Expand Down
24 changes: 16 additions & 8 deletions tensorflow/python/keras/engine/network.py
Expand Up @@ -472,6 +472,11 @@ def weights(self):
Returns:
A list of variables.
"""
return self._dedup_weights(self._undeduplicated_weights)

@property
def _undeduplicated_weights(self):
"""Returns the undeduplicated list of all layer variables/weights."""
self._assert_weights_created()
weights = []
for layer in self._layers:
Expand Down Expand Up @@ -535,18 +540,21 @@ def get_layer(self, name=None, index=None):
@property
def trainable_weights(self):
self._assert_weights_created()
return trackable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._trainable_weights)
return self._dedup_weights(
trackable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._trainable_weights))

@property
def non_trainable_weights(self):
self._assert_weights_created()
return trackable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._non_trainable_weights + self._trainable_weights)
return self._dedup_weights(
trackable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._non_trainable_weights +
self._trainable_weights))

@property
def input_spec(self):
Expand Down
7 changes: 3 additions & 4 deletions tensorflow/python/keras/engine/training.py
Expand Up @@ -386,7 +386,7 @@ def compile(self,
self.predict_function = None

# Collected trainable weights, sorted in topological order.
self._collected_trainable_weights = self._unique_trainable_weights
self._collected_trainable_weights = self.trainable_weights

# Validate all variables were correctly created in distribution scope.
if self._distribution_strategy and not self._compile_distribution:
Expand Down Expand Up @@ -1535,7 +1535,7 @@ def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode):
# Set metric attributes on model.
self._set_metric_attributes()

self._collected_trainable_weights = self._unique_trainable_weights
self._collected_trainable_weights = self.trainable_weights

def _update_sample_weight_modes(self, sample_weights=None):
"""Updates sample weight modes based on training/eval inputs.
Expand Down Expand Up @@ -2046,8 +2046,7 @@ def _check_trainable_weights_consistency(self):
if not hasattr(self, '_collected_trainable_weights'):
return

if (len(self._unique_trainable_weights) !=
len(self._collected_trainable_weights)):
if len(self.trainable_weights) != len(self._collected_trainable_weights):
logging.log_first_n(
logging.WARN, 'Discrepancy between trainable weights and collected'
' trainable weights, did you set `model.trainable`'
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/keras/engine/training_eager.py
Expand Up @@ -258,7 +258,7 @@ def _process_single_batch(model,
else:
scaled_total_loss = total_loss
if training:
trainable_weights = model._unique_trainable_weights
trainable_weights = model.trainable_weights
if trainable_weights:
# TODO(tanzheny) b/132690565: Provide mechanism for user to override
# model.train_on_batch.
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/python/keras/engine/training_test.py
Expand Up @@ -904,6 +904,23 @@ def test_that_trainable_disables_updates(self):
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)

def test_weight_deduplication_in_methods(self):
inp = keras.layers.Input(shape=(1,))
bn = keras.layers.BatchNormalization()
d = keras.layers.Dense(1)

m0 = keras.models.Model(inp, d(bn(inp)))
m1 = keras.models.Model(inp, d(bn(inp)))

x0 = m0(inp)
x1 = m1(inp)
x = keras.layers.Add()([x0, x1])

model = keras.models.Model(inp, x)
self.assertLen(model.trainable_weights, 4)
self.assertLen(model.non_trainable_weights, 2)
self.assertLen(model.weights, 6)

@keras_parameterized.run_all_keras_modes
def test_weight_deduplication(self):
class WatchingLayer(keras.layers.Layer):
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/python/keras/premade/wide_deep.py
Expand Up @@ -102,8 +102,8 @@ def _get_optimizers(self):

# This does not support gradient scaling and LossScaleOptimizer.
def _backwards(self, tape, loss):
linear_vars = self.linear_model._unique_trainable_weights # pylint: disable=protected-access
dnn_vars = self.dnn_model._unique_trainable_weights # pylint: disable=protected-access
linear_vars = self.linear_model.trainable_weights # pylint: disable=protected-access
dnn_vars = self.dnn_model.trainable_weights # pylint: disable=protected-access
linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars))
linear_optimizer, dnn_optimizer = self._get_optimizers()
linear_optimizer.apply_gradients(zip(linear_grads, linear_vars))
Expand Down Expand Up @@ -134,11 +134,11 @@ def _make_train_function(self):
# Training updates
updates = []
linear_updates = linear_optimizer.get_updates(
params=self.linear_model._unique_trainable_weights, # pylint: disable=protected-access
params=self.linear_model.trainable_weights, # pylint: disable=protected-access
loss=self.total_loss)
updates += linear_updates
dnn_updates = dnn_optimizer.get_updates(
params=self.dnn_model._unique_trainable_weights, # pylint: disable=protected-access
params=self.dnn_model.trainable_weights, # pylint: disable=protected-access
loss=self.total_loss)
updates += dnn_updates
# Unconditional updates
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/python/keras/saving/hdf5_format.py
Expand Up @@ -75,6 +75,12 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
# TODO(psv) Add warning when we save models that contain non-serializable
# entities like metrics added using `add_metric` and losses added using
# `add_loss.`
if len(model.weights) != len(model._undeduplicated_weights):
logging.warning('Found duplicated `Variable`s in Model\'s `weights`. '
'This is usually caused by `Variable`s being shared by '
'Layers in the Model. These `Variable`s will be treated '
'as separate `Variable`s when the Model is restored. To '
'avoid this, please save with `save_format="tf"`.')

if not isinstance(filepath, h5py.File):
# If file exists and should not be overwritten.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/keras/utils/layer_utils.py
Expand Up @@ -235,7 +235,7 @@ def print_layer_summary_with_connections(layer):
if hasattr(model, '_collected_trainable_weights'):
trainable_count = count_params(model._collected_trainable_weights)
else:
trainable_count = count_params(model._unique_trainable_weights)
trainable_count = count_params(model.trainable_weights)

non_trainable_count = count_params(model.non_trainable_weights)

Expand Down