Skip to content

Commit

Permalink
Make binary-only default for precision/recall/fscore and their scorers
Browse files Browse the repository at this point in the history
  • Loading branch information
jnothman committed Aug 31, 2014
1 parent fabb13e commit afe2d23
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 110 deletions.
8 changes: 4 additions & 4 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ Scoring Function
**Classification**
'accuracy' :func:`sklearn.metrics.accuracy_score`
'average_precision' :func:`sklearn.metrics.average_precision_score`
'f1_binary' :func:`sklearn.metrics.f1_score` with `pos_label=1`
'f1' :func:`sklearn.metrics.f1_score` for binary targets
'f1_micro' :func:`sklearn.metrics.f1_score` micro-averaged
'f1_macro' :func:`sklearn.metrics.f1_score` macro-averaged
'f1_weighted' :func:`sklearn.metrics.f1_score` weighted average
'f1_samples' :func:`sklearn.metrics.f1_score` by multilabel sample
'precision_...' :func:`sklearn.metrics.precision_score` likewise
'recall_...' :func:`sklearn.metrics.recall_score` likewise
'precision...' :func:`sklearn.metrics.precision_score` likewise
'recall...' :func:`sklearn.metrics.recall_score` likewise
'roc_auc' :func:`sklearn.metrics.roc_auc_score`

**Clustering**
Expand All @@ -82,7 +82,7 @@ and is shown if ``scoring`` is set to an unknown string::
>>> model = svm.SVC()
>>> cross_validation.cross_val_score(model, X, y, scoring='wrong_choice')
Traceback (most recent call last):
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1_binary', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'log_loss', 'mean_absolute_error', 'mean_squared_error', 'precision_binary', 'precision_macro', 'precision_micro', 'precision_weighted', 'r2', 'recall_binary', 'recall_macro', 'recall_micro', 'recall_weighted', 'roc_auc']
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_rand_score', 'average_precision', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'log_loss', 'mean_absolute_error', 'mean_squared_error', 'precision', 'precision_macro', 'precision_micro', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_weighted', 'roc_auc']

.. note::

Expand Down
6 changes: 3 additions & 3 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ API changes summary
and pass these to their distance metric. This will no longer be supported
in scikit-learn 0.18; use the ``metric_params`` argument instead.

- `scoring` parameter for cross validation now accepts `'f1_binary'`,
`'f1_micro'`, `'f1_macro'` or `'f1_weighted'`, deprecating the generic
`'f1'`. Similarly, `'precision'` and `'recall'` are deprecated.
- `scoring` parameter for cross validation now accepts `'f1_micro'`,
`'f1_macro'` or `'f1_weighted'`. `'f1'` is now for binary classification
only. Similar changes apply to `'precision'` and `'recall'`.
By `Joel Nothman`_.

- Users should now supply an explicit ``average`` parameter to
Expand Down
2 changes: 0 additions & 2 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from .scorer import make_scorer
from .scorer import SCORERS
from .scorer import get_scorer
from .scorer import list_scorers

# Deprecated in 0.16
from .ranking import auc_score
Expand All @@ -79,7 +78,6 @@
'homogeneity_score',
'jaccard_similarity_score',
'label_ranking_average_precision_score',
'list_scorers',
'log_loss',
'make_scorer',
'matthews_corrcoef',
Expand Down
32 changes: 18 additions & 14 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def zero_one_loss(y_true, y_pred, normalize=True, sample_weight=None):
return n_samples - score


def f1_score(y_true, y_pred, labels=None, pos_label=1, average='compat',
def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
sample_weight=None):
"""Compute the F1 score, also known as balanced F-score or F-measure
Expand Down Expand Up @@ -561,7 +561,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='compat',


def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
average='compat', sample_weight=None):
average='binary', sample_weight=None):
"""Compute the F-beta score
The F-beta score is the weighted harmonic mean of precision and recall,
Expand Down Expand Up @@ -823,20 +823,23 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
"""
average_options = (None, 'micro', 'macro', 'weighted', 'samples')
if average not in average_options and average != 'compat':
if average not in average_options and average != 'binary':
raise ValueError('average has to be one of ' +
str(average_options))
if beta <= 0:
raise ValueError("beta should be >0 in the F-beta score")

y_type, y_true, y_pred = _check_targets(y_true, y_pred)

if average == 'compat' and y_type != 'binary':
if average == 'binary' and y_type != 'binary':
warnings.warn('The default `weighted` averaging is deprecated, '
'and another default may be used from version 0.18. '
'and from version 0.18, use of precision, recall or '
'F-score with multiclass or multilabel data will result '
'in an exception. '
'Please set an explicit value for `average`, one of '
'%s.' % str(average_options),
DeprecationWarning, stacklevel=2)
'%s. In cross validation use, for instance, '
'scoring="f1_weighted" instead of scoring="f1".'
% str(average_options), DeprecationWarning, stacklevel=2)
average = 'weighted'

label_order = labels # save this for later
Expand All @@ -861,7 +864,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,

elif average == 'samples':
raise ValueError("Sample-based precision, recall, fscore is "
"not meaningful outside multilabel"
"not meaningful outside multilabel "
"classification. See the accuracy_score instead.")
else:
lb = LabelEncoder()
Expand Down Expand Up @@ -894,11 +897,12 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
### Select labels to keep ###

if y_type == 'binary' and average is not None and pos_label is not None:
if label_order is not None and len(label_order) == 2:
if average != 'binary' and label_order is not None \
and len(label_order) == 2:
warnings.warn('In the future, providing two `labels` values, as '
'well as `average` will average over those '
'labels. For now, please use `labels=None` with '
'`pos_label` to evaluate precision, recall and '
'well as `average!=\'binary\'` will average over '
'those labels. For now, please use `labels=None` '
'with `pos_label` to evaluate precision, recall and '
'F-score for the positive label only.',
FutureWarning)
if pos_label not in labels:
Expand Down Expand Up @@ -963,7 +967,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,


def precision_score(y_true, y_pred, labels=None, pos_label=1,
average='compat', sample_weight=None):
average='binary', sample_weight=None):
"""Compute the precision
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
Expand Down Expand Up @@ -1046,7 +1050,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
return p


