Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras set_weights function not working with batch normalization layer in the network #258

Closed
aqibsaeed opened this issue Mar 20, 2019 · 8 comments
Assignees

Comments

@aqibsaeed
Copy link

aqibsaeed commented Mar 20, 2019

model.set_weights() function from Keras is not working if the network consists of batch normalization layer(s). I looked into the issue, I believe, it is due to having an incorrect order of array elements as returned by tff.learning.keras_weights_from_tff_weights(state.model) function which does not match the output of model.get_weights(). Also, the underlying reason could be that the state.model contains separate tuple for trainable and non-trainable weights.

@ZacharyGarrett ZacharyGarrett self-assigned this Mar 21, 2019
@aqibsaeed
Copy link
Author

Any updates on this? Thank you in advance.

@ZacharyGarrett
Copy link
Collaborator

Thank you for the report @aqibsaeed, apologies for the delay.

I've been trying to reproduce this with little success. Might you be able to provide a minimal code that defines an example model and input data to reproduce this? Which TFF release you see this behavior?

Its also not entirely clear to us (yet) how the batch normalization parameters should be handled in the federated setting.

@aqibsaeed
Copy link
Author

aqibsaeed commented Mar 28, 2019

I used the colab notebook provided on tensorflow.org/federated website.

def create_compiled_keras_model():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(64, activation='linear', kernel_initializer='zeros', input_shape=(784,)),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.Activation('relu'),
      tf.keras.layers.Dense(64, activation='linear', kernel_initializer='zeros'),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.Activation('relu'),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax, kernel_initializer='zeros')
  ])
  
  def loss_fn(y_true, y_pred):
    return tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(
        y_true, y_pred))
 
  model.compile(
      loss=loss_fn,
      optimizer=gradient_descent.SGD(learning_rate=0.02),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  return model

After one round of training as described in the notebook, I tried to set the weights as follows and got the error pasted below: (Am I not setting the weights correctly?)

n_model = create_compiled_keras_model()
n_model.set_weights(tff.learning.keras_weights_from_tff_weights(state.model))

`ValueError Traceback (most recent call last)
in ()
1 nglobal_model = create_compiled_keras_model()
----> 2 nglobal_model.set_weights(tff.learning.keras_weights_from_tff_weights(state.model))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py in set_weights(self, weights)
406 tuples.append((sw, w))
407 weights = weights[num_param:]
--> 408 backend.batch_set_value(tuples)
409
410 def compute_mask(self, inputs, mask):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py in batch_set_value(tuples)
2858 if ops.executing_eagerly_outside_functions():
2859 for x, value in tuples:
-> 2860 x.assign(np.asarray(value, dtype=dtype(x)))
2861 else:
2862 with get_graph().as_default():

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py in assign(self, value, use_locking, name, read_value)
913 with _handle_graph(self.handle):
914 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
--> 915 self._shape.assert_is_compatible_with(value_tensor.shape)
916 assign_op = gen_resource_variable_ops.assign_variable_op(
917 self.handle, value_tensor, name=name)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_shape.py in assert_is_compatible_with(self, other)
1021 """
1022 if not self.is_compatible_with(other):
-> 1023 raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1024
1025 def most_specific_compatible_shape(self, other):

ValueError: Shapes (64,) and (64, 64) are incompatible`

To debug this issue I tried printing weight matrix length for each layer, it seems the order of weight matrix as returned by tff.learning.keras_weights_from_tff_weights() is different than base_model.get_weights().

base_model = create_compiled_keras_model()
random_weights = base_model.get_weights()
for rw in random_weights:
  print(len(rw))

784
64
64
64
64
64
64
64
64
64
64
64
64
10

federated_weights = tff.learning.keras_weights_from_tff_weights(state.model)
for fw in federated_weights:
  print(len(fw))

784
64
64
64
64
64
64
64
64
10
64
64
64
64

@ZacharyGarrett
Copy link
Collaborator

@aqibsaeed thank you for sharing the model you used and the steps you have already taken, they were very helpful!

It looks like tf.keras.Model.weights and tf.keras.Moldel.get_weights() are returning the weights in different order. I suspect set_weights() is expecting the order returned by get_weights(), but TFF is pulling from .weights parameter.

I'll look into a fix for this.

federated_weights = tff.learning.keras_weights_from_tff_weights(state.model)
for f_w, k_w in zip(federated_weights, base_model.weights):
  print('{} == {}: {}'.format(f_w.shape, k_w.shape, f_w.shape == k_w.shape))

(784, 64) == (784, 64): True
(64,) == (64,): True
(64,) == (64,): True
(64,) == (64,): True
(64, 64) == (64, 64): True
(64,) == (64,): True
(64,) == (64,): True
(64,) == (64,): True
(64, 10) == (64, 10): True
(10,) == (10,): True
(64,) == (64,): True
(64,) == (64,): True
(64,) == (64,): True
(64,) == (64,): True

for f_w, k_w in zip(federated_weights, base_model.get_weights()):
  print('{} == {}: {}'.format(f_w.shape, k_w.shape, f_w.shape == k_w.shape))

(784, 64) == (784, 64): True
(64,) == (64,): True
(64,) == (64,): True
(64,) == (64,): True
(64, 64) == (64,): False
(64,) == (64,): True
(64,) == (64, 64): False
(64,) == (64,): True
(64, 10) == (64,): False
(10,) == (64,): False
(64,) == (64,): True
(64,) == (64,): True
(64,) == (64, 10): False
(64,) == (10,): False

@aqibsaeed
Copy link
Author

Thank you very much!

tensorflow-copybara pushed a commit that referenced this issue Mar 29, 2019
…Model.

This is a fix for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same.

- Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change.
- Move client optimizer variables to local_variables, this includes variables such as iteration number.

PiperOrigin-RevId: 240198034
tensorflow-copybara pushed a commit that referenced this issue Mar 29, 2019
…Model.

This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same.

- Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change.
- Move client optimizer variables to local_variables, this includes variables such as iteration number.

PiperOrigin-RevId: 240198034
tensorflow-copybara pushed a commit that referenced this issue Mar 29, 2019
…Model.

This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same.

- Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change.
- Move client optimizer variables to local_variables, this includes variables such as iteration number.

PiperOrigin-RevId: 240198034
tensorflow-copybara pushed a commit that referenced this issue Mar 29, 2019
…Model.

This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same.

- Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change.
- Move client optimizer variables to local_variables, this includes variables such as iteration number.

PiperOrigin-RevId: 240198034
tensorflow-copybara pushed a commit that referenced this issue Mar 29, 2019
…Model.

This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same.

- Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change.
- Move client optimizer variables to local_variables, this includes variables such as iteration number.

PiperOrigin-RevId: 241075678
@ZacharyGarrett
Copy link
Collaborator

@aqibsaeed

(Note: this checked-in to source, but isn't in the pip package yet. The next release will be in a week or two)

tff.learning.framework.ModelWeights has a new interface method called assign_weights_to (api_docs) which should be a workaround in the short term.

The code would look something like:

keras_model = tf.keras.models.Sequential(...)
tff_weights = tff.learning.keras_weights_from_tff_weights(state.model)
tff_weights.assign_weights_to(keras_model)

Please also be aware that BatchNorm has trainable variables (beta, gamma) and non-trainable variables (mean, variance). How these should be aggregated at the global model is still an open question, and likely the vanilla FedAvg algorithm in the FL API may not be doing what is desired.

@aqibsaeed
Copy link
Author

Sounds good! Thank you very much for the prompt action. I think this issue can be closed now.

@AbbasiAYE
Copy link

AbbasiAYE commented Jun 1, 2019

Hello
I followed the above discussion on keras models with batch_normalization, this is the relevant part of the code:

for round_num in range(2, 70): 
    FLstate, FLoutputs = trainer_Itr_Process.next(FLstate, federated_train_data)   
    FLlosses_arr.append(FLoutputs.loss)
    tff_weights= tff.learning.keras_weights_from_tff_weights(FLstate.model)
    tff_weights.assign_weights_to(tff_weights, Local_model_Fed)

I appreciate the help with the following error:

Traceback (most recent call last):
  File "C:\Users\ezalaab\.p2\pool\plugins\org.python.pydev.core_7.1.0.201902031515\pysrc\pydevd.py", line 2225, in <module>
    main()
  File "C:\Users\ezalaab\.p2\pool\plugins\org.python.pydev.core_7.1.0.201902031515\pysrc\pydevd.py", line 2218, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "C:\Users\ezalaab\.p2\pool\plugins\org.python.pydev.core_7.1.0.201902031515\pysrc\pydevd.py", line 1560, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "C:\Users\ezalaab\.p2\pool\plugins\org.python.pydev.core_7.1.0.201902031515\pysrc\pydevd.py", line 1567, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Users\ezalaab\.p2\pool\plugins\org.python.pydev.core_7.1.0.201902031515\pysrc\_pydev_imps\_pydev_execfile.py", line 25, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:\Users\ezalaab\Documents\eclipse-workspace\MANA-FederatedLearning\location-based-federated-learning\Test\FederatedLearningTFv3.py", line 299, in <module>
    tff_weights.assign_weights_to(tff_weights, Local_model_Fed)
AttributeError: 'list' object has no attribute 'assign_weights_to

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants