Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*.h5
bazel-*
__pycache__


106 changes: 90 additions & 16 deletions tensorflow_model_optimization/python/core/clustering/keras/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.
# ==============================================================================
"""Clustering API functions for Keras models."""
import distutils.version

import tensorflow as tf
from tensorflow import keras

from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
Expand All @@ -24,6 +26,11 @@
Layer = keras.layers.Layer
InputLayer = keras.layers.InputLayer

# From tf version 2.4.0 onwards the internal variable
# _layers has been renamed to _self_tracked_trackables.
# This variable is the only way to add cluster wrapper
# to layers of a subclassed model.
TF_VERSION_LAYERS = "2.4.0"

def cluster_scope():
"""Provides a scope in which Clustered layers and models can be deserialized.
Expand All @@ -50,6 +57,27 @@ def cluster_scope():
}
)

def _type_model(model):
""" Auxiliary function to check type of the model:
Sequential/Functional, Layer or Subclassed.

Args:
model : provided model to check

Returns:
[tuple]: (is_sequential_or_functional, is_keras_layer, is_subclassed_model)
"""
is_sequential_or_functional = isinstance(
model, keras.Model) and (isinstance(model, keras.Sequential) or
model._is_graph_network)

is_keras_layer = isinstance(
model, keras.layers.Layer) and not isinstance(model, keras.Model)

is_subclassed_model = isinstance(model, keras.Model) and \
not model._is_graph_network

return (is_sequential_or_functional, is_keras_layer, is_subclassed_model)

def cluster_weights(to_cluster,
number_of_clusters,
Expand Down Expand Up @@ -221,8 +249,9 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
cluster_centroids_init))

def _add_clustering_wrapper(layer):
if isinstance(layer, keras.Model):
# Check whether the model is a subclass.
if (isinstance(layer, keras.Model)):
# Check whether the model is subclassed.

# NB: This check is copied from keras.py file in tensorflow.
# There is no available public API to do this check.
# pylint: disable=protected-access
Expand All @@ -248,29 +277,49 @@ def _wrap_list(layers):

return output

if isinstance(to_cluster, keras.Model):
(is_sequential_or_functional, is_keras_layer, is_subclassed_model) =\
_type_model(to_cluster)

if isinstance(to_cluster, list):
return _wrap_list(to_cluster)
elif is_sequential_or_functional:
return keras.models.clone_model(to_cluster,
input_tensors=None,
clone_function=_add_clustering_wrapper)
if isinstance(to_cluster, Layer):
elif is_keras_layer:
return _add_clustering_wrapper(layer=to_cluster)
if isinstance(to_cluster, list):
return _wrap_list(to_cluster)

elif is_subclassed_model:
# If the subclassed model is provided, then
# we add wrappers for all available layers and
# we wrap the whole model, so that augmented
# 'build' and 'call' functions are called.
tf_version = distutils.version.LooseVersion(tf.__version__)
layers_tf_version = distutils.version.LooseVersion(TF_VERSION_LAYERS)
for i, layer in enumerate(to_cluster.submodules):
if tf_version > layers_tf_version:
to_cluster._self_tracked_trackables[i] = _add_clustering_wrapper(layer=layer)
else:
to_cluster._layers[i] = _add_clustering_wrapper(layer=layer)
return cluster_wrapper.WrapperSubclassedModel(to_cluster)
else:
raise ValueError(
' Clustering cannot be applied. You passed '
'an object of type: {input}.'.format(input=to_cluster.__class__.__name__))

def strip_clustering(model):
"""Strips clustering wrappers from the model.
def strip_clustering(to_strip):
"""Strip clustering wrappers from the model.

Once a model has been clustered, this method can be used
to restore the original model with the clustered weights.
to restore the original model or layer with the clustered weights.

Only sequential and functional models are supported for now.
Sequential, functional and subclassed models are supported.

Arguments:
model: A `tf.keras.Model` instance with clustered layers.
to_strip: A `tf.keras.Model` instance with clustered layers or a
`tf.keras.layers.Layer` instance

Returns:
A keras model with clustering wrappers removed.
A keras model or layer with clustering wrappers removed.

Raises:
ValueError: if the model is not a `tf.keras.Model` instance.
Expand All @@ -285,9 +334,11 @@ def strip_clustering(model):
```
The exported_model and the orig_model have the same structure.
"""
if not isinstance(model, keras.Model):
if not isinstance(to_strip, keras.Model) and not isinstance(
to_strip, keras.layers.Layer):
raise ValueError(
'Expected model to be a `tf.keras.Model` instance but got: ', model)
'Expected to_strip to be a `tf.keras.Model` or \
`tf.keras.layers.Layer` instance but got: ', to_strip)

def _strip_clustering_wrapper(layer):
if isinstance(layer, keras.Model):
Expand Down Expand Up @@ -325,7 +376,30 @@ def _strip_clustering_wrapper(layer):
return layer.layer
return layer

(is_sequential_or_functional, is_keras_layer, is_subclassed_model) =\
_type_model(to_strip)

