Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ py_strict_test(
# numpy dep1,
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:test_utils",
"//tensorflow_model_optimization/python/core/clustering/keras/experimental:cluster",
],
)

Expand All @@ -228,5 +229,6 @@ py_strict_test(
":cluster",
":cluster_config",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/clustering/keras/experimental",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@
# ==============================================================================
"""Distributed clustering test."""

from absl.testing import parameterized
import itertools
import unittest

import numpy as np
import tensorflow as tf

from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
from absl.testing import parameterized
from tensorflow_model_optimization.python.core.clustering.keras import (
cluster, cluster_config, cluster_wrapper)
from tensorflow_model_optimization.python.core.clustering.keras.experimental import \
cluster as experimental_cluster
from tensorflow_model_optimization.python.core.keras import \
test_utils as keras_test_utils

keras = tf.keras
CentroidInitialization = cluster_config.CentroidInitialization
Expand All @@ -30,23 +34,37 @@
def _distribution_strategies():
return [tf.distribute.MirroredStrategy()]

def _clustering_strategies():
return [
{
'number_of_clusters': 2,
'cluster_centroids_init': CentroidInitialization.LINEAR,
'preserve_sparsity': False
},
{
'number_of_clusters': 3,
'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
'preserve_sparsity': True
}
]

class ClusterDistributedTest(tf.test.TestCase, parameterized.TestCase):
"""Distributed tests for clustering."""

def setUp(self):
super(ClusterDistributedTest, self).setUp()
self.params = {
'number_of_clusters': 2,
'cluster_centroids_init': CentroidInitialization.LINEAR
}

@parameterized.parameters(_distribution_strategies())
def testClusterSimpleDenseModel(self, distribution):
@parameterized.parameters(
*itertools.product(
_distribution_strategies(),
_clustering_strategies()
)
)
def testClusterSimpleDenseModel(self, distribution, clustering):
"""End-to-end test."""
with distribution.scope():
model = cluster.cluster_weights(
keras_test_utils.build_simple_dense_model(), **self.params)
model = experimental_cluster.cluster_weights(
keras_test_utils.build_simple_dense_model(), **clustering)
model.compile(
loss='categorical_crossentropy',
optimizer='sgd',
Expand All @@ -64,9 +82,11 @@ def testClusterSimpleDenseModel(self, distribution):
weights_as_list = stripped_model.layers[0].kernel.numpy().reshape(
-1,).tolist()
unique_weights = set(weights_as_list)
self.assertLessEqual(len(unique_weights), self.params['number_of_clusters'])
self.assertLessEqual(len(unique_weights), clustering["number_of_clusters"])

@parameterized.parameters(_distribution_strategies())
@parameterized.parameters(
_distribution_strategies()
)
def testAssociationValuesPerReplica(self, distribution):
"""Verifies that associations of weights are updated per replica."""
assert tf.distribute.get_replica_context() is not None
Expand All @@ -76,8 +96,9 @@ def testAssociationValuesPerReplica(self, distribution):
output_shape = (2, 8)
l = cluster_wrapper.ClusterWeights(
keras.layers.Dense(8, input_shape=input_shape),
number_of_clusters=self.params['number_of_clusters'],
cluster_centroids_init=self.params['cluster_centroids_init'])
number_of_clusters=2,
cluster_centroids_init=CentroidInitialization.LINEAR
)
l.build(input_shape)

clusterable_weights = l.layer.get_clusterable_weights()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def _verify_tflite(tflite_file, x_test):
interpreter.invoke()
interpreter.get_tensor(output_index)

@staticmethod
def _get_number_of_unique_weights(stripped_model, layer_nr, weight_name):
layer = stripped_model.layers[layer_nr]
weight = getattr(layer, weight_name)
weights_as_list = weight.numpy().flatten()
nr_of_unique_weights = len(set(weights_as_list))
return nr_of_unique_weights

@keras_parameterized.run_all_keras_modes
def testValuesRemainClusteredAfterTraining(self):
"""Verifies that training a clustered model does not destroy the clusters."""
Expand All @@ -150,73 +158,59 @@ def testValuesRemainClusteredAfterTraining(self):
unique_weights = set(weights_as_list)
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])


