Skip to content

Commit

Permalink
FIX Require explicit average arg for multiclass/label P/R/F metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
jnothman committed Dec 19, 2013
1 parent 8dabff1 commit 26ac3cf
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 29 deletions.
2 changes: 1 addition & 1 deletion benchmarks/bench_multilabel_metrics.py
Expand Up @@ -20,7 +20,7 @@


METRICS = {
'f1': f1_score,
'f1': partial(f1_score, average='micro'),
'f1-by-sample': partial(f1_score, average='samples'),
'accuracy': accuracy_score,
'hamming': hamming_loss,
Expand Down
6 changes: 3 additions & 3 deletions doc/datasets/twenty_newsgroups.rst
Expand Up @@ -131,7 +131,7 @@ which is fast to train and achieves a decent F-score::
>>> clf = MultinomialNB(alpha=.01)
>>> clf.fit(vectors, newsgroups_train.target)
>>> pred = clf.predict(vectors_test)
>>> metrics.f1_score(newsgroups_test.target, pred)
>>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
0.88251152461278892

(The example :ref:`example_document_classification_20newsgroups.py` shuffles
Expand Down Expand Up @@ -181,7 +181,7 @@ blocks, and quotation blocks respectively.
... categories=categories)
>>> vectors_test = vectorizer.transform(newsgroups_test.data)
>>> pred = clf.predict(vectors_test)
>>> metrics.f1_score(pred, newsgroups_test.target)
>>> metrics.f1_score(pred, newsgroups_test.target, average='weighted')
0.78409163025839435

This classifier lost over a lot of its F-score, just because we removed
Expand All @@ -196,7 +196,7 @@ It loses even more if we also strip this metadata from the training data:
>>> clf.fit(vectors, newsgroups_train.target)
>>> vectors_test = vectorizer.transform(newsgroups_test.data)
>>> pred = clf.predict(vectors_test)
>>> metrics.f1_score(newsgroups_test.target, pred)
>>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
0.73160869205141166

Some other classifiers cope better with this harder version of the task. Try
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -114,6 +114,12 @@ API changes summary
:class:`RandomizedPCA <decomposition.RandomizedPCA>`.
By `Alexandre Gramfort`_.

- Users should now supply an explicit ``average`` parameter to
:func:`sklearn.metrics.f1_score`, :func:`sklearn.metrics.fbeta_score`,
:func:`sklearn.metrics.recall_score` and
:func:`sklearn.metrics.precision_score` when performing multiclass
or multilabel (i.e. not binary) classification. By `Joel Nothman`_.

.. _changes_0_14:

0.14
Expand Down
4 changes: 2 additions & 2 deletions examples/document_classification_20newsgroups.py
Expand Up @@ -201,8 +201,8 @@ def benchmark(clf):
test_time = time() - t0
print("test time: %0.3fs" % test_time)

score = metrics.f1_score(y_test, pred)
print("f1-score: %0.3f" % score)
score = metrics.f1_score(y_test, pred, average='micro')
print("micro f1-score: %0.3f" % score)

if hasattr(clf, 'coef_'):
print("dimensionality: %d" % clf.coef_.shape[1])
Expand Down
16 changes: 9 additions & 7 deletions sklearn/linear_model/tests/test_sgd.py
Expand Up @@ -454,14 +454,16 @@ def test_auto_weight(self):
y = y[idx]
clf = self.factory(alpha=0.0001, n_iter=1000,
class_weight=None).fit(X, y)
assert_almost_equal(metrics.f1_score(y, clf.predict(X)), 0.96,
decimal=1)
assert_almost_equal(metrics.f1_score(y, clf.predict(X),
average='weighted'),
0.96, decimal=1)

# make the same prediction using automated class_weight
clf_auto = self.factory(alpha=0.0001, n_iter=1000,
class_weight="auto").fit(X, y)
assert_almost_equal(metrics.f1_score(y, clf_auto.predict(X)), 0.96,
decimal=1)
assert_almost_equal(metrics.f1_score(y, clf_auto.predict(X),
average='weighted'),
0.96, decimal=1)

# Make sure that in the balanced case it does not change anything
# to use "auto"
Expand All @@ -478,19 +480,19 @@ def test_auto_weight(self):
clf = self.factory(n_iter=1000, class_weight=None)
clf.fit(X_imbalanced, y_imbalanced)
y_pred = clf.predict(X)
assert_less(metrics.f1_score(y, y_pred), 0.96)
assert_less(metrics.f1_score(y, y_pred, average='weighted'), 0.96)

# fit a model with auto class_weight enabled
clf = self.factory(n_iter=1000, class_weight="auto")
clf.fit(X_imbalanced, y_imbalanced)
y_pred = clf.predict(X)
assert_greater(metrics.f1_score(y, y_pred), 0.96)
assert_greater(metrics.f1_score(y, y_pred, average='weighted'), 0.96)

# fit another using a fit parameter override
clf = self.factory(n_iter=1000, class_weight="auto")
clf.fit(X_imbalanced, y_imbalanced)
y_pred = clf.predict(X)
assert_greater(metrics.f1_score(y, y_pred), 0.96)
assert_greater(metrics.f1_score(y, y_pred, average='weighted'), 0.96)

def test_sample_weights(self):
"""Test weights on individual samples"""
Expand Down
30 changes: 21 additions & 9 deletions sklearn/metrics/metrics.py
Expand Up @@ -1214,7 +1214,7 @@ def accuracy_score(y_true, y_pred, normalize=True):
return np.sum(score)


def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):
def f1_score(y_true, y_pred, labels=None, pos_label=1, average='compat'):
"""Compute the F1 score, also known as balanced F-score or F-measure
The F1 score can be interpreted as a weighted average of the precision and
Expand Down Expand Up @@ -1242,7 +1242,8 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):
If ``average`` is not ``None`` and the classification target is binary,
only this class's scores will be returned.
average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)]
average : string, [None, 'micro', 'macro', 'samples', 'weighted']
If the targets are multiclass, this should be set explicitly.
If ``None``, the scores for each class are returned. Otherwise,
unless ``pos_label`` is given in binary classification, this
determines the type of averaging performed on the data:
Expand Down Expand Up @@ -1296,7 +1297,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):


def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
average='weighted'):
average='compat'):
"""Compute the F-beta score
The F-beta score is the weighted harmonic mean of precision and recall,
Expand Down Expand Up @@ -1325,7 +1326,8 @@ def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
If ``average`` is not ``None`` and the classification target is binary,
only this class's scores will be returned.
average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)]
average : string, [None, 'micro', 'macro', 'samples', 'weighted']
If the targets are multiclass, this should be set explicitly.
If ``None``, the scores for each class are returned. Otherwise,
unless ``pos_label`` is given in binary classification, this
determines the type of averaging performed on the data:
Expand Down Expand Up @@ -1549,14 +1551,22 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
"""
average_options = (None, 'micro', 'macro', 'weighted', 'samples')
if average not in average_options:
if average not in average_options and average != 'compat':
raise ValueError('average has to be one of ' +
str(average_options))
if beta <= 0:
raise ValueError("beta should be >0 in the F-beta score")

y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred)

if average == 'compat' and y_type != 'binary':
warnings.warn('The default `weighted` averaging is deprecated, '
'and another default may be used from version 0.17. '
'Please set an explicit value for `average`, one of '
'%s.' % str(average_options),
DeprecationWarning, stacklevel=2)
average = 'weighted'

label_order = labels # save this for later
if labels is None:
labels = unique_labels(y_true, y_pred)
Expand Down Expand Up @@ -1665,7 +1675,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,


def precision_score(y_true, y_pred, labels=None, pos_label=1,
average='weighted'):
average='compat'):
"""Compute the precision
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
Expand All @@ -1690,7 +1700,8 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
If ``average`` is not ``None`` and the classification target is binary,
only this class's scores will be returned.
average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)]
average : string, [None, 'micro', 'macro', 'samples', 'weighted']
If the targets are multiclass, this should be set explicitly.
If ``None``, the scores for each class are returned. Otherwise,
unless ``pos_label`` is given in binary classification, this
determines the type of averaging performed on the data:
Expand Down Expand Up @@ -1743,7 +1754,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
return p


