Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+2] Invariance tests for clustering metrics #10828

Merged
merged 46 commits into from Mar 18, 2018
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
157efbc
Common test file for different cluster metrics addressingissue 8102
anki08 Dec 29, 2016
d4875f4
common test file for different cluster metrics addressing issue 8102 …
anki08 Dec 29, 2016
9165f23
Update and rename test_file.py to test_common.py
anki08 Dec 30, 2016
89df474
Add files via upload
anki08 Dec 30, 2016
2f76ca6
Add files via upload
anki08 Dec 30, 2016
c469da8
Add files via upload
anki08 Jan 1, 2017
2d5c29b
Delete test_common.py
anki08 Jan 2, 2017
57cec20
Delete 1.py
anki08 Jan 2, 2017
b4050a2
Add files via upload
anki08 Jan 2, 2017
35f2417
Add files via upload
anki08 Jan 3, 2017
52476ab
Wrote different dictionaries for supervised and unsupervised metrics.…
anki08 Jan 4, 2017
7e2f9df
fixed flake8 issues
anki08 Jan 4, 2017
cba2dd2
Fixed flake8 issues
anki08 Jan 4, 2017
4e22e3d
Fixed flake8 issues
anki08 Jan 6, 2017
072c033
Add files via upload
anki08 Jan 6, 2017
3e6d533
Fixed flake8 issues
anki08 Jan 6, 2017
b230cb9
Fixed flake8 issues
anki08 Jan 6, 2017
15c9407
Added tests to test_format_invariance .
anki08 Jan 9, 2017
7ac5f6f
Update test_common.py
anki08 Jan 9, 2017
8cc59ea
Updated function test_format_invariance
anki08 Jan 10, 2017
126f8ad
Update test_common.py
anki08 Jan 10, 2017
b5296b9
Updated function test_format_invariance
anki08 Jan 10, 2017
4aed4c1
Added error message to all metrics
anki08 Jan 11, 2017
e1eaf04
Added assert_contains
anki08 Jan 11, 2017
e16cf65
Added assert_contains
anki08 Jan 11, 2017
ed05584
Added assert_contains
anki08 Jan 11, 2017
40ca799
Added assert_contains
anki08 Jan 11, 2017
8b2e9b4
Added assert_true
anki08 Jan 11, 2017
149364e
Update test_common.py
anki08 Jan 11, 2017
c9fa819
assert_equal -> assert_almost_equal
anki08 Jan 17, 2017
bdce12e
changed assert_almost_equal
anki08 Jan 17, 2017
4f0639d
Update test_common.py
anki08 Jan 17, 2017
71a575d
Updated test_common.py
anki08 Jan 17, 2017
06c2180
Updated test_common.py
anki08 Jan 17, 2017
987a9fe
Updated test_common.py
anki08 Jan 17, 2017
907aef2
Updated test_common.py
anki08 Jan 17, 2017
a3af713
assert_almost_equal -> assert_equal
anki08 Jan 17, 2017
05c687b
updated test_common.py
anki08 Jan 17, 2017
0c70e3f
updated test_common.py
anki08 Jan 17, 2017
8f4de0f
Updated test_common.py
anki08 Jan 18, 2017
7718d36
updated test_common.py
anki08 Jan 18, 2017
481e096
Update test_common.py
amueller Dec 12, 2017
0a792d8
TST use pytest
glemaitre Mar 18, 2018
6398349
Merge remote-tracking branch 'origin/master' into feature1
glemaitre Mar 18, 2018
b5441d9
add whats-new entry
glemaitre Mar 18, 2018
54dc723
DOC syntax
jnothman Mar 18, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v0.20.rst
Expand Up @@ -492,3 +492,6 @@ Changes to estimator checks
- Add test :func:`estimator_checks.check_methods_subset_invariance` to check
that estimators methods are invariant if applied to a data subset.
:issue:`10420` by :user:`Jonathan Ohayon <Johayon>`

- Add invariance tests for clustering metrics. :issue:`8102` by :user:`Ankita
Sinha <anki08>` and `Guillaume Lemaitre <glemaitre>`.
Copy link
Member

@jnothman jnothman Mar 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You either need :user: or you need to add yourself to _contributors.rst and use `Guillaume Lemaitre`_

178 changes: 178 additions & 0 deletions sklearn/metrics/cluster/tests/test_common.py
@@ -0,0 +1,178 @@
from functools import partial

import pytest
import numpy as np

from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import completeness_score
from sklearn.metrics.cluster import fowlkes_mallows_score
from sklearn.metrics.cluster import homogeneity_score
from sklearn.metrics.cluster import mutual_info_score
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import v_measure_score
from sklearn.metrics.cluster import silhouette_score
from sklearn.metrics.cluster import calinski_harabaz_score

from sklearn.utils.testing import assert_allclose


# Dictionaries of metrics
# ------------------------
# The goal of having those dictionaries is to have an easy way to call a
# particular metric and associate a name to each function:
# - SUPERVISED_METRICS: all supervised cluster metrics - (when given a
# ground truth value)
# - UNSUPERVISED_METRICS: all unsupervised cluster metrics
#
# Those dictionaries will be used to test systematically some invariance
# properties, e.g. invariance toward several input layout.
#

SUPERVISED_METRICS = {
"adjusted_mutual_info_score": adjusted_mutual_info_score,
"adjusted_rand_score": adjusted_rand_score,
"completeness_score": completeness_score,
"homogeneity_score": homogeneity_score,
"mutual_info_score": mutual_info_score,
"normalized_mutual_info_score": normalized_mutual_info_score,
"v_measure_score": v_measure_score,
"fowlkes_mallows_score": fowlkes_mallows_score
}

