New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+2] Invariance tests for clustering metrics #10828
Merged
jnothman
merged 46 commits into
scikit-learn:master
from
glemaitre:common_test_clustering
Mar 18, 2018
Merged
Changes from 45 commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
157efbc
Common test file for different cluster metrics addressingissue 8102
anki08 d4875f4
common test file for different cluster metrics addressing issue 8102 …
anki08 9165f23
Update and rename test_file.py to test_common.py
anki08 89df474
Add files via upload
anki08 2f76ca6
Add files via upload
anki08 c469da8
Add files via upload
anki08 2d5c29b
Delete test_common.py
anki08 57cec20
Delete 1.py
anki08 b4050a2
Add files via upload
anki08 35f2417
Add files via upload
anki08 52476ab
Wrote different dictionaries for supervised and unsupervised metrics.…
anki08 7e2f9df
fixed flake8 issues
anki08 cba2dd2
Fixed flake8 issues
anki08 4e22e3d
Fixed flake8 issues
anki08 072c033
Add files via upload
anki08 3e6d533
Fixed flake8 issues
anki08 b230cb9
Fixed flake8 issues
anki08 15c9407
Added tests to test_format_invariance .
anki08 7ac5f6f
Update test_common.py
anki08 8cc59ea
Updated function test_format_invariance
anki08 126f8ad
Update test_common.py
anki08 b5296b9
Updated function test_format_invariance
anki08 4aed4c1
Added error message to all metrics
anki08 e1eaf04
Added assert_contains
anki08 e16cf65
Added assert_contains
anki08 ed05584
Added assert_contains
anki08 40ca799
Added assert_contains
anki08 8b2e9b4
Added assert_true
anki08 149364e
Update test_common.py
anki08 c9fa819
assert_equal -> assert_almost_equal
anki08 bdce12e
changed assert_almost_equal
anki08 4f0639d
Update test_common.py
anki08 71a575d
Updated test_common.py
anki08 06c2180
Updated test_common.py
anki08 987a9fe
Updated test_common.py
anki08 907aef2
Updated test_common.py
anki08 a3af713
assert_almost_equal -> assert_equal
anki08 05c687b
updated test_common.py
anki08 0c70e3f
updated test_common.py
anki08 8f4de0f
Updated test_common.py
anki08 7718d36
updated test_common.py
anki08 481e096
Update test_common.py
amueller 0a792d8
TST use pytest
glemaitre 6398349
Merge remote-tracking branch 'origin/master' into feature1
glemaitre b5441d9
add whats-new entry
glemaitre 54dc723
DOC syntax
jnothman File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from functools import partial | ||
|
||
import pytest | ||
import numpy as np | ||
|
||
from sklearn.metrics.cluster import adjusted_mutual_info_score | ||
from sklearn.metrics.cluster import adjusted_rand_score | ||
from sklearn.metrics.cluster import completeness_score | ||
from sklearn.metrics.cluster import fowlkes_mallows_score | ||
from sklearn.metrics.cluster import homogeneity_score | ||
from sklearn.metrics.cluster import mutual_info_score | ||
from sklearn.metrics.cluster import normalized_mutual_info_score | ||
from sklearn.metrics.cluster import v_measure_score | ||
from sklearn.metrics.cluster import silhouette_score | ||
from sklearn.metrics.cluster import calinski_harabaz_score | ||
|
||
from sklearn.utils.testing import assert_allclose | ||
|
||
|
||
# Dictionaries of metrics | ||
# ------------------------ | ||
# The goal of having those dictionaries is to have an easy way to call a | ||
# particular metric and associate a name to each function: | ||
# - SUPERVISED_METRICS: all supervised cluster metrics - (when given a | ||
# ground truth value) | ||
# - UNSUPERVISED_METRICS: all unsupervised cluster metrics | ||
# | ||
# Those dictionaries will be used to test systematically some invariance | ||
# properties, e.g. invariance toward several input layout. | ||
# | ||
|
||
SUPERVISED_METRICS = { | ||
"adjusted_mutual_info_score": adjusted_mutual_info_score, | ||
"adjusted_rand_score": adjusted_rand_score, | ||
"completeness_score": completeness_score, | ||
"homogeneity_score": homogeneity_score, | ||
"mutual_info_score": mutual_info_score, | ||
"normalized_mutual_info_score": normalized_mutual_info_score, | ||
"v_measure_score": v_measure_score, | ||
"fowlkes_mallows_score": fowlkes_mallows_score | ||
} | ||
|
||
UNSUPERVISED_METRICS = { | ||
"silhouette_score": silhouette_score, | ||
"silhouette_manhattan": partial(silhouette_score, metric='manhattan'), | ||
"calinski_harabaz_score": calinski_harabaz_score | ||
} | ||
|
||
# Lists of metrics with common properties | ||
# --------------------------------------- | ||
# Lists of metrics with common properties are used to test systematically some | ||
# functionalities and invariance, e.g. SYMMETRIC_METRICS lists all metrics | ||
# that are symmetric with respect to their input argument y_true and y_pred. | ||
# | ||
# -------------------------------------------------------------------- | ||
# Symmetric with respect to their input arguments y_true and y_pred. | ||
# Symmetric metrics only apply to supervised clusters. | ||
SYMMETRIC_METRICS = [ | ||
"adjusted_rand_score", "v_measure_score", | ||
"mutual_info_score", "adjusted_mutual_info_score", | ||
"normalized_mutual_info_score", "fowlkes_mallows_score" | ||
] | ||
|
||
NON_SYMMETRIC_METRICS = ["homogeneity_score", "completeness_score"] | ||
|
||
# Metrics whose upper bound is 1 | ||
NORMALIZED_METRICS = [ | ||
"adjusted_rand_score", "homogeneity_score", "completeness_score", | ||
"v_measure_score", "adjusted_mutual_info_score", "fowlkes_mallows_score", | ||
"normalized_mutual_info_score" | ||
] | ||
|
||
|
||
rng = np.random.RandomState(0) | ||
y1 = rng.randint(3, size=30) | ||
y2 = rng.randint(3, size=30) | ||
|
||
|
||
def test_symmetric_non_symmetric_union(): | ||
assert (sorted(SYMMETRIC_METRICS + NON_SYMMETRIC_METRICS) == | ||
sorted(SUPERVISED_METRICS)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'metric_name, y1, y2', | ||
[(name, y1, y2) for name in SYMMETRIC_METRICS] | ||
) | ||
def test_symmetry(metric_name, y1, y2): | ||
metric = SUPERVISED_METRICS[metric_name] | ||
assert metric(y1, y2) == pytest.approx(metric(y2, y1)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'metric_name, y1, y2', | ||
[(name, y1, y2) for name in NON_SYMMETRIC_METRICS] | ||
) | ||
def test_non_symmetry(metric_name, y1, y2): | ||
metric = SUPERVISED_METRICS[metric_name] | ||
assert metric(y1, y2) != pytest.approx(metric(y2, y1)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"metric_name", | ||
[name for name in NORMALIZED_METRICS] | ||
) | ||
def test_normalized_output(metric_name): | ||
upper_bound_1 = [0, 0, 0, 1, 1, 1] | ||
upper_bound_2 = [0, 0, 0, 1, 1, 1] | ||
metric = SUPERVISED_METRICS[metric_name] | ||
assert metric([0, 0, 0, 1, 1], [0, 0, 0, 1, 2]) > 0.0 | ||
assert metric([0, 0, 1, 1, 2], [0, 0, 1, 1, 1]) > 0.0 | ||
assert metric([0, 0, 0, 1, 2], [0, 1, 1, 1, 1]) < 1.0 | ||
assert metric([0, 0, 0, 1, 2], [0, 1, 1, 1, 1]) < 1.0 | ||
assert metric(upper_bound_1, upper_bound_2) == pytest.approx(1.0) | ||
|
||
lower_bound_1 = [0, 0, 0, 0, 0, 0] | ||
lower_bound_2 = [0, 1, 2, 3, 4, 5] | ||
score = np.array([metric(lower_bound_1, lower_bound_2), | ||
metric(lower_bound_2, lower_bound_1)]) | ||
assert not (score < 0).any() | ||
|
||
|
||
# All clustering metrics do not change score due to permutations of labels | ||
# that is when 0 and 1 exchanged. | ||
@pytest.mark.parametrize( | ||
"metric_name", | ||
[name for name in dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)] | ||
) | ||
def test_permute_labels(metric_name): | ||
y_label = np.array([0, 0, 0, 1, 1, 0, 1]) | ||
y_pred = np.array([1, 0, 1, 0, 1, 1, 0]) | ||
if metric_name in SUPERVISED_METRICS: | ||
metric = SUPERVISED_METRICS[metric_name] | ||
score_1 = metric(y_pred, y_label) | ||
assert_allclose(score_1, metric(1 - y_pred, y_label)) | ||
assert_allclose(score_1, metric(1 - y_pred, 1 - y_label)) | ||
assert_allclose(score_1, metric(y_pred, 1 - y_label)) | ||
else: | ||
metric = UNSUPERVISED_METRICS[metric_name] | ||
X = np.random.randint(10, size=(7, 10)) | ||
score_1 = metric(X, y_pred) | ||
assert_allclose(score_1, metric(X, 1 - y_pred)) | ||
|
||
|
||
# For all clustering metrics Input parameters can be both | ||
@pytest.mark.parametrize( | ||
"metric_name", | ||
[name for name in dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)] | ||
) | ||
# in the form of arrays lists, positive, negetive or string | ||
def test_format_invariance(metric_name): | ||
y_true = [0, 0, 0, 0, 1, 1, 1, 1] | ||
y_pred = [0, 1, 2, 3, 4, 5, 6, 7] | ||
|
||
def generate_formats(y): | ||
y = np.array(y) | ||
yield y, 'array of ints' | ||
yield y.tolist(), 'list of ints' | ||
yield [str(x) for x in y.tolist()], 'list of strs' | ||
yield y - 1, 'including negative ints' | ||
yield y + 1, 'strictly positive ints' | ||
|
||
if metric_name in SUPERVISED_METRICS: | ||
metric = SUPERVISED_METRICS[metric_name] | ||
score_1 = metric(y_true, y_pred) | ||
y_true_gen = generate_formats(y_true) | ||
y_pred_gen = generate_formats(y_pred) | ||
for (y_true_fmt, fmt_name), (y_pred_fmt, _) in zip(y_true_gen, | ||
y_pred_gen): | ||
assert score_1 == metric(y_true_fmt, y_pred_fmt) | ||
else: | ||
metric = UNSUPERVISED_METRICS[metric_name] | ||
X = np.random.randint(10, size=(8, 10)) | ||
score_1 = metric(X, y_true) | ||
assert score_1 == metric(X.astype(float), y_true) | ||
y_true_gen = generate_formats(y_true) | ||
for (y_true_fmt, fmt_name) in y_true_gen: | ||
assert score_1 == metric(X, y_true_fmt) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You either need
:user:
or you need to add yourself to_contributors.rst
and use`Guillaume Lemaitre`_