-
Notifications
You must be signed in to change notification settings - Fork 336
[Clustering] Support for clustering of a subclassed model. #571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| *.h5 | ||
| bazel-* | ||
| __pycache__ | ||
|
|
||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,14 +24,42 @@ | |
| from tensorflow.python.keras import keras_parameterized | ||
| from tensorflow_model_optimization.python.core.clustering.keras import cluster | ||
| from tensorflow_model_optimization.python.core.clustering.keras import cluster_config | ||
|
|
||
| from tensorflow_model_optimization.python.core.keras import compat | ||
| from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper | ||
|
|
||
| from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster | ||
|
|
||
|
|
||
| keras = tf.keras | ||
| layers = keras.layers | ||
| test = tf.test | ||
|
|
||
| CentroidInitialization = cluster_config.CentroidInitialization | ||
|
|
||
| class SubclassedModel(keras.Model): | ||
| """Subclassed model with one layer.""" | ||
|
|
||
| def __init__(self): | ||
| """Subclassed model with one dense layer.""" | ||
| super(SubclassedModel, self).__init__(name='subclass_model') | ||
| self.dense_layer = keras.layers.Dense(5, activation='relu') | ||
|
|
||
| def call(self, inputs): | ||
| return self.dense_layer(inputs) | ||
|
|
||
| class SubclassedModelTwoLayers(keras.Model): | ||
| """Subclassed model with two layers.""" | ||
|
|
||
| def __init__(self): | ||
| """Subclassed model with two layers.""" | ||
| super(SubclassedModelTwoLayers, self).__init__(name='subclass_model') | ||
| self.dense_layer1 = keras.layers.Dense(5, activation='relu') | ||
| self.dense_layer2 = keras.layers.Dense(5, activation='softmax') | ||
|
|
||
| def call(self, inputs): | ||
| x = self.dense_layer1(inputs) | ||
| return self.dense_layer2(x) | ||
|
|
||
| class ClusterIntegrationTest(test.TestCase, parameterized.TestCase): | ||
| """Integration tests for clustering.""" | ||
|
|
@@ -233,6 +261,105 @@ def clusters_check(stripped_model): | |
|
|
||
| self.end_to_end_testing(original_model, clusters_check) | ||
|
|
||
| @keras_parameterized.run_all_keras_modes(always_skip_v1=True) | ||
| def testEndToEndSubclassedModel(self): | ||
| """Test End to End clustering for the subclassed model. | ||
| In this test we pass the whole subclassed model for clustering. | ||
| We check that the number of weights is less the requested | ||
| number of clusters after stripping clustering wrapper. | ||
|
|
||
| """ | ||
| subclassed_model = SubclassedModel() | ||
|
|
||
| clustered_model = cluster.cluster_weights(subclassed_model, **self.params) | ||
|
|
||
| clustered_model.compile( | ||
| loss=keras.losses.categorical_crossentropy, | ||
| optimizer="adam", | ||
| metrics=["accuracy"] | ||
| ) | ||
|
|
||
| # The model should be trained a little bit. | ||
| clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1) | ||
| stripped_model = cluster.strip_clustering(clustered_model) | ||
|
|
||
| nr_unique_weights = len(np.unique(stripped_model.layers[0].\ | ||
| trainable_weights[0].numpy().flatten())) | ||
| self.assertLessEqual(nr_unique_weights, self.params["number_of_clusters"]) | ||
|
|
||
| @keras_parameterized.run_all_keras_modes(always_skip_v1=True) | ||
| def testEndToEndSubclassedModelTwoLayers(self): | ||
| """Test End to End clustering for the subclass model. | ||
|
|
||
| This test demonstrates another approach. | ||
| All layers that are present in the subclassed model | ||
| (see SubclassedModelTwoLayers definition above) are wrapped | ||
| manually. The model should be re-build in this case. | ||
|
|
||
| We need to strip clustering away manually as well (see how it is | ||
| done inside the test). | ||
|
|
||
| Clustering is working well and clusters are updated during | ||
| training.""" | ||
| subclassed_model = SubclassedModelTwoLayers() | ||
| input_shape = (1, 5) | ||
|
|
||
| # We need to build the model | ||
| subclassed_model.build(input_shape=input_shape) | ||
|
|
||
| # Check that the number of weights is bigger than the number of clusters. | ||
| nr_unique_weights = len(np.unique(subclassed_model.layers[0].\ | ||
| trainable_weights[0].numpy().flatten())) | ||
| self.assertGreater(nr_unique_weights, self.params["number_of_clusters"]) | ||
| nr_unique_weights = len(np.unique(subclassed_model.layers[1].\ | ||
| trainable_weights[0].numpy().flatten())) | ||
| self.assertGreater(nr_unique_weights, self.params["number_of_clusters"]) | ||
|
|
||
| # Now we apply cluster_weights for each layer. | ||
| subclassed_model.dense_layer1 = cluster.cluster_weights( | ||
| subclassed_model.dense_layer1, **self.params) | ||
| subclassed_model.dense_layer2 = cluster.cluster_weights( | ||
| subclassed_model.dense_layer2, **self.params) | ||
|
|
||
| # We need to re-build the model again. | ||
| subclassed_model.build(input_shape=input_shape) | ||
|
|
||
| subclassed_model.compile( | ||
| loss=keras.losses.categorical_crossentropy, | ||
| optimizer="adam", | ||
| metrics=["accuracy"] | ||
| ) | ||
|
|
||
| subclassed_model.fit(x=self.dataset_generator(), steps_per_epoch=1) | ||
|
|
||
| # We strip from layers that were wrapped. | ||
| subclassed_model.dense_layer1 = cluster.strip_clustering(subclassed_model.dense_layer1) | ||
| subclassed_model.dense_layer2 = cluster.strip_clustering(subclassed_model.dense_layer2) | ||
|
|
||
| # Checks that the number of unique values is less than the requested | ||
| # number of clusters. | ||
| nr_unique_weights = len(np.unique(subclassed_model.layers[0].\ | ||
| trainable_weights[0].numpy().flatten())) | ||
| self.assertLessEqual(nr_unique_weights, self.params["number_of_clusters"]) | ||
| nr_unique_weights = len(np.unique(subclassed_model.layers[1].\ | ||
| trainable_weights[0].numpy().flatten())) | ||
| self.assertLessEqual(nr_unique_weights, self.params["number_of_clusters"]) | ||
|
|
||
| @keras_parameterized.run_all_keras_modes(always_skip_v1=True) | ||
| def testEndToEndSubclassedModelAsDeepLayer(self): | ||
| """Test End to End clustering for the model with the layer as a subclass model.""" | ||
| # This case is not supported currently. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this case will be enabled later once the current approach is approved |
||
|
|
||
| internal_model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(5,))]) | ||
| subclassed_model = SubclassedModel() | ||
| original_model = keras.Sequential([ | ||
| internal_model, | ||
| subclassed_model, | ||
| ]) | ||
|
|
||
| with self.assertRaisesRegexp(ValueError, "Subclassed models.*"): | ||
| self.end_to_end_testing(original_model) | ||
|
|
||
| @keras_parameterized.run_all_keras_modes(always_skip_v1=True) | ||
| def testEndToEndDeepLayer(self): | ||
| """Test End to End clustering for the model with deep layer.""" | ||
|
|
@@ -302,3 +429,4 @@ def clusters_check(stripped_model): | |
|
|
||
| if __name__ == "__main__": | ||
| test.main() | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test re-produces the approach tested here: #554