def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):
def recall_score(y_true, y_pred, labels=None, pos_label=1, average='compat'):
"""Compute the recall
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
Expand All @@ -1767,7 +1778,8 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):
If ``average`` is not ``None`` and the classification target is binary,
only this class's scores will be returned.
average : string, [None, 'micro', 'macro', 'samples', 'weighted' (default)]
average : string, [None, 'micro', 'macro', 'samples', 'weighted']
If the targets are multiclass, this should be set explicitly.
If ``None``, the scores for each class are returned. Otherwise,
unless ``pos_label`` is given in binary classification, this
determines the type of averaging performed on the data:
Expand Down
26 changes: 26 additions & 0 deletions sklearn/metrics/tests/test_metrics.py
Expand Up @@ -114,11 +114,13 @@
"zero_one_loss": zero_one_loss,
"unnormalized_zero_one_loss": partial(zero_one_loss, normalize=False),

# These are needed to test averaging
"precision_score": precision_score,
"recall_score": recall_score,
"f1_score": f1_score,
"f2_score": partial(fbeta_score, beta=2),
"f0.5_score": partial(fbeta_score, beta=0.5),

"matthews_corrcoef_score": matthews_corrcoef,

"weighted_f0.5_score": partial(fbeta_score, average="weighted", beta=0.5),
Expand Down Expand Up @@ -1293,6 +1295,7 @@ def test_losses_at_limits():
assert_almost_equal(r2_score([0., 1], [0., 1]), 1.00, 2)


@ignore_warnings
def test_symmetry():
"""Test the symmetry of score and loss functions"""
y_true, y_pred, _ = make_prediction(binary=True)
Expand Down Expand Up @@ -1321,6 +1324,7 @@ def test_symmetry():
msg="%s seems to be symmetric" % name)


@ignore_warnings
def test_sample_order_invariance():
y_true, y_pred, _ = make_prediction(binary=True)
y_true_shuffle, y_pred_shuffle = shuffle(y_true, y_pred, random_state=0)
Expand All @@ -1335,6 +1339,7 @@ def test_sample_order_invariance():
% name)


@ignore_warnings
def test_sample_order_invariance_multilabel_and_multioutput():
random_state = check_random_state(0)

Expand Down Expand Up @@ -1374,6 +1379,7 @@ def test_sample_order_invariance_multilabel_and_multioutput():
% name)


@ignore_warnings
def test_format_invariance_with_1d_vectors():
y1, y2, _ = make_prediction(binary=True)

Expand Down Expand Up @@ -1450,6 +1456,7 @@ def test_format_invariance_with_1d_vectors():
assert_raises(ValueError, metric, y1_row, y2_row)


@ignore_warnings
def test_invariance_string_vs_numbers_labels():
"""Ensure that classification metrics with string labels"""
y1, y2, _ = make_prediction(binary=True)
Expand Down Expand Up @@ -1597,6 +1604,7 @@ def test_multioutput_regression_invariance_to_dimension_shuffling():
"invariant" % name)


@ignore_warnings
def test_multilabel_representation_invariance():

# Generate some data
Expand Down Expand Up @@ -2228,6 +2236,24 @@ def test_fscore_warnings():
'being set to 0.0 due to no true samples.')


def test_prf_average_compat():
"""Ensure warning if f1_score et al.'s average is implicit for multiclass"""
y_true = [1, 2, 3, 3]
y_pred = [1, 2, 3, 1]

