New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras set_weights function not working with batch normalization layer in the network #258
Comments
Any updates on this? Thank you in advance. |
Thank you for the report @aqibsaeed, apologies for the delay. I've been trying to reproduce this with little success. Might you be able to provide a minimal code that defines an example model and input data to reproduce this? Which TFF release you see this behavior? Its also not entirely clear to us (yet) how the batch normalization parameters should be handled in the federated setting. |
I used the colab notebook provided on tensorflow.org/federated website.
After one round of training as described in the notebook, I tried to set the weights as follows and got the error pasted below: (Am I not setting the weights correctly?)
`ValueError Traceback (most recent call last) /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py in set_weights(self, weights) /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py in batch_set_value(tuples) /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py in assign(self, value, use_locking, name, read_value) /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_shape.py in assert_is_compatible_with(self, other) ValueError: Shapes (64,) and (64, 64) are incompatible` To debug this issue I tried printing weight matrix length for each layer, it seems the order of weight matrix as returned by
|
@aqibsaeed thank you for sharing the model you used and the steps you have already taken, they were very helpful! It looks like I'll look into a fix for this.
|
Thank you very much! |
…Model. This is a fix for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same. - Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change. - Move client optimizer variables to local_variables, this includes variables such as iteration number. PiperOrigin-RevId: 240198034
…Model. This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same. - Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change. - Move client optimizer variables to local_variables, this includes variables such as iteration number. PiperOrigin-RevId: 240198034
…Model. This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same. - Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change. - Move client optimizer variables to local_variables, this includes variables such as iteration number. PiperOrigin-RevId: 240198034
…Model. This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same. - Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change. - Move client optimizer variables to local_variables, this includes variables such as iteration number. PiperOrigin-RevId: 240198034
…Model. This is a workaround for issue #258, which uncovered that tf.keras.Model.weights and tf.keras.Model.get_weights() are not ordered the same. - Add a toy example model that uses batch norm (includes non-trainable variables), which will fail without this change. - Move client optimizer variables to local_variables, this includes variables such as iteration number. PiperOrigin-RevId: 241075678
(Note: this checked-in to source, but isn't in the pip package yet. The next release will be in a week or two)
The code would look something like:
Please also be aware that BatchNorm has trainable variables (beta, gamma) and non-trainable variables (mean, variance). How these should be aggregated at the global model is still an open question, and likely the vanilla FedAvg algorithm in the FL API may not be doing what is desired. |
Sounds good! Thank you very much for the prompt action. I think this issue can be closed now. |
Hello
I appreciate the help with the following error:
|
model.set_weights()
function from Keras is not working if the network consists of batch normalization layer(s). I looked into the issue, I believe, it is due to having an incorrect order of array elements as returned bytff.learning.keras_weights_from_tff_weights(state.model)
function which does not match the output ofmodel.get_weights()
. Also, the underlying reason could be that thestate.model
contains separate tuple for trainable and non-trainable weights.The text was updated successfully, but these errors were encountered: