Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_model_optimization/python/core/clustering/keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,20 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":cluster_config",
":cluster_wrapper",
":clustering_centroids",
":clustering_registry",
],
)

py_library(
name = "cluster_config",
srcs = ["cluster_config.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)

py_library(
name = "clustering_registry",
srcs = ["clustering_registry.py"],
Expand All @@ -51,13 +59,19 @@ py_library(
srcs = ["clustering_centroids.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":cluster_config",
],
)

py_library(
name = "cluster_wrapper",
srcs = ["cluster_wrapper.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":cluster_config",
],
)

py_test(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from tensorflow import keras
from tensorflow.keras import initializers

from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids

k = keras.backend
CustomObjectScope = keras.utils.CustomObjectScope
Expand Down Expand Up @@ -47,7 +48,7 @@ def cluster_scope():
"""
return CustomObjectScope(
{
'ClusterWeights': cluster_wrapper.ClusterWeights
'ClusterWeights' : cluster_wrapper.ClusterWeights
}
)

Expand Down Expand Up @@ -79,7 +80,8 @@ def cluster_weights(to_cluster,
```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': 'density-based'
'cluster_centroids_init':
cluster_config.CentroidInitialization.DENSITY_BASED
}

clustered_model = cluster_weights(original_model, **clustering_params)
Expand All @@ -90,7 +92,8 @@ def cluster_weights(to_cluster,
```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': 'density-based'
'cluster_centroids_init':
cluster_config.CentroidInitialization.DENSITY_BASED
}

model = keras.Sequential([
Expand All @@ -105,15 +108,16 @@ def cluster_weights(to_cluster,
number_of_clusters: the number of cluster centroids to form when
clustering a layer/model. For example, if number_of_clusters=8 then only
8 unique values will be used in each weight array.
cluster_centroids_init: how to initialize the cluster centroids.
cluster_centroids_init: enum value that determines how the cluster
centroids will be initialized.
Can have following values:
1. 'random' : centroids are sampled using the uniform distribution
1. RANDOM : centroids are sampled using the uniform distribution
between the minimum and maximum weight values in a given layer
2. 'density-based' : density-based sampling. First, cumulative
2. DENSITY_BASED : density-based sampling. First, cumulative
distribution function is built for weights, then y-axis is evenly
spaced into number_of_clusters regions. After this the corresponding x
values are obtained and used to initialize clusters centroids.
3. 'linear' : cluster centroids are evenly spaced between the minimum
3. LINEAR : cluster centroids are evenly spaced between the minimum
and maximum values of a given weight
**kwargs: Additional keyword arguments to be passed to the keras layer.
Ignored when to_cluster is not a keras layer.
Expand All @@ -127,8 +131,8 @@ def cluster_weights(to_cluster,
"""
if not clustering_centroids.CentroidsInitializerFactory.\
init_is_supported(cluster_centroids_init):
raise ValueError("cluster centroids can only be one of three values: "
"random, density-based, linear")
raise ValueError("Cluster centroid initialization {} not supported".\
format(cluster_centroids_init))

def _add_clustering_wrapper(layer):
if isinstance(layer, cluster_wrapper.ClusterWeights):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration classes for clustering."""

import enum


class CentroidInitialization(str, enum.Enum):
LINEAR = "LINEAR"
RANDOM = "RANDOM"
DENSITY_BASED = "DENSITY_BASED"
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

from absl.testing import parameterized
from tensorflow.python.keras import keras_parameterized

from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config

keras = tf.keras
layers = keras.layers
test = tf.test

CentroidInitialization = cluster_config.CentroidInitialization

class ClusterIntegrationTest(test.TestCase, parameterized.TestCase):
"""Integration tests for clustering."""
Expand All @@ -43,7 +46,7 @@ def testValuesRemainClusteredAfterTraining(self):
clustered_model = cluster.cluster_weights(
original_model,
number_of_clusters=number_of_clusters,
cluster_centroids_init='linear'
cluster_centroids_init=CentroidInitialization.LINEAR
)

clustered_model.compile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tensorflow.python.keras import keras_parameterized

from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
Expand Down Expand Up @@ -76,7 +77,8 @@ def setUp(self):
self.model = keras.Sequential()
self.params = {
'number_of_clusters': 8,
'cluster_centroids_init': 'density-based'
'cluster_centroids_init':
cluster_config.CentroidInitialization.DENSITY_BASED
}

def _build_clustered_layer_model(self, layer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from tensorflow.keras import initializers

from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
Expand Down Expand Up @@ -273,7 +274,8 @@ def from_config(cls, config, custom_objects=None):
number_of_clusters = config.pop('number_of_clusters')
cluster_centroids_init = config.pop('cluster_centroids_init')
config['number_of_clusters'] = number_of_clusters
config['cluster_centroids_init'] = cluster_centroids_init
config['cluster_centroids_init'] = cluster_config.CentroidInitialization(
cluster_centroids_init)

from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
layer = deserialize_layer(config.pop('layer'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from absl.testing import parameterized

from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
Expand All @@ -30,7 +31,7 @@
layers = keras.layers
test = tf.test

layers = keras.layers
CentroidInitialization = cluster_config.CentroidInitialization
ClusterRegistry = clustering_registry.ClusteringRegistry
ClusteringLookupRegistry = clustering_registry.ClusteringLookupRegistry

Expand All @@ -55,28 +56,34 @@ def testCannotBeInitializedWithNonLayerObject(self):
not an instance of keras.layers.Layer.
"""
with self.assertRaises(ValueError):
cluster_wrapper.ClusterWeights({
'this': 'is not a Layer instance'
}, number_of_clusters=13, cluster_centroids_init='linear')
cluster_wrapper.ClusterWeights(
{'this': 'is not a Layer instance'},
number_of_clusters=13,
cluster_centroids_init=CentroidInitialization.LINEAR
)

def testCannotBeInitializedWithNonClusterableLayer(self):
"""
Verifies that ClusterWeights cannot be initialized with a non-clusterable
custom layer.
"""
with self.assertRaises(ValueError):
cluster_wrapper.ClusterWeights(NonClusterableLayer(10),
number_of_clusters=13,
cluster_centroids_init='linear')
cluster_wrapper.ClusterWeights(
NonClusterableLayer(10),
number_of_clusters=13,
cluster_centroids_init=CentroidInitialization.LINEAR
)

def testCanBeInitializedWithClusterableLayer(self):
"""
Verifies that ClusterWeights can be initialized with a built-in clusterable
layer.
"""
l = cluster_wrapper.ClusterWeights(layers.Dense(10),
number_of_clusters=13,
cluster_centroids_init='linear')
l = cluster_wrapper.ClusterWeights(
layers.Dense(10),
number_of_clusters=13,
cluster_centroids_init=CentroidInitialization.LINEAR
)
self.assertIsInstance(l, cluster_wrapper.ClusterWeights)

def testCannotBeInitializedWithNonIntegerNumberOfClusters(self):
Expand All @@ -85,19 +92,23 @@ def testCannotBeInitializedWithNonIntegerNumberOfClusters(self):
provided for the number of clusters.
"""
with self.assertRaises(ValueError):
cluster_wrapper.ClusterWeights(layers.Dense(10),
number_of_clusters="13",
cluster_centroids_init='linear')
cluster_wrapper.ClusterWeights(
layers.Dense(10),
number_of_clusters="13",
cluster_centroids_init=CentroidInitialization.LINEAR
)

def testCannotBeInitializedWithFloatNumberOfClusters(self):
"""
Verifies that ClusterWeights cannot be initialized with a decimal value
provided for the number of clusters.
"""
with self.assertRaises(ValueError):
cluster_wrapper.ClusterWeights(layers.Dense(10),
number_of_clusters=13.4,
cluster_centroids_init='linear')
cluster_wrapper.ClusterWeights(
layers.Dense(10),
number_of_clusters=13.4,
cluster_centroids_init=CentroidInitialization.LINEAR
)

@parameterized.parameters(
(0),
Expand All @@ -111,34 +122,47 @@ def testCannotBeInitializedWithNumberOfClustersLessThanTwo(
clusters.
"""
with self.assertRaises(ValueError):
cluster_wrapper.ClusterWeights(layers.Dense(10),
number_of_clusters=number_of_clusters,
cluster_centroids_init='linear')
cluster_wrapper.ClusterWeights(
layers.Dense(10),
number_of_clusters=number_of_clusters,
cluster_centroids_init=CentroidInitialization.LINEAR
)

def testCanBeInitializedWithAlreadyClusterableLayer(self):
"""
Verifies that ClusterWeights can be initialized with a custom clusterable
layer.
"""
layer = AlreadyClusterableLayer(10)
l = cluster_wrapper.ClusterWeights(layer,
number_of_clusters=13,
cluster_centroids_init='linear')
l = cluster_wrapper.ClusterWeights(
layer,
number_of_clusters=13,
cluster_centroids_init=CentroidInitialization.LINEAR
)
self.assertIsInstance(l, cluster_wrapper.ClusterWeights)

def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self):
"""
Verifies that the ClusterWeights instance created from a layer that has
a batch shape attribute, will also have this attribute.
"""
l = cluster_wrapper.ClusterWeights(layers.Dense(10, input_shape=(10,)),
number_of_clusters=13,
cluster_centroids_init='linear')
l = cluster_wrapper.ClusterWeights(
layers.Dense(10, input_shape=(10,)),
number_of_clusters=13,
cluster_centroids_init=CentroidInitialization.LINEAR
)
self.assertTrue(hasattr(l, '_batch_input_shape'))

# Makes it easier to test all possible parameters combinations.
@parameterized.parameters(
*itertools.product(range(2, 16, 4), ('linear', 'random', 'density-based'))
*itertools.product(
range(2, 16, 4),
(
CentroidInitialization.LINEAR,
CentroidInitialization.RANDOM,
CentroidInitialization.DENSITY_BASED
)
)
)
def testValuesAreClusteredAfterStripping(self,
number_of_clusters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import six
import tensorflow as tf

k = tf.keras.backend
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config

k = tf.keras.backend
CentroidInitialization = cluster_config.CentroidInitialization

@six.add_metaclass(abc.ABCMeta)
class AbstractCentroidsInitialisation:
Expand Down Expand Up @@ -187,9 +189,10 @@ class CentroidsInitializerFactory:
reflect new methods available.
"""
_initialisers = {
'linear': LinearCentroidsInitialisation,
'random': RandomCentroidsInitialisation,
'density-based': DensityBasedCentroidsInitialisation
CentroidInitialization.LINEAR : LinearCentroidsInitialisation,
CentroidInitialization.RANDOM : RandomCentroidsInitialisation,
CentroidInitialization.DENSITY_BASED :
DensityBasedCentroidsInitialisation
}

@classmethod
Expand All @@ -199,9 +202,11 @@ def init_is_supported(cls, init_method):
@classmethod
def get_centroid_initializer(cls, init_method):
"""
:param init_method: a string representation of the init methods requested
:return: A concrete implementation of AbstractCentroidsInitialisation
:raises: ValueError if the string representation is not recognised
:param init_method: a CentroidInitialization value representing the init
method requested
:return: A concrete implementation of AbstractCentroidsInitialisation
:raises: ValueError if the requested centroid initialization method is not
recognised
"""
if not cls.init_is_supported(init_method):
raise ValueError(
Expand Down
Loading