# Just copy the model with the right callback
return keras.models.clone_model(model,
if is_sequential_or_functional:
return keras.models.clone_model(to_strip,
input_tensors=None,
clone_function=_strip_clustering_wrapper)
elif is_keras_layer:
if isinstance(to_strip, keras.layers.Layer):
return _strip_clustering_wrapper(to_strip)
elif is_subclassed_model:
to_strip_model = to_strip.model
tf_version = distutils.version.LooseVersion(tf.__version__)
layers_tf_version = distutils.version.LooseVersion(TF_VERSION_LAYERS)
if tf_version > layers_tf_version:
for i, layer in enumerate(to_strip_model._self_tracked_trackables):
to_strip_model._self_tracked_trackables[i] = _strip_clustering_wrapper(layer=layer)
else:
for i, layer in enumerate(to_strip_model._layers):
to_strip_model._layers[i] = _strip_clustering_wrapper(layer=layer)
return to_strip_model
else:
raise ValueError(
' Strip clustering cannot be applied. You passed '
'an object of type: {input}.'.format(input=to_strip.__class__.__name__))

Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,42 @@
from tensorflow.python.keras import keras_parameterized
from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config

from tensorflow_model_optimization.python.core.keras import compat
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper

from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster


keras = tf.keras
layers = keras.layers
test = tf.test

CentroidInitialization = cluster_config.CentroidInitialization

class SubclassedModel(keras.Model):
"""Subclassed model with one layer."""

def __init__(self):
"""Subclassed model with one dense layer."""
super(SubclassedModel, self).__init__(name='subclass_model')
self.dense_layer = keras.layers.Dense(5, activation='relu')

def call(self, inputs):
return self.dense_layer(inputs)

class SubclassedModelTwoLayers(keras.Model):
"""Subclassed model with two layers."""

def __init__(self):
"""Subclassed model with two layers."""
super(SubclassedModelTwoLayers, self).__init__(name='subclass_model')
self.dense_layer1 = keras.layers.Dense(5, activation='relu')
self.dense_layer2 = keras.layers.Dense(5, activation='softmax')

def call(self, inputs):
x = self.dense_layer1(inputs)
return self.dense_layer2(x)

class ClusterIntegrationTest(test.TestCase, parameterized.TestCase):
"""Integration tests for clustering."""
Expand Down Expand Up @@ -233,6 +261,105 @@ def clusters_check(stripped_model):

self.end_to_end_testing(original_model, clusters_check)

@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def testEndToEndSubclassedModel(self):
"""Test End to End clustering for the subclassed model.
In this test we pass the whole subclassed model for clustering.
We check that the number of weights is less the requested
number of clusters after stripping clustering wrapper.

"""
subclassed_model = SubclassedModel()

clustered_model = cluster.cluster_weights(subclassed_model, **self.params)

clustered_model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer="adam",
metrics=["accuracy"]
)

# The model should be trained a little bit.
clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
stripped_model = cluster.strip_clustering(clustered_model)

nr_unique_weights = len(np.unique(stripped_model.layers[0].\
trainable_weights[0].numpy().flatten()))
self.assertLessEqual(nr_unique_weights, self.params["number_of_clusters"])

@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def testEndToEndSubclassedModelTwoLayers(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test re-produces the approach tested here: #554

"""Test End to End clustering for the subclass model.

This test demonstrates another approach.
All layers that are present in the subclassed model
(see SubclassedModelTwoLayers definition above) are wrapped
manually. The model should be re-build in this case.

We need to strip clustering away manually as well (see how it is
done inside the test).

Clustering is working well and clusters are updated during
training."""
subclassed_model = SubclassedModelTwoLayers()
input_shape = (1, 5)

# We need to build the model
subclassed_model.build(input_shape=input_shape)

# Check that the number of weights is bigger than the number of clusters.
nr_unique_weights = len(np.unique(subclassed_model.layers[0].\
trainable_weights[0].numpy().flatten()))
self.assertGreater(nr_unique_weights, self.params["number_of_clusters"])
nr_unique_weights = len(np.unique(subclassed_model.layers[1].\
trainable_weights[0].numpy().flatten()))
self.assertGreater(nr_unique_weights, self.params["number_of_clusters"])

# Now we apply cluster_weights for each layer.
subclassed_model.dense_layer1 = cluster.cluster_weights(
subclassed_model.dense_layer1, **self.params)
subclassed_model.dense_layer2 = cluster.cluster_weights(
subclassed_model.dense_layer2, **self.params)

# We need to re-build the model again.
subclassed_model.build(input_shape=input_shape)

subclassed_model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer="adam",
metrics=["accuracy"]
)

subclassed_model.fit(x=self.dataset_generator(), steps_per_epoch=1)

# We strip from layers that were wrapped.
subclassed_model.dense_layer1 = cluster.strip_clustering(subclassed_model.dense_layer1)
subclassed_model.dense_layer2 = cluster.strip_clustering(subclassed_model.dense_layer2)

# Checks that the number of unique values is less than the requested
# number of clusters.
nr_unique_weights = len(np.unique(subclassed_model.layers[0].\
trainable_weights[0].numpy().flatten()))
self.assertLessEqual(nr_unique_weights, self.params["number_of_clusters"])
nr_unique_weights = len(np.unique(subclassed_model.layers[1].\
trainable_weights[0].numpy().flatten()))
self.assertLessEqual(nr_unique_weights, self.params["number_of_clusters"])

@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def testEndToEndSubclassedModelAsDeepLayer(self):
"""Test End to End clustering for the model with the layer as a subclass model."""
# This case is not supported currently.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this case will be enabled later once the current approach is approved


internal_model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(5,))])
subclassed_model = SubclassedModel()
original_model = keras.Sequential([
internal_model,
subclassed_model,
])

with self.assertRaisesRegexp(ValueError, "Subclassed models.*"):
self.end_to_end_testing(original_model)

@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def testEndToEndDeepLayer(self):
"""Test End to End clustering for the model with deep layer."""
Expand Down Expand Up @@ -302,3 +429,4 @@ def clusters_check(stripped_model):

if __name__ == "__main__":
test.main()

Loading