Skip to content

Commit 80e0853

Browse files
albertcthomasTomDLT
authored andcommitted
[MRG+1] test that clustering returns int (#9912)
1 parent 766ba93 commit 80e0853

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1035,8 +1035,8 @@ def check_clustering(name, clusterer_orig):
10351035
# with lists
10361036
clusterer.fit(X.tolist())
10371037

1038-
assert_equal(clusterer.labels_.shape, (n_samples,))
10391038
pred = clusterer.labels_
1039+
assert_equal(pred.shape, (n_samples,))
10401040
assert_greater(adjusted_rand_score(pred, y), 0.4)
10411041
# fit another time with ``fit_predict`` and compare results
10421042
if name == 'SpectralClustering':
@@ -1047,6 +1047,25 @@ def check_clustering(name, clusterer_orig):
10471047
pred2 = clusterer.fit_predict(X)
10481048
assert_array_equal(pred, pred2)
10491049

1050+
# fit_predict(X) and labels_ should be of type int
1051+
assert_in(pred.dtype, [np.dtype('int32'), np.dtype('int64')])
1052+
assert_in(pred2.dtype, [np.dtype('int32'), np.dtype('int64')])
1053+
1054+
# There should be at least one sample in every cluster. Equivalently
1055+
# labels_ should contain all the consecutive values between its
1056+
# min and its max.
1057+
pred_sorted = np.unique(pred)
1058+
assert_array_equal(pred_sorted, np.arange(pred_sorted[0],
1059+
pred_sorted[-1] + 1))
1060+
1061+
# labels_ should be greater than -1
1062+
assert_greater_equal(pred_sorted[0], -1)
1063+
# labels_ should be less than n_clusters - 1
1064+
if hasattr(clusterer, 'n_clusters'):
1065+
n_clusters = getattr(clusterer, 'n_clusters')
1066+
assert_greater_equal(n_clusters - 1, pred_sorted[-1])
1067+
# else labels_ should be less than max(labels_) which is necessarily true
1068+
10501069

10511070
@ignore_warnings(category=DeprecationWarning)
10521071
def check_clusterer_compute_labels_predict(name, clusterer_orig):

0 commit comments

Comments
 (0)