def recall_score(y_true, y_pred, labels=None, pos_label=1, average='compat',
def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
sample_weight=None):
"""Compute the recall
Expand Down
63 changes: 16 additions & 47 deletions sklearn/metrics/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
# Authors: Andreas Mueller <amueller@ais.uni-bonn.de>
# Lars Buitinck <L.J.Buitinck@uva.nl>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# License: Simplified BSD

from abc import ABCMeta, abstractmethod
from functools import partial
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -89,7 +87,7 @@ def __call__(self, estimator, X, y_true, sample_weight=None):
else:
return self._sign * self._score_func(y_true, y_pred,
**self._kwargs)


class _ProbaScorer(_BaseScorer):
def __call__(self, clf, X, y, sample_weight=None):
Expand Down Expand Up @@ -186,28 +184,16 @@ def _factory_args(self):


def get_scorer(scoring):
"""Get a scorer by its name
Parameters
----------
scoring : string or callable
Returns
-------
scorer : callable
Returns the scorer of the given name if scoring is a string, and
otherwise the object passed in.
"""
if isinstance(scoring, six.string_types):
if scoring in SCORER_DEPRECATION:
warn(SCORER_DEPRECATION[scoring], DeprecationWarning)
try:
return SCORERS[scoring]
scorer = SCORERS[scoring]
except KeyError:
raise ValueError('%r is not a valid scoring name. '
'Valid options are %s' % (scoring,
list_scorers()))
return scoring
raise ValueError('%r is not a valid scoring value. '
'Valid options are %s'
% (scoring, sorted(SCORERS.keys())))
else:
scorer = scoring
return scorer


def _passthrough_scorer(estimator, *args, **kwargs):
Expand Down Expand Up @@ -325,17 +311,6 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
return cls(score_func, sign, kwargs)


def list_scorers():
"""Lists the names of known scorers
Returns
-------
scorer_names : list of strings
"""
return sorted(set(SCORERS) - set(SCORER_DEPRECATION))



# Standard regression scores
r2_scorer = make_scorer(r2_score)
mean_squared_error_scorer = make_scorer(mean_squared_error,
Expand All @@ -362,32 +337,26 @@ def list_scorers():
# Clustering scores
adjusted_rand_scorer = make_scorer(adjusted_rand_score)

SCORER_DEPRECATION = {}
SCORERS = dict(r2=r2_scorer,
mean_absolute_error=mean_absolute_error_scorer,
mean_squared_error=mean_squared_error_scorer,
accuracy=accuracy_scorer, f1=f1_scorer, roc_auc=roc_auc_scorer,
accuracy=accuracy_scorer, roc_auc=roc_auc_scorer,
average_precision=average_precision_scorer,
precision=precision_scorer, recall=recall_scorer,
log_loss=log_loss_scorer,
adjusted_rand_score=adjusted_rand_scorer)

msg = ("The {0!r} scorer has been deprecated and will be removed in version "
"0.17. Please choose one of '{0}_binary' or '{0}_weighted' depending "
"on your data; '{0}_macro', '{0}_micro' and '{0}_samples' provide "
"alternative multiclass/multilabel averaging.")
for name, metric in [('precision', precision_score),
('recall', recall_score), ('f1', f1_score)]:
('recall', recall_score), ('f1', f1_score)]:
SCORERS.update({
name: make_scorer(metric),
'{0}_binary'.format(name): make_scorer(partial(metric)),
'{0}'.format(name): make_scorer(partial(metric)),
'{0}_macro'.format(name): make_scorer(partial(metric, pos_label=None,
average='macro')),
average='macro')),
'{0}_micro'.format(name): make_scorer(partial(metric, pos_label=None,
average='micro')),
average='micro')),
'{0}_samples'.format(name): make_scorer(partial(metric, pos_label=None,
average='samples')),
'{0}_weighted'.format(name): make_scorer(partial(metric, pos_label=None,
average='weighted')),
'{0}_weighted'.format(name): make_scorer(partial(metric,
pos_label=None,
average='weighted')),
})
SCORER_DEPRECATION[name] = (msg.format(name))
2 changes: 1 addition & 1 deletion sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def test_prf_average_compat():
score_weighted = assert_no_warnings(metric, y_true, y_pred,
average='weighted')
assert_equal(score, score_weighted,
'average is not "weighted" by default')
'average does not act like "weighted" by default')

# check binary passes without warning
assert_no_warnings(metric, [0, 1, 1], [0, 1, 0])
Expand Down
Loading

0 comments on commit afe2d23

Please sign in to comment.