Skip to content

Commit

Permalink
Use cut_tree as a backup for fcluster
Browse files Browse the repository at this point in the history
I observed that sometimes the number of clusters returned by fcluster is not equal to the specified number of clusters -- the returned number of clusters is smaller. This seems to happen when n_clusters is close to n_samples. This issue causes some problems downstream.

I tried cut_tree instead of fcluster, and it seems that cut_tree is returning the correct solution in the examples I saw.

However, I didn't want to replace fcluster by cut_tree, because there also seems to be some issue with cut_tree, see scipy/scipy#8063, which is not yet resolved in scipy.

An alternative solution is to use AgglomerativeClustering in sklearn: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html. See https://stackoverflow.com/questions/47535256/how-to-make-fcluster-to-return-the-same-output-as-cut-tree. However, doing that requires a lot of code changes.

So for now, as a temporary solution, I use cut_tree as a backup for fcluster, i.e., when the fcluster output is not desired, I run cut_tree. When both solutions are not desired, I output a warning.

We'll need to fix this in a better way in the future if this is really a problem that happens a lot. By far it does not seem to be happening a lot.
  • Loading branch information
Hu-JIN committed Jul 26, 2021
1 parent bc13084 commit 6f5dfe0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
10 changes: 10 additions & 0 deletions musical/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def hierarchical_cluster(X, k, metric='cosine', linkage_method='average'):
d_square_form = sp.spatial.distance.squareform(d)
linkage = sch.linkage(d, method=linkage_method)
cluster_membership = sch.fcluster(linkage, k, criterion="maxclust")
if len(set(cluster_membership)) != k:
cluster_membership = sch.cut_tree(linkage, n_clusters=k).flatten() + 1
if len(set(cluster_membership)) != k:
warnings.warn('Number of clusters output by cut_tree or fcluster is not equal to the specified number of clusters',
UserWarning)
return d_square_form, cluster_membership


Expand Down Expand Up @@ -150,6 +155,11 @@ def _cluster_statistic(X, max_k, metric='cosine', linkage_method='average'):
silscorek_percluster = {}
for k in range(1, max_k + 1):
cluster_membership = sch.fcluster(linkage, k, criterion="maxclust")
if len(set(cluster_membership)) != k:
cluster_membership = sch.cut_tree(linkage, n_clusters=k).flatten() + 1
if len(set(cluster_membership)) != k:
warnings.warn('Number of clusters output by cut_tree or fcluster is not equal to the specified number of clusters',
UserWarning)
Wk.append(_within_cluster_variation(d_square_form, cluster_membership))
if k == 1:
silscorek.append(np.nan)
Expand Down
5 changes: 5 additions & 0 deletions musical/denovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def _gather_results(X, Ws, Hs=None, method='cluster_by_matching', n_components=N
d_square_form = sp.spatial.distance.squareform(d)
linkage = sch.linkage(d, method='average')
cluster_membership = sch.fcluster(linkage, n_components, criterion="maxclust")
if len(set(cluster_membership)) != n_components:
cluster_membership = sch.cut_tree(linkage, n_clusters=n_components).flatten() + 1
if len(set(cluster_membership)) != n_components:
warnings.warn('Number of clusters output by cut_tree or fcluster is not equal to the specified number of clusters',
UserWarning)
W = []
for i in range(0, n_components):
W.append(np.mean(sigs[:, cluster_membership == i + 1], 1))
Expand Down
5 changes: 5 additions & 0 deletions musical/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def _init_cluster(X, n_components, metric='cosine', linkage='average',
linkage = sch.linkage(d, method=linkage)
for ncluster in range(n_components, np.min([n_samples, max_ncluster]) + 1):
cluster_membership = sch.fcluster(linkage, ncluster, criterion='maxclust')
if len(set(cluster_membership)) != ncluster:
cluster_membership = sch.cut_tree(linkage, n_clusters=ncluster).flatten() + 1
if len(set(cluster_membership)) != ncluster:
warnings.warn('Number of clusters output by cut_tree or fcluster is not equal to the specified number of clusters',
UserWarning)
W = []
for i in range(1, ncluster + 1):
if np.sum(cluster_membership == i) >= min_nsample:
Expand Down

0 comments on commit 6f5dfe0

Please sign in to comment.