Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions sklearn/metrics/cluster/tests/test_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,24 +184,25 @@ def test_silhouette_nonzero_diag(dtype):
silhouette_samples(dists, labels, metric='precomputed')


def assert_raises_on_only_one_label(func):
"""Assert message when there is only one label"""
def raise_exception_on_only_one_label(func):
"""Raise exception when there is only one label"""
rng = np.random.RandomState(seed=0)
with pytest.raises(ValueError, match="Number of labels is"):
func(rng.rand(10, 2), np.zeros(10))


def assert_raises_on_all_points_same_cluster(func):
"""Assert message when all point are in different clusters"""
def raise_exception_on_all_points_in_different_clusters(func):
"""Raise exception when all points are in different clusters"""
rng = np.random.RandomState(seed=0)
with pytest.raises(ValueError, match="Number of labels is"):
func(rng.rand(10, 2), np.arange(10))


def test_calinski_harabasz_score():
assert_raises_on_only_one_label(calinski_harabasz_score)
raise_exception_on_only_one_label(calinski_harabasz_score)

assert_raises_on_all_points_same_cluster(calinski_harabasz_score)
raise_exception_on_all_points_in_different_clusters(
calinski_harabasz_score)

# Assert the value is 1. when all samples are equals
assert 1. == calinski_harabasz_score(np.ones((10, 2)),
Expand All @@ -220,8 +221,8 @@ def test_calinski_harabasz_score():


def test_davies_bouldin_score():
assert_raises_on_only_one_label(davies_bouldin_score)
assert_raises_on_all_points_same_cluster(davies_bouldin_score)
raise_exception_on_only_one_label(davies_bouldin_score)
raise_exception_on_all_points_in_different_clusters(davies_bouldin_score)

# Assert the value is 0. when all samples are equals
assert davies_bouldin_score(np.ones((10, 2)),
Expand Down