Skip to content

Commit

Permalink
[MRG+1] test that clustering returns int (#9912)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertcthomas authored and TomDLT committed Oct 18, 2017
1 parent 766ba93 commit 80e0853
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion sklearn/utils/estimator_checks.py
Expand Up @@ -1035,8 +1035,8 @@ def check_clustering(name, clusterer_orig):
# with lists
clusterer.fit(X.tolist())

assert_equal(clusterer.labels_.shape, (n_samples,))
pred = clusterer.labels_
assert_equal(pred.shape, (n_samples,))
assert_greater(adjusted_rand_score(pred, y), 0.4)
# fit another time with ``fit_predict`` and compare results
if name == 'SpectralClustering':
Expand All @@ -1047,6 +1047,25 @@ def check_clustering(name, clusterer_orig):
pred2 = clusterer.fit_predict(X)
assert_array_equal(pred, pred2)

# fit_predict(X) and labels_ should be of type int
assert_in(pred.dtype, [np.dtype('int32'), np.dtype('int64')])
assert_in(pred2.dtype, [np.dtype('int32'), np.dtype('int64')])

# There should be at least one sample in every cluster. Equivalently
# labels_ should contain all the consecutive values between its
# min and its max.
pred_sorted = np.unique(pred)
assert_array_equal(pred_sorted, np.arange(pred_sorted[0],
pred_sorted[-1] + 1))

# labels_ should be greater than -1
assert_greater_equal(pred_sorted[0], -1)
# labels_ should be less than n_clusters - 1
if hasattr(clusterer, 'n_clusters'):
n_clusters = getattr(clusterer, 'n_clusters')
assert_greater_equal(n_clusters - 1, pred_sorted[-1])
# else labels_ should be less than max(labels_) which is necessarily true


@ignore_warnings(category=DeprecationWarning)
def check_clusterer_compute_labels_predict(name, clusterer_orig):
Expand Down

0 comments on commit 80e0853

Please sign in to comment.