@keras_parameterized.run_all_keras_modes
def testSparsityIsPreservedDuringTraining(self):
# Set a specific random seed to ensure that we get some null weights to
# test sparsity preservation with.
"""Set a specific random seed to ensure that we get some null weights
to test sparsity preservation with."""
tf.random.set_seed(1)

# Verifies that training a clustered model does not destroy the sparsity of
# the weights.
# Verifies that training a clustered model with null weights in it
# does not destroy the sparsity of the weights.
original_model = keras.Sequential([
layers.Dense(5, input_shape=(5,)),
layers.Dense(5),
layers.Flatten(),
])

# Using a mininum number of centroids to make it more likely that some
# weights will be zero.
# Reset the kernel weights to reflect potential zero drifting of
# the cluster centroids
first_layer_weights = original_model.layers[0].get_weights()
first_layer_weights[0][:][0:2] = 0.0
first_layer_weights[0][:][3] = [-0.13, -0.08, -0.05, 0.005, 0.13]
first_layer_weights[0][:][4] = [-0.13, -0.08, -0.05, 0.005, 0.13]
original_model.layers[0].set_weights(first_layer_weights)
clustering_params = {
"number_of_clusters": 3,
"number_of_clusters": 6,
"cluster_centroids_init": CentroidInitialization.LINEAR,
"preserve_sparsity": True
}

clustered_model = experimental_cluster.cluster_weights(
original_model, **clustering_params)

stripped_model_before_tuning = cluster.strip_clustering(clustered_model)
weights_before_tuning = stripped_model_before_tuning.layers[0].kernel
non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning)

nr_of_unique_weights_before = self._get_number_of_unique_weights(
stripped_model_before_tuning, 0, 'kernel')
clustered_model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer="adam",
metrics=["accuracy"],
)
clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1)

clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=100)
stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
weights_after_tuning = stripped_model_after_tuning.layers[0].kernel
non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
weights_as_list_after_tuning = weights_after_tuning.numpy().reshape(
-1,).tolist()
unique_weights_after_tuning = set(weights_as_list_after_tuning)

nr_of_unique_weights_after = self._get_number_of_unique_weights(
stripped_model_after_tuning, 0, 'kernel')
# Check after sparsity-aware clustering, despite zero centroid can drift,
# the final number of unique weights remains the same
self.assertLessEqual(nr_of_unique_weights_after, nr_of_unique_weights_before)
# Check that the null weights stayed the same before and after tuning.
# There might be new weights that become zeros but sparsity-aware
# clustering preserves the original null weights in the original positions
# of the weight array
self.assertTrue(
np.array_equal(non_zero_weight_indices_before_tuning,
non_zero_weight_indices_after_tuning))

np.array_equal(first_layer_weights[0][:][0:2],
weights_after_tuning[:][0:2]))
# Check that the number of unique weights matches the number of clusters.
self.assertLessEqual(
len(unique_weights_after_tuning), self.params["number_of_clusters"])

@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def testEndToEndSequential(self):
"""Test End to End clustering - sequential model."""
original_model = keras.Sequential([
layers.Dense(5, input_shape=(5,)),
layers.Dense(5),
])

def clusters_check(stripped_model):
# dense layer
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
unique_weights = set(weights_as_list)
self.assertLessEqual(
len(unique_weights), self.params["number_of_clusters"])

self.end_to_end_testing(original_model, clusters_check)
nr_of_unique_weights_after,
clustering_params["number_of_clusters"])

@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def testEndToEndFunctional(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def __init__(self,
# Stores the pairs of weight names and their respective sparsity masks
self.sparsity_masks = {}

# Stores the pairs of weight names and the zero centroids
self.zero_idx = {}

# Map weight names to original clusterable weights variables
# Those weights will still be updated during backpropagation
self.original_clusterable_weights = {}
Expand Down Expand Up @@ -199,10 +202,32 @@ def build(self, input_shape):
pulling_indices, original_weight))
self.sparsity_masks[weight_name] = (
tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32))
# If the model is pruned (which we suppose), this is approximately zero
self.zero_idx[weight_name] = tf.argmin(
tf.abs(self.cluster_centroids[weight_name]), axis=-1)

