Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Module containing clustering code built on Keras abstractions."""
# pylint: disable=g-bad-import-order
from tensorflow_model_optimization.python.core.clustering.keras import experimental

from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_scope
from tensorflow_model_optimization.python.core.clustering.keras.cluster import cluster_weights
from tensorflow_model_optimization.python.core.clustering.keras.cluster import strip_clustering
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module containing experimental clustering code built on Keras abstractions."""
from tensorflow_model_optimization.python.core.clustering.keras.experimental.cluster import cluster_weights
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ py_library(
srcs_version = "PY3",
deps = [
":cluster",
"//tensorflow_model_optimization/python/core/clustering/keras/experimental",
],
)

Expand Down Expand Up @@ -90,6 +91,7 @@ py_test(
visibility = ["//visibility:public"],
deps = [
":cluster",
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
# tensorflow dep1,
],
)
Expand Down Expand Up @@ -146,6 +148,7 @@ py_test(
":cluster",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:compat",
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
],
)

Expand Down
110 changes: 106 additions & 4 deletions tensorflow_model_optimization/python/core/clustering/keras/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def cluster_weights(to_cluster,
```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init':
CentroidInitialization.DENSITY_BASED
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}

clustered_model = cluster_weights(original_model, **clustering_params)
Expand All @@ -92,8 +91,108 @@ def cluster_weights(to_cluster,
```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init':
CentroidInitialization.DENSITY_BASED
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}

model = keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
])
```

Arguments:
to_cluster: A single keras layer, list of keras layers, or a
`tf.keras.Model` instance.
number_of_clusters: the number of cluster centroids to form when
clustering a layer/model. For example, if number_of_clusters=8 then only
8 unique values will be used in each weight array.
cluster_centroids_init: enum value that determines how the cluster
centroids will be initialized.
Can have following values:
1. RANDOM : centroids are sampled using the uniform distribution
between the minimum and maximum weight values in a given layer
2. DENSITY_BASED : density-based sampling. First, cumulative
distribution function is built for weights, then y-axis is evenly
spaced into number_of_clusters regions. After this the corresponding x
values are obtained and used to initialize clusters centroids.
3. LINEAR : cluster centroids are evenly spaced between the minimum
and maximum values of a given weight
preserve_sparsity: optional boolean value that determines whether or not
sparsity preservation will be enforced during training
**kwargs: Additional keyword arguments to be passed to the keras layer.
Ignored when to_cluster is not a keras layer.

Returns:
Layer or model modified to include clustering related metadata.

Raises:
ValueError: if the keras layer is unsupported, or the keras model contains
an unsupported layer.
"""
return _cluster_weights(to_cluster,
number_of_clusters,
cluster_centroids_init,
preserve_sparsity=False,
**kwargs)


def _cluster_weights(to_cluster,
number_of_clusters,
cluster_centroids_init,
preserve_sparsity,
**kwargs):
"""Modify a keras layer or model to be clustered during training (private method).

This function wraps a keras model or layer with clustering functionality
which clusters the layer's weights during training. For examples, using
this with number_of_clusters equals 8 will ensure that each weight tensor has
no more than 8 unique values.

Before passing to the clustering API, a model should already be trained and
show some acceptable performance on the testing/validation sets.

The function accepts either a single keras layer
(subclass of `keras.layers.Layer`), list of keras layers or a keras model
(instance of `keras.models.Model`) and handles them appropriately.

If it encounters a layer it does not know how to handle, it will throw an
error. While clustering an entire model, even a single unknown layer would
lead to an error.

Cluster a model:

```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
'preserve_sparsity': False
}

clustered_model = cluster_weights(original_model, **clustering_params)
```

Cluster a layer:

```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
'preserve_sparsity': False
}

model = keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
])
```

Cluster a layer with sparsity preservation (experimental):

```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
'preserve_sparsity': True
}

