Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class CentroidInitialization(str, enum.Enum):
axis is evenly spaced into as many regions as many clusters we want to
have. After this the corresponding X values are obtained and used to
initialize the clusters centroids.
* `KMEANS_PLUS_PLUS`: cluster centroids using the kmeans++ algorithm
"""
LINEAR = "LINEAR"
RANDOM = "RANDOM"
DENSITY_BASED = "DENSITY_BASED"
KMEANS_PLUS_PLUS = "KMEANS_PLUS_PLUS"
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
import abc
import six
import tensorflow as tf

from tensorflow.python.ops import clustering_ops
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config

k = tf.keras.backend
CentroidInitialization = cluster_config.CentroidInitialization

@six.add_metaclass(abc.ABCMeta)
class AbstractCentroidsInitialisation:
"""
Expand Down Expand Up @@ -53,6 +52,20 @@ def get_cluster_centroids(self):
self.number_of_clusters)
return cluster_centroids

class KmeansPlusPlusCentroidsInitialisation(AbstractCentroidsInitialisation):
"""
Cluster centroids based on kmeans++ algorithm
"""
def get_cluster_centroids(self):

weights = tf.reshape(self.weights, [-1, 1])

cluster_centroids = clustering_ops.kmeans_plus_plus_initialization(weights,
self.number_of_clusters,
seed=9,
num_retries_per_sample=-1)

return cluster_centroids

class RandomCentroidsInitialisation(AbstractCentroidsInitialisation):
"""
Expand Down Expand Up @@ -192,7 +205,9 @@ class CentroidsInitializerFactory:
CentroidInitialization.LINEAR : LinearCentroidsInitialisation,
CentroidInitialization.RANDOM : RandomCentroidsInitialisation,
CentroidInitialization.DENSITY_BASED :
DensityBasedCentroidsInitialisation
DensityBasedCentroidsInitialisation,
CentroidInitialization.KMEANS_PLUS_PLUS :
KmeansPlusPlusCentroidsInitialisation,
}

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def setUp(self):
(CentroidInitialization.LINEAR),
(CentroidInitialization.RANDOM),
(CentroidInitialization.DENSITY_BASED),
(CentroidInitialization.KMEANS_PLUS_PLUS),
)
def testExistingInitsAreSupported(self, init_type):
"""
Expand All @@ -63,6 +64,10 @@ def testNonExistingInitIsNotSupported(self):
CentroidInitialization.DENSITY_BASED,
clustering_centroids.DensityBasedCentroidsInitialisation
),
(
CentroidInitialization.KMEANS_PLUS_PLUS,
clustering_centroids.KmeansPlusPlusCentroidsInitialisation
),
)
def testReturnsMethodForExistingInit(self, init_type, method):
"""
Expand Down Expand Up @@ -177,6 +182,30 @@ def testClusterCentroids(self, weights, number_of_clusters, centroids):
calc_centroids = K.batch_get_value([dbci.get_cluster_centroids()])[0]
self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4)

@parameterized.parameters(
(
[0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5],
5,
[3.1, 0., 2., 1., 3.4]
),
(
[0, 1, 2, 3, 3.1, 3.2, 3.3, 3.4, 3.5],
3,
[3.1, 0., 2.]
),
(
[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
3,
[6., 1., 8.]
)
)
def testKmeanPlusPlusValues(self, weights, number_of_clusters, centroids):
kmci = clustering_centroids.KmeansPlusPlusCentroidsInitialisation(
weights,
number_of_clusters
)
calc_centroids = K.batch_get_value([kmci.get_cluster_centroids()])[0]
self.assertSequenceAlmostEqual(centroids, calc_centroids, places=4)

if __name__ == '__main__':
test.main()