def update_clustered_weights_associations(self):
for weight_name, original_weight in self.original_clusterable_weights.items(
):

if self.preserve_sparsity:
# Set the smallest centroid to zero to force sparsity
# and avoid extra cluster from forming
zero_idx_mask = (
tf.cast(tf.math.not_equal(
self.cluster_centroids[weight_name],
self.cluster_centroids[weight_name][self.zero_idx[weight_name]]),
dtype=tf.float32)
)
self.cluster_centroids[weight_name].assign(
tf.math.multiply(self.cluster_centroids[weight_name],
zero_idx_mask))
# During training, the original zero weights can drift slightly.
# We want to prevent this by forcing them to stay zero at the places
# where they were originally zero to begin with.
original_weight = tf.math.multiply(original_weight,
self.sparsity_masks[weight_name])

# Update pulling indices (cluster associations)
pulling_indices = (
self.clustering_algorithms[weight_name].get_pulling_indices(
Expand All @@ -214,11 +239,6 @@ def update_clustered_weights_associations(self):
self.clustering_algorithms[weight_name].get_clustered_weight(
pulling_indices, original_weight))

if self.preserve_sparsity:
# Apply the sparsity mask to the clustered weights
clustered_weights = tf.math.multiply(clustered_weights,
self.sparsity_masks[weight_name])

# Replace the weights with their clustered counterparts
setattr(self.layer, weight_name, clustered_weights)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# ==============================================================================
"""Tests for a simple convnet with clusterable layer on the MNIST dataset."""

from absl.testing import parameterized
import tensorflow as tf

from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster

tf.random.set_seed(42)

Expand Down Expand Up @@ -63,19 +65,21 @@ def _train_model(model):
model.fit(x_train, y_train, epochs=EPOCHS)


def _cluster_model(model, number_of_clusters):
def _cluster_model(model, number_of_clusters, preserve_sparsity=False):

(x_train, y_train), _ = _get_dataset()

clustering_params = {
'number_of_clusters':
number_of_clusters,
'cluster_centroids_init':
cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS
cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS,
'preserve_sparsity':
preserve_sparsity,
}

# Cluster model
clustered_model = cluster.cluster_weights(model, **clustering_params)
clustered_model = experimental_cluster.cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning
# clustered model
Expand Down Expand Up @@ -106,13 +110,27 @@ def _get_number_of_unique_weights(stripped_model, layer_nr, weight_name):

return nr_of_unique_weights

def _deepcopy_model(model):
model_copy = keras.models.clone_model(model)
model_copy.set_weights(model.get_weights())
return model_copy

class FunctionalTest(tf.test.TestCase):
class FunctionalTest(tf.test.TestCase, parameterized.TestCase):

def testMnist(self):
"""In this test we test that 'kernel' weights are clustered."""
def setUp(self):
model = _build_model()
_train_model(model)
self.model = model
self.dataset = _get_dataset()

@parameterized.parameters(
(False),
(True),
)
def testMnist(self, preserve_sparisty):
"""In this test we test that 'kernel' weights are clustered."""
model = self.model
_, (x_test, y_test) = self.dataset

# Checks that number of original weights('kernel') is greater than the
# number of clusters
Expand All @@ -123,12 +141,11 @@ def testMnist(self):
nr_of_bias_weights = _get_number_of_unique_weights(model, -1, 'bias')
self.assertGreater(nr_of_bias_weights, NUMBER_OF_CLUSTERS)

_, (x_test, y_test) = _get_dataset()

results_original = model.evaluate(x_test, y_test)
self.assertGreater(results_original[1], 0.8)

clustered_model = _cluster_model(model, NUMBER_OF_CLUSTERS)
model_copy = _deepcopy_model(model)
clustered_model = _cluster_model(model_copy, NUMBER_OF_CLUSTERS, preserve_sparisty)

results = clustered_model.evaluate(x_test, y_test)

Expand Down