model = keras.Sequential([
Expand All @@ -110,6 +209,8 @@ def cluster_weights(to_cluster,
8 unique values will be used in each weight array.
cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
instance that determines how the cluster centroids will be initialized.
preserve_sparsity (experimental): optional boolean value that determines whether or not
sparsity preservation will be enforced during training.
**kwargs: Additional keyword arguments to be passed to the keras layer.
Ignored when to_cluster is not a keras layer.

Expand Down Expand Up @@ -146,6 +247,7 @@ def _add_clustering_wrapper(layer):
return cluster_wrapper.ClusterWeights(layer,
number_of_clusters,
cluster_centroids_init,
preserve_sparsity,
**kwargs)

def _wrap_list(layers):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.keras import compat

from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster

keras = tf.keras
layers = keras.layers
test = tf.test
Expand Down Expand Up @@ -60,10 +62,26 @@ def setUp(self):
dtype="float32",
)

self.x_train2 = np.array(
[[0.0, 1.0, 2.0, 3.0, 4.0], [2.0, 0.0, 2.0, 3.0, 4.0], [0.0, 3.0, 2.0, 3.0, 4.0],
[4.0, 1.0, 2.0, 3.0, 4.0], [5.0, 1.0, 2.0, 3.0, 4.0]],
dtype="float32",
)

self.y_train2 = np.array(
[[0.0, 1.0, 2.0, 3.0, 4.0], [1.0, 0.0, 2.0, 3.0, 4.0], [1.0, 0.0, 2.0, 3.0, 4.0],
[0.0, 1.0, 2.0, 3.0, 4.0], [0.0, 1.0, 2.0, 3.0, 4.0]],
dtype="float32",
)

def dataset_generator(self):
for x, y in zip(self.x_train, self.y_train):
yield np.array([x]), np.array([y])

def dataset_generator2(self):
for x, y in zip(self.x_train2, self.y_train2):
yield np.array([x]), np.array([y])

def end_to_end_testing(self, original_model, clusters_check=None):
"""Test End to End clustering."""

Expand Down Expand Up @@ -128,6 +146,50 @@ def testValuesRemainClusteredAfterTraining(self):
unique_weights = set(weights_as_list)
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])

@keras_parameterized.run_all_keras_modes
def testSparsityIsPreservedDuringTraining(self):
""" Set a specific random seed to ensure that we get some null weights to test sparsity preservation with. """
tf.random.set_seed(1)

"""Verifies that training a clustered model does not destroy the sparsity of the weights."""
original_model = keras.Sequential([
layers.Dense(5, input_shape=(5,)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we guaranteed to have zeros in this model? It may be beneficial to make initial weights deterministic in this test, with the known number and location of zeros.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the latest revision. I've set the random seed to a value that guarantees that some of the weights are zero for that test (now there are 5 null-weights at each test run).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

layers.Dense(5),
])

"""Using a mininum number of centroids to make it more likely that some weights will be zero."""
clustering_params = {
"number_of_clusters": 3,
"cluster_centroids_init": CentroidInitialization.LINEAR,
"preserve_sparsity": True
}

clustered_model = experimental_cluster.cluster_weights(original_model, **clustering_params)

stripped_model_before_tuning = cluster.strip_clustering(clustered_model)
weights_before_tuning = stripped_model_before_tuning.get_weights()[0]
non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning)

clustered_model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer="adam",
metrics=["accuracy"],
)
clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1)

stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
weights_after_tuning = stripped_model_after_tuning.get_weights()[0]
non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
weights_as_list_after_tuning = weights_after_tuning.reshape(-1,).tolist()
unique_weights_after_tuning = set(weights_as_list_after_tuning)

"""Check that the null weights stayed the same before and after tuning."""
self.assertTrue(np.array_equal(non_zero_weight_indices_before_tuning,
non_zero_weight_indices_after_tuning))

"""Check that the number of unique weights matches the number of clusters."""
self.assertLessEqual(len(unique_weights_after_tuning), self.params["number_of_clusters"])

@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def testEndToEndSequential(self):
"""Test End to End clustering - sequential model."""
Expand Down
Loading