@@ -1035,8 +1035,8 @@ def check_clustering(name, clusterer_orig):
1035
1035
# with lists
1036
1036
clusterer .fit (X .tolist ())
1037
1037
1038
- assert_equal (clusterer .labels_ .shape , (n_samples ,))
1039
1038
pred = clusterer .labels_
1039
+ assert_equal (pred .shape , (n_samples ,))
1040
1040
assert_greater (adjusted_rand_score (pred , y ), 0.4 )
1041
1041
# fit another time with ``fit_predict`` and compare results
1042
1042
if name == 'SpectralClustering' :
@@ -1047,6 +1047,25 @@ def check_clustering(name, clusterer_orig):
1047
1047
pred2 = clusterer .fit_predict (X )
1048
1048
assert_array_equal (pred , pred2 )
1049
1049
1050
+ # fit_predict(X) and labels_ should be of type int
1051
+ assert_in (pred .dtype , [np .dtype ('int32' ), np .dtype ('int64' )])
1052
+ assert_in (pred2 .dtype , [np .dtype ('int32' ), np .dtype ('int64' )])
1053
+
1054
+ # There should be at least one sample in every cluster. Equivalently
1055
+ # labels_ should contain all the consecutive values between its
1056
+ # min and its max.
1057
+ pred_sorted = np .unique (pred )
1058
+ assert_array_equal (pred_sorted , np .arange (pred_sorted [0 ],
1059
+ pred_sorted [- 1 ] + 1 ))
1060
+
1061
+ # labels_ should be greater than -1
1062
+ assert_greater_equal (pred_sorted [0 ], - 1 )
1063
+ # labels_ should be less than n_clusters - 1
1064
+ if hasattr (clusterer , 'n_clusters' ):
1065
+ n_clusters = getattr (clusterer , 'n_clusters' )
1066
+ assert_greater_equal (n_clusters - 1 , pred_sorted [- 1 ])
1067
+ # else labels_ should be less than max(labels_) which is necessarily true
1068
+
1050
1069
1051
1070
@ignore_warnings (category = DeprecationWarning )
1052
1071
def check_clusterer_compute_labels_predict (name , clusterer_orig ):
0 commit comments