UNSUPERVISED_METRICS = {
"silhouette_score": silhouette_score,
"silhouette_manhattan": partial(silhouette_score, metric='manhattan'),
"calinski_harabaz_score": calinski_harabaz_score
}

# Lists of metrics with common properties
# ---------------------------------------
# Lists of metrics with common properties are used to test systematically some
# functionalities and invariance, e.g. SYMMETRIC_METRICS lists all metrics
# that are symmetric with respect to their input argument y_true and y_pred.
#
# --------------------------------------------------------------------
# Symmetric with respect to their input arguments y_true and y_pred.
# Symmetric metrics only apply to supervised clusters.
SYMMETRIC_METRICS = [
"adjusted_rand_score", "v_measure_score",
"mutual_info_score", "adjusted_mutual_info_score",
"normalized_mutual_info_score", "fowlkes_mallows_score"
]

NON_SYMMETRIC_METRICS = ["homogeneity_score", "completeness_score"]

# Metrics whose upper bound is 1
NORMALIZED_METRICS = [
"adjusted_rand_score", "homogeneity_score", "completeness_score",
"v_measure_score", "adjusted_mutual_info_score", "fowlkes_mallows_score",
"normalized_mutual_info_score"
]


rng = np.random.RandomState(0)
y1 = rng.randint(3, size=30)
y2 = rng.randint(3, size=30)


def test_symmetric_non_symmetric_union():
assert (sorted(SYMMETRIC_METRICS + NON_SYMMETRIC_METRICS) ==
sorted(SUPERVISED_METRICS))


@pytest.mark.parametrize(
'metric_name, y1, y2',
[(name, y1, y2) for name in SYMMETRIC_METRICS]
)
def test_symmetry(metric_name, y1, y2):
metric = SUPERVISED_METRICS[metric_name]
assert metric(y1, y2) == pytest.approx(metric(y2, y1))


@pytest.mark.parametrize(
'metric_name, y1, y2',
[(name, y1, y2) for name in NON_SYMMETRIC_METRICS]
)
def test_non_symmetry(metric_name, y1, y2):
metric = SUPERVISED_METRICS[metric_name]
assert metric(y1, y2) != pytest.approx(metric(y2, y1))


@pytest.mark.parametrize(
"metric_name",
[name for name in NORMALIZED_METRICS]
)
def test_normalized_output(metric_name):
upper_bound_1 = [0, 0, 0, 1, 1, 1]
upper_bound_2 = [0, 0, 0, 1, 1, 1]
metric = SUPERVISED_METRICS[metric_name]
assert metric([0, 0, 0, 1, 1], [0, 0, 0, 1, 2]) > 0.0
assert metric([0, 0, 1, 1, 2], [0, 0, 1, 1, 1]) > 0.0
assert metric([0, 0, 0, 1, 2], [0, 1, 1, 1, 1]) < 1.0
assert metric([0, 0, 0, 1, 2], [0, 1, 1, 1, 1]) < 1.0
assert metric(upper_bound_1, upper_bound_2) == pytest.approx(1.0)

lower_bound_1 = [0, 0, 0, 0, 0, 0]
lower_bound_2 = [0, 1, 2, 3, 4, 5]
score = np.array([metric(lower_bound_1, lower_bound_2),
metric(lower_bound_2, lower_bound_1)])
assert not (score < 0).any()


# All clustering metrics do not change score due to permutations of labels
# that is when 0 and 1 exchanged.
@pytest.mark.parametrize(
"metric_name",
[name for name in dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)]
)
def test_permute_labels(metric_name):
y_label = np.array([0, 0, 0, 1, 1, 0, 1])
y_pred = np.array([1, 0, 1, 0, 1, 1, 0])
if metric_name in SUPERVISED_METRICS:
metric = SUPERVISED_METRICS[metric_name]
score_1 = metric(y_pred, y_label)
assert_allclose(score_1, metric(1 - y_pred, y_label))
assert_allclose(score_1, metric(1 - y_pred, 1 - y_label))
assert_allclose(score_1, metric(y_pred, 1 - y_label))
else:
metric = UNSUPERVISED_METRICS[metric_name]
X = np.random.randint(10, size=(7, 10))
score_1 = metric(X, y_pred)
assert_allclose(score_1, metric(X, 1 - y_pred))


# For all clustering metrics Input parameters can be both
@pytest.mark.parametrize(
"metric_name",
[name for name in dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)]
)
# in the form of arrays lists, positive, negetive or string
def test_format_invariance(metric_name):
y_true = [0, 0, 0, 0, 1, 1, 1, 1]
y_pred = [0, 1, 2, 3, 4, 5, 6, 7]

def generate_formats(y):
y = np.array(y)
yield y, 'array of ints'
yield y.tolist(), 'list of ints'
yield [str(x) for x in y.tolist()], 'list of strs'
yield y - 1, 'including negative ints'
yield y + 1, 'strictly positive ints'

if metric_name in SUPERVISED_METRICS:
metric = SUPERVISED_METRICS[metric_name]
score_1 = metric(y_true, y_pred)
y_true_gen = generate_formats(y_true)
y_pred_gen = generate_formats(y_pred)
for (y_true_fmt, fmt_name), (y_pred_fmt, _) in zip(y_true_gen,
y_pred_gen):
assert score_1 == metric(y_true_fmt, y_pred_fmt)
else:
metric = UNSUPERVISED_METRICS[metric_name]
X = np.random.randint(10, size=(8, 10))
score_1 = metric(X, y_true)
assert score_1 == metric(X.astype(float), y_true)
y_true_gen = generate_formats(y_true)
for (y_true_fmt, fmt_name) in y_true_gen:
assert score_1 == metric(X, y_true_fmt)