diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py index 70cc459cd..ae1edfda3 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py @@ -29,7 +29,9 @@ class CentroidInitialization(str, enum.Enum): axis is evenly spaced into as many regions as many clusters we want to have. After this the corresponding X values are obtained and used to initialize the clusters centroids. + * `KMEANS_PLUS_PLUS`: cluster centroids using the kmeans++ algorithm """ LINEAR = "LINEAR" RANDOM = "RANDOM" DENSITY_BASED = "DENSITY_BASED" + KMEANS_PLUS_PLUS = "KMEANS_PLUS_PLUS" diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py index d24f5a17a..7f32090f2 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids.py @@ -17,12 +17,11 @@ import abc import six import tensorflow as tf - +from tensorflow.python.ops import clustering_ops from tensorflow_model_optimization.python.core.clustering.keras import cluster_config k = tf.keras.backend CentroidInitialization = cluster_config.CentroidInitialization - @six.add_metaclass(abc.ABCMeta) class AbstractCentroidsInitialisation: """ @@ -53,6 +52,20 @@ def get_cluster_centroids(self): self.number_of_clusters) return cluster_centroids +class KmeansPlusPlusCentroidsInitialisation(AbstractCentroidsInitialisation): + """ + Cluster centroids based on kmeans++ algorithm + """ + def get_cluster_centroids(self): + + weights = tf.reshape(self.weights, [-1, 1]) + + cluster_centroids = clustering_ops.kmeans_plus_plus_initialization(weights, + self.number_of_clusters, + seed=9, + num_retries_per_sample=-1) + + return cluster_centroids class RandomCentroidsInitialisation(AbstractCentroidsInitialisation): """ @@ -192,7 +205,9 @@ class CentroidsInitializerFactory: CentroidInitialization.LINEAR : LinearCentroidsInitialisation, CentroidInitialization.RANDOM : RandomCentroidsInitialisation, CentroidInitialization.DENSITY_BASED : - DensityBasedCentroidsInitialisation + DensityBasedCentroidsInitialisation, + CentroidInitialization.KMEANS_PLUS_PLUS : + KmeansPlusPlusCentroidsInitialisation, } @classmethod diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py index 693276b6d..a0fc2052c 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_centroids_test.py @@ -40,6 +40,7 @@ def setUp(self): (CentroidInitialization.LINEAR), (CentroidInitialization.RANDOM), (CentroidInitialization.DENSITY_BASED), + (CentroidInitialization.KMEANS_PLUS_PLUS), ) def testExistingInitsAreSupported(self, init_type): """ @@ -63,6 +64,10 @@ def testNonExistingInitIsNotSupported(self): CentroidInitialization.DENSITY_BASED, clustering_centroids.DensityBasedCentroidsInitialisation ), + ( + CentroidInitialization.KMEANS_PLUS_PLUS, + clustering_centroids.KmeansPlusPlusCentroidsInitialisation + ), ) def testReturnsMethodForExistingInit(self, init_type, method): """ @@ -177,6 +182,30 @@ def testClusterCentroids(self, weights, number_of_clusters, centroids): calc_centroids = K.batch_get_value([dbci.get_cluster_centroids()])[0] self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4) + @parameterized.parameters( + ( + [0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5], + 5, + [3.1, 0., 2., 1., 3.4] + ), + ( + [0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5], + 3, + [3.1, 0., 2.] + ), + ( + [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], + 3, + [6., 1., 8.] + ) + ) + def testKmeanPlusPlusValues(self, weights, number_of_clusters, centroids): + kmci = clustering_centroids.KmeansPlusPlusCentroidsInitialisation( + weights, + number_of_clusters + ) + calc_centroids = K.batch_get_value([kmci.get_cluster_centroids()])[0] + self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4) if __name__ == '__main__': test.main()