-
Notifications
You must be signed in to change notification settings - Fork 332
[clustering] Clusterable layer API #616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and apologies for taking so long to review!
Can you please tidy up the description? The formatting is a bit off and makes it difficult to read. Also, the code snippet for the second use case does not include get_clusterable_algorithm()
.
cluster_wrapper.ClusterWeights(keras_custom_layer, | ||
**self.params) | ||
|
||
>>>>>>> 8fe29ec... MLTOOLS-1031 Customerable layer API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please, remove!
pass | ||
|
||
|
||
class MyCustomerableLayer(keras.layers.Dense, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: MyClusterableLayer instead? There are also similar instances in other part of the PR.
layers.Embedding: {'embeddings': DenseWeightsCA}, | ||
layers.LocallyConnected1D: {'kernel': ConvolutionalWeightsCA}, | ||
layers.LocallyConnected2D: {'kernel': ConvolutionalWeightsCA}, | ||
layers.Conv1D: {'kernel': ConvolutionalWeightsCA, 'bias': BiasWeightsCA}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't cluster biases by default. Will this change such a behaviour? It'd be good to have a test for this.
@@ -0,0 +1,249 @@ | |||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please rename to mnist_clusterable_layer_test.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Looks good. Just a minor request.
self.assertGreater(nr_of_unique_weights, NUMBER_OF_CLUSTERS) | ||
|
||
# Record the number of unique values of 'bias' | ||
nr_of_bias_weights = _get_number_of_unique_weights(model, -1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, add a similar assertion check for the number of unique bias values to the one for the weights above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, I missed it initially. We need to expose ClusterableLayer in the public API, which will require exposing AbstractClusteringAlgorithm as well. The latter then needs moving out of the clustering_registry.py into its own module, to prevent exposing the implementation guts in the public API, and we should also then clean up clustering_registry.py to remove unused functions, e.g. register_new_implementation(). Any thoughts?
Hi @akarmi Please re-review this PR - all comments are addressed. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
@daverim, as this change affects the public API, please review it as well.
self.assertGreater(nr_of_unique_weights, NUMBER_OF_CLUSTERS) | ||
|
||
# Record the number of unique values of 'bias' | ||
nr_of_bias_weights = _get_number_of_unique_weights(model, -1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
@daverim, any feedback? - we would like to merge this now. |
PiperOrigin-RevId: 368158644
Change-Id: I51ce035855ff8e82c339f1cd260b7f5891050ab7
Hi @daverim I updated this PR with the merged changes. I checked the test that is left in this PR locally and it works as expected. |
This PR extends our support for clustering the following two user cases:
In this case the user needs to create a new layer derived from Dense and ClusterableLayer and
provide what is needed to be clustered in the function get_clusterable_weights.
To be able to cluster it, the user needs to derive it from ClusterableLayer and provides two functions:
get_clusterable_weights, which specifies what should be clustered as in the first user case,
get_clusterable_algorithm, which is needed to specify the layout of the weight tensor.
The example is given below:
, where
Both these user cases are added as a mnist test in the new test file: mnist_clusterable_layer_test.py.