Skip to content

Commit

Permalink
Add a model for assign tff.learning.Model weights back to a tf.keras.…
Browse files Browse the repository at this point in the history
…Model.

This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same.

- Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change.
- Move client optimizer variables to local_variables, this includes variables such as iteration number.

PiperOrigin-RevId: 241075678
  • Loading branch information
ZacharyGarrett authored and Copybara-Service committed Mar 29, 2019
1 parent 6824f2f commit 46f9c69
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 5 deletions.
17 changes: 17 additions & 0 deletions docs/api_docs/python/tff/learning/framework/ModelWeights.md
Expand Up @@ -5,6 +5,7 @@
<meta itemprop="property" content="non_trainable"/>
<meta itemprop="property" content="keras_weights"/>
<meta itemprop="property" content="__new__"/>
<meta itemprop="property" content="assign_weights_to"/>
<meta itemprop="property" content="from_model"/>
<meta itemprop="property" content="from_tff_value"/>
</div>
Expand Down Expand Up @@ -47,8 +48,24 @@ Returns a list of weights in the same order as `tf.keras.Model.weights`.
(Assuming that this ModelWeights object corresponds to the weights of a keras
model).

IMPORTANT: this is not the same order as `tf.keras.Model.get_weights()`, and
hence will not work with `tf.keras.Model.set_weights()`. Instead, use
`tff.learning.ModelWeights.assign_weights_to`.

## Methods

<h3 id="assign_weights_to"><code>assign_weights_to</code></h3>

```python
assign_weights_to(keras_model)
```

Assign these TFF model weights to the weights of a `tf.keras.Model`.

#### Args:

* <b>`keras_model`</b>: the `tf.keras.Model` object to assign weights to.

<h3 id="from_model"><code>from_model</code></h3>

```python
Expand Down
49 changes: 46 additions & 3 deletions tensorflow_federated/python/learning/model_examples.py
Expand Up @@ -171,21 +171,21 @@ def _dense_all_zeros_layer(input_dims=None, output_dim=1):
return build_keras_dense_layer()


def build_linear_regresion_keras_sequential_model(feature_dims):
def build_linear_regresion_keras_sequential_model(feature_dims=2):
"""Build a linear regression `tf.keras.Model` using the Sequential API."""
keras_model = tf.keras.models.Sequential()
keras_model.add(_dense_all_zeros_layer(feature_dims))
return keras_model


def build_linear_regresion_keras_functional_model(feature_dims):
def build_linear_regresion_keras_functional_model(feature_dims=2):
"""Build a linear regression `tf.keras.Model` using the functional API."""
a = tf.keras.layers.Input(shape=(feature_dims,))
b = _dense_all_zeros_layer()(a)
return tf.keras.Model(inputs=a, outputs=b)


def build_linear_regresion_keras_subclass_model(feature_dims):
def build_linear_regresion_keras_subclass_model(feature_dims=2):
"""Build a linear regression model by sub-classing `tf.keras.Model`."""
del feature_dims # unused.

Expand All @@ -207,3 +207,46 @@ def build_embedding_keras_model(vocab_size=10):
keras_model.add(tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=5))
keras_model.add(tf.keras.layers.Softmax())
return keras_model


def build_conv_batch_norm_keras_model():
"""Builds a test model with convolution and batch normalization."""
# This is an example of a model that has trainable and non-trainable
# variables.
l = tf.keras.layers
data_format = 'channels_last'
max_pool = l.MaxPooling2D((2, 2), (2, 2),
padding='same',
data_format=data_format)
keras_model = tf.keras.models.Sequential([
l.Reshape(target_shape=[28, 28, 1], input_shape=(28 * 28,)),
l.Conv2D(
32,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu,
kernel_initializer='zeros',
bias_initializer='zeros'),
max_pool,
l.BatchNormalization(),
l.Conv2D(
64,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu,
kernel_initializer='zeros',
bias_initializer='zeros'),
max_pool,
l.BatchNormalization(),
l.Flatten(),
l.Dense(
1024,
activation=tf.nn.relu,
kernel_initializer='zeros',
bias_initializer='zeros'),
l.Dropout(0.4),
l.Dense(10, kernel_initializer='zeros', bias_initializer='zeros'),
])
return keras_model
17 changes: 15 additions & 2 deletions tensorflow_federated/python/learning/model_utils.py
Expand Up @@ -96,9 +96,22 @@ def keras_weights(self):
(Assuming that this ModelWeights object corresponds to the weights of
a keras model).
IMPORTANT: this is not the same order as `tf.keras.Model.get_weights()`, and
hence will not work with `tf.keras.Model.set_weights()`. Instead, use
`tff.learning.ModelWeights.assign_weights_to`.
"""
return list(self.trainable.values()) + list(self.non_trainable.values())

def assign_weights_to(self, keras_model):
"""Assign these TFF model weights to the weights of a `tf.keras.Model`.
Args:
keras_model: the `tf.keras.Model` object to assign weights to.
"""
for k, w in zip(keras_model.weights, self.keras_weights):
k.assign(w)


def keras_weights_from_tff_weights(tff_weights):
"""Converts TFF's nested weights structure to flat weights.
Expand Down Expand Up @@ -442,8 +455,8 @@ def __init__(self, inner_model, dummy_batch):
inner_model.loss_functions[0], inner_model.metrics)

@property
def non_trainable_variables(self):
return (super(_TrainableKerasModel, self).non_trainable_variables +
def local_variables(self):
return (super(_TrainableKerasModel, self).local_variables +
self._keras_model.optimizer.variables())

@tf.contrib.eager.function(autograph=False)
Expand Down
52 changes: 52 additions & 0 deletions tensorflow_federated/python/learning/model_utils_test.py
Expand Up @@ -312,6 +312,58 @@ def loss_fn(y_true, y_pred):
self.assertGreater(m['loss'][0], 0.0)
self.assertEqual(m['loss'][1], input_vocab_size * num_iterations)

def test_keras_model_using_batch_norm(self):
model = model_examples.build_conv_batch_norm_keras_model()

def loss_fn(y_true, y_pred):
loss_per_example = tf.keras.losses.sparse_categorical_crossentropy(
y_true=y_true, y_pred=y_pred)
return tf.reduce_mean(loss_per_example)

model.compile(
optimizer=gradient_descent.SGD(learning_rate=0.01),
loss=loss_fn,
metrics=[NumBatchesCounter(), NumExamplesCounter()])

dummy_batch = collections.OrderedDict([
('x', np.zeros([1, 28 * 28], dtype=np.float32)),
('y', np.zeros([1, 1], dtype=np.int64)),
])
tff_model = model_utils.from_compiled_keras_model(
keras_model=model, dummy_batch=dummy_batch)

batch_size = 2
batch = {
'x':
np.random.uniform(low=0.0, high=1.0,
size=[batch_size, 28 * 28]).astype(np.float32),
'y':
np.random.random_integers(low=0, high=9, size=[batch_size,
1]).astype(np.int64),
}

num_iterations = 2
for _ in range(num_iterations):
self.evaluate(tff_model.train_on_batch(batch))

m = self.evaluate(tff_model.report_local_outputs())
self.assertEqual(m['num_batches'], [num_iterations])
self.assertEqual(m['num_examples'], [batch_size * num_iterations])
self.assertGreater(m['loss'][0], 0.0)
self.assertEqual(m['loss'][1], batch_size * num_iterations)

# Ensure we can assign the FL trained model weights to a new model.
tff_weights = model_utils.ModelWeights.from_model(tff_model)
keras_model = model_examples.build_conv_batch_norm_keras_model()
tff_weights.assign_weights_to(keras_model)

for keras_w, tff_w in zip(keras_model.weights, tff_weights.keras_weights):
self.assertAllClose(
self.evaluate(keras_w),
self.evaluate(tff_w),
atol=1e-4,
msg='Variable [{}]'.format(keras_w.name))

def test_wrap_tff_model_in_tf_computation(self):
feature_dims = 3

Expand Down

0 comments on commit 46f9c69

Please sign in to comment.