Skip to content

Commit

Permalink
k-means clustering, monitor distances
Browse files Browse the repository at this point in the history
  • Loading branch information
wannesm committed Jul 5, 2023
1 parent 15cac70 commit 0343fe0
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion dtaidistance/clustering/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,19 @@ def kmeansplusplus_centers(self, series, use_c=False):
def fit_fast(self, series):
return self.fit(series, use_c=True, use_parallel=True)

def fit(self, series, use_c=False, use_parallel=True):
def fit(self, series, use_c=False, use_parallel=True, monitor_distances=None):
"""Perform K-means clustering.
:param series: Container with series
:param use_c: Use the C-library (only available if package is compiled)
:param use_parallel: Use multipool for parallelization
:param monitor_distances: This function is called with two arguments:
(1) a list of (cluster, distance) for each instance;
(2) a boolean indicating whether the clustering has been stopped or not.
From this one can compute inertia or other metrics
to monitor the clustering. If the boolean argument is true, this is the
final assignment. If this function returns True, the clustering
continues, if False is returned the clustering is stopped.
:return: cluster indices, number of iterations
If the number of iterations is equal to max_it, the clustering
did not converge.
Expand Down Expand Up @@ -299,6 +306,10 @@ def fit(self, series, use_c=False, use_parallel=True):
else:
clusters_distances = list(map(fn, [(self.series[idx], self.means, self.dists_options) for idx in
range(len(self.series))]))
if monitor_distances is not None:
cont = monitor_distances(clusters_distances, False)
if cont is False:
break
clusters, distances = zip(*clusters_distances)
distances = list(distances)

Expand Down Expand Up @@ -393,6 +404,8 @@ def fit(self, series, use_c=False, use_parallel=True):
else:
clusters_distances = list(map(fn, [(self.series[idx], self.means, self.dists_options) for idx in
range(len(self.series))]))
if monitor_distances is not None:
monitor_distances(clusters_distances, True)
clusters, distances = zip(*clusters_distances)

# self.cluster_idx = {medoid: {inst for inst in instances}
Expand Down

0 comments on commit 0343fe0

Please sign in to comment.