Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,6 @@ py_test(
deps = [
":cluster",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:compat",
],
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,9 @@ def _strip_clustering_wrapper(layer):
# If the value was not clustered(e.g. bias), we still store a valid
# reference to the tensor. We use this reference to get the value
new_weight_value = k.batch_get_value([weight])[0]
layer.layer.add_weight(
name=name,
shape=new_weight_value.shape,
initializer=initializers.Constant(new_weight_value),
trainable=True
)
setattr(layer.layer,
name,
k.variable(new_weight_value, name=name))
# When all weights are filled with the values, just return the underlying
# layer since it is now fully autonomous from its wrapper
return layer.layer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

import numpy as np
import tensorflow as tf

import tempfile
from absl.testing import parameterized
from tensorflow.python.keras import keras_parameterized

from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.keras import compat
import os

keras = tf.keras
layers = keras.layers
Expand All @@ -30,55 +32,105 @@
CentroidInitialization = cluster_config.CentroidInitialization

class ClusterIntegrationTest(test.TestCase, parameterized.TestCase):
"""Integration tests for clustering."""

@keras_parameterized.run_all_keras_modes
def testValuesRemainClusteredAfterTraining(self):
"""
Verifies that training a clustered model does not destroy the clusters.
Integration tests for clustering.
"""
number_of_clusters = 10
original_model = keras.Sequential([
layers.Dense(2, input_shape=(2,)),
layers.Dense(2),
])

clustered_model = cluster.cluster_weights(
original_model,
number_of_clusters=number_of_clusters,
cluster_centroids_init=CentroidInitialization.LINEAR
)

clustered_model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)

def dataset_generator():
x_train = np.array([
[0, 1],
[2, 0],
[0, 3],
[4, 1],
[5, 1],
])
y_train = np.array([
[0, 1],
[1, 0],
[1, 0],
[0, 1],
[0, 1],
])
for x, y in zip(x_train, y_train):
yield np.array([x]), np.array([y])

clustered_model.fit_generator(dataset_generator(), steps_per_epoch=1)
stripped_model = cluster.strip_clustering(clustered_model)
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
unique_weights = set(weights_as_list)
self.assertLessEqual(len(unique_weights), number_of_clusters)


if __name__ == '__main__':
test.main()
def setUp(self):
self.params = {
"number_of_clusters": 8,
"cluster_centroids_init": CentroidInitialization.LINEAR,
}

self.x_train = np.array(
[[0.0, 1.0], [2.0, 0.0], [0.0, 3.0], [4.0, 1.0], [5.0, 1.0]],
dtype="float32",
)

self.y_train = np.array(
[[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],
dtype="float32",
)

def dataset_generator(self):
for x, y in zip(self.x_train, self.y_train):
yield np.array([x]), np.array([y])

@staticmethod
def _verify_tflite(tflite_file, x_test):
interpreter = tf.lite.Interpreter(model_path=tflite_file)
interpreter.allocate_tensors()
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
x = x_test[0]
x = x.reshape((1,) + x.shape)
interpreter.set_tensor(input_index, x)
interpreter.invoke()
interpreter.get_tensor(output_index)

@keras_parameterized.run_all_keras_modes
def testValuesRemainClusteredAfterTraining(self):

"""
Verifies that training a clustered model does not destroy the clusters.
"""
original_model = keras.Sequential(
[layers.Dense(2, input_shape=(2,)), layers.Dense(2),]
)

clustered_model = cluster.cluster_weights(original_model, **self.params)

clustered_model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer="adam",
metrics=["accuracy"],
)

clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
stripped_model = cluster.strip_clustering(clustered_model)
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
unique_weights = set(weights_as_list)
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])

@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def testEndToEnd(self):

"""
Test End to End clustering.
"""
original_model = keras.Sequential(
[layers.Dense(2, input_shape=(2,)), layers.Dense(2),]
)

clustered_model = cluster.cluster_weights(original_model, **self.params)

clustered_model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer="adam",
metrics=["accuracy"],
)

clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
stripped_model = cluster.strip_clustering(clustered_model)

_, tflite_file = tempfile.mkstemp(".tflite")
_, keras_file = tempfile.mkstemp(".h5")

if not compat.is_v1_apis():
converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model)
else:
tf.keras.models.save_model(stripped_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)

converter.experimental_new_converter = True
tflite_model = converter.convert()
with open(tflite_file, "wb") as f:
f.write(tflite_model)

self._verify_tflite(tflite_file, self.x_train)

os.remove(keras_file)
Copy link

@alanchiao alanchiao Jun 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to cleanup in a followup PR: os.remove should not be necessary since tempfile will automatically delete it.

os.remove(tflite_file)

if __name__ == "__main__":
test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,28 @@ def testClusterWeightsStrippedWeights(self):
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
self.assertEqual(len(stripped_model.get_weights()), cluster_weight_length)

@keras_parameterized.run_all_keras_modes
def testStrippedKernel(self):
"""
Verifies that stripping the clustering wrappers from a functional model
restores the layers kernel and the layers weight array to the new clustered weight value .
"""
i1 = keras.Input(shape=(1, 1, 1))
x1 = layers.Conv2D(1, 1)(i1)
outputs = x1
model = keras.Model(inputs=[i1], outputs=outputs)

clustered_model = cluster.cluster_weights(model, **self.params)
clustered_conv2d_layer = clustered_model.layers[1]
clustered_kernel = clustered_conv2d_layer.layer.kernel
stripped_model = cluster.strip_clustering(clustered_model)
stripped_conv2d_layer = stripped_model.layers[1]

self.assertEqual(self._count_clustered_layers(stripped_model), 0)
self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel)
self.assertEqual(stripped_conv2d_layer.kernel,
stripped_conv2d_layer.weights[0])

@keras_parameterized.run_all_keras_modes
def testStripSelectivelyClusteredFunctionalModel(self):
"""
Expand Down