for metric in [precision_score, recall_score, f1_score,
partial(fbeta_score, beta=2)]:
score = assert_warns(DeprecationWarning, metric, y_true, y_pred)
score_weighted = assert_no_warnings(metric, y_true, y_pred,
average='weighted')
assert_equal(score, score_weighted,
'average is not "weighted" by default')

# check binary passes without warning
assert_no_warnings(metric, [0, 1, 1], [0, 1, 0])



def test__check_clf_targets():
"""Check that _check_clf_targets correctly merges target types, squeezes
output and fails if input lengths differ."""
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/tests/test_score_objects.py
Expand Up @@ -36,13 +36,13 @@ def test_classification_scores():
clf = LinearSVC(random_state=0)
clf.fit(X_train, y_train)
score1 = SCORERS['f1'](clf, X_test, y_test)
score2 = f1_score(y_test, clf.predict(X_test))
score2 = f1_score(y_test, clf.predict(X_test), average='weighted')
assert_almost_equal(score1, score2)

# test fbeta score that takes an argument
scorer = make_scorer(fbeta_score, beta=2)
score1 = scorer(clf, X_test, y_test)
score2 = fbeta_score(y_test, clf.predict(X_test), beta=2)
score2 = fbeta_score(y_test, clf.predict(X_test), beta=2, average='weighted')
assert_almost_equal(score1, score2)

# test that custom scorer can be pickled
Expand Down
5 changes: 3 additions & 2 deletions sklearn/svm/tests/test_svm.py
Expand Up @@ -368,8 +368,9 @@ def test_auto_weight():
y_pred = clf.fit(X[unbalanced], y[unbalanced]).predict(X)
clf.set_params(class_weight='auto')
y_pred_balanced = clf.fit(X[unbalanced], y[unbalanced],).predict(X)
assert_true(metrics.f1_score(y, y_pred)
<= metrics.f1_score(y, y_pred_balanced))
assert_true(metrics.f1_score(y, y_pred, average='weighted')
<= metrics.f1_score(y, y_pred_balanced,
average='weighted'))


def test_bad_input():
Expand Down
4 changes: 2 additions & 2 deletions sklearn/tests/test_common.py
Expand Up @@ -981,8 +981,8 @@ def test_class_weight_auto_classifies():
classifier.set_params(class_weight='auto')
classifier.fit(X_train, y_train)
y_pred_auto = classifier.predict(X_test)
assert_greater(f1_score(y_test, y_pred_auto),
f1_score(y_test, y_pred))
assert_greater(f1_score(y_test, y_pred_auto, average='weighted'),
f1_score(y_test, y_pred, average='weighted'))


def test_estimators_overwrite_params():
Expand Down
2 changes: 1 addition & 1 deletion sklearn/tests/test_cross_validation.py
Expand Up @@ -594,7 +594,7 @@ def test_permutation_score():
assert_true(pvalue_label == pvalue)

# test with custom scoring object
scorer = make_scorer(fbeta_score, beta=2)
scorer = make_scorer(fbeta_score, beta=2, average='weighted')
score_label, _, pvalue_label = cval.permutation_test_score(
svm, X, y, scoring=scorer, cv=cv, labels=np.ones(y.size),
random_state=0)
Expand Down

0 comments on commit 26ac3cf

Please sign in to comment.