Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 95 additions & 30 deletions imblearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from inspect import getcallargs

import numpy as np
import scipy as sp

from sklearn.metrics.classification import (_check_targets, _prf_divide,
precision_recall_fscore_support)
Expand Down Expand Up @@ -460,20 +461,27 @@ def geometric_mean_score(y_true,
y_pred,
labels=None,
pos_label=1,
average='binary',
sample_weight=None):
average='multiclass',
sample_weight=None,
correction=0.0):
"""Compute the geometric mean

The geometric mean is the squared root of the product of the sensitivity
and specificity. This measure tries to maximize the accuracy on each
of the two classes while keeping these accuracies balanced.
The geometric mean (G-mean) is the root of the product of class-wise
sensitivity. This measure tries to maximize the accuracy on each of the
classes while keeping these accuracies balanced. For binary classification
G-mean is the squared root of the product of the sensitivity
and specificity. For multi-class problems it is a higher root of the
product of sensitivity for each class.

The specificity is the ratio ``tp / (tp + fn)`` where ``tp`` is the number
of true positives and ``fn`` the number of false negatives. The specificity
is intuitively the ability of the classifier to find all the positive
samples.
For compatibility with other imbalance performance measures, G-mean can
calculated for each class separately on a one-vs-rest basis when
``average != 'multiclass'``.

The best value is 1 and the worst value is 0.
The best value is 1 and the worst value is 0. Traditionally if at least one
class is unrecognized by the classifier, G-mean resolves to zero. To
alleviate this property, for highly multi-class the sensitivity of
unrecognized classes can be "corrected" to be a user specified value
(instead of zero). This option works only if ``average == 'multiclass'``.

Parameters
----------
Expand All @@ -492,11 +500,11 @@ def geometric_mean_score(y_true,

pos_label : str or int, optional (default=1)
The class to report if ``average='binary'`` and the data is binary.
If the data are multiclass or multilabel, this will be ignored;
If the data are multiclass, this will be ignored;
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
scores for that label only.

average : str or None, optional (default=None)
average : str or None, optional (default=``'multiclass'``)
If ``None``, the scores for each class are returned. Otherwise, this
determines the type of averaging performed on the data:

Expand All @@ -519,24 +527,26 @@ def geometric_mean_score(y_true,
meaningful for multilabel classification where this differs from
:func:`accuracy_score`).

warn_for : tuple or set, for internal use
This determines which warnings will be made in the case that this
function is being used to return only one of its metrics.

sample_weight : ndarray, shape (n_samples, )
Sample weights.

correction: float, optional (default=0.0)
Substitutes sensitivity of unrecognized classes from zero to a given
value.

Returns
-------
geometric_mean : float (if ``average`` = None) or ndarray, \
shape (n_unique_labels, )
geometric_mean : float

Examples
--------
>>> import numpy as np
>>> from imblearn.metrics import geometric_mean_score
>>> y_true = [0, 1, 2, 0, 1, 2]
>>> y_pred = [0, 2, 1, 0, 0, 1]
>>> geometric_mean_score(y_true, y_pred)
0.0
>>> geometric_mean_score(y_true, y_pred, correction=0.001)
0.010000000000000004
>>> geometric_mean_score(y_true, y_pred, average='macro')
0.47140452079103168
>>> geometric_mean_score(y_true, y_pred, average='micro')
Expand All @@ -556,18 +566,66 @@ def geometric_mean_score(y_true,
36(3), (2003), pp 849-851.

"""
sen, spe, _ = sensitivity_specificity_support(
y_true,
y_pred,
labels=labels,
pos_label=pos_label,
average=average,
warn_for=('specificity', 'specificity'),
sample_weight=sample_weight)
if average is None or average != 'multiclass':
sen, spe, _ = sensitivity_specificity_support(
y_true,
y_pred,
labels=labels,
pos_label=pos_label,
average=average,
warn_for=('specificity', 'specificity'),
sample_weight=sample_weight)

LOGGER.debug('The sensitivity and specificity are : %s - %s' %
(sen, spe))

return np.sqrt(sen * spe)
else:
present_labels = unique_labels(y_true, y_pred)

if labels is None:
labels = present_labels
n_labels = None
else:
n_labels = len(labels)
labels = np.hstack([labels, np.setdiff1d(present_labels, labels,
assume_unique=True)])

le = LabelEncoder()
le.fit(labels)
y_true = le.transform(y_true)
y_pred = le.transform(y_pred)
sorted_labels = le.classes_

# labels are now from 0 to len(labels) - 1 -> use bincount
tp = y_true == y_pred
tp_bins = y_true[tp]

if sample_weight is not None:
tp_bins_weights = np.asarray(sample_weight)[tp]
else:
tp_bins_weights = None

if len(tp_bins):
tp_sum = bincount(tp_bins, weights=tp_bins_weights,
minlength=len(labels))
else:
# Pathological case
true_sum = tp_sum = np.zeros(len(labels))
if len(y_true):
true_sum = bincount(y_true, weights=sample_weight,
minlength=len(labels))

# Retain only selected labels
indices = np.searchsorted(sorted_labels, labels[:n_labels])
tp_sum = tp_sum[indices]
true_sum = true_sum[indices]

LOGGER.debug('The sensitivity and specificity are : %s - %s' % (sen, spe))
recall = _prf_divide(tp_sum, true_sum, "recall", "true", None,
"recall")
recall[recall == 0] = correction

return np.sqrt(sen * spe)
return sp.stats.mstats.gmean(recall)


def make_index_balanced_accuracy(alpha=0.1, squared=True):
Expand Down Expand Up @@ -616,7 +674,14 @@ def compute_score(*args, **kwargs):
# Get the signature of the sens/spec function
sens_spec_sig = signature(sensitivity_specificity_support)
# Filter the inputs required by the sens/spec function
tags_sens_spec = sens_spec_sig.bind(**tags_scoring_func)
if scoring_func != geometric_mean_score:
tags_sens_spec = sens_spec_sig.bind(**tags_scoring_func)
else:
# Adapt the parameters to sens/spec function
del tags_scoring_func['correction']
if "average" not in kwargs:
tags_scoring_func['average'] = 'binary'
tags_sens_spec = sens_spec_sig.bind(**tags_scoring_func)
# Call the sens/spec function
sen, spe, _ = sensitivity_specificity_support(
*tags_sens_spec.args,
Expand Down
93 changes: 78 additions & 15 deletions imblearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from numpy.testing import (assert_allclose, assert_array_equal,
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
assert_no_warnings, assert_equal,
assert_almost_equal, assert_raises)
from sklearn.utils.testing import assert_warns_message, ignore_warnings
Expand All @@ -27,7 +27,6 @@
from imblearn.metrics import classification_report_imbalanced

RND_SEED = 42
R_TOL = 1e-2

###############################################################################
# Utilities for testing
Expand Down Expand Up @@ -88,8 +87,8 @@ def test_sensitivity_specificity_score_binary():
# detailed measures for each class
sen, spe, sup = sensitivity_specificity_support(
y_true, y_pred, average=None)
assert_allclose(sen, [0.88, 0.68], rtol=R_TOL)
assert_allclose(spe, [0.68, 0.88], rtol=R_TOL)
assert_array_almost_equal(sen, [0.88, 0.68], 2)
assert_array_almost_equal(spe, [0.68, 0.88], 2)
assert_array_equal(sup, [25, 25])

# individual scoring function that can be used for grid search: in the
Expand All @@ -99,10 +98,10 @@ def test_sensitivity_specificity_score_binary():
'average': 'binary'
}, assert_no_warnings)]:
sen = my_assert(sensitivity_score, y_true, y_pred, **kwargs)
assert_allclose(sen, 0.68, rtol=R_TOL)
assert_array_almost_equal(sen, 0.68, 2)

spe = my_assert(specificity_score, y_true, y_pred, **kwargs)
assert_allclose(spe, 0.88, rtol=R_TOL)
assert_array_almost_equal(spe, 0.88, 2)


def test_sensitivity_specificity_f_binary_single_class():
Expand All @@ -125,22 +124,22 @@ def test_sensitivity_specificity_extra_labels():
# No average: zeros in array
actual = specificity_score(
y_true, y_pred, labels=[0, 1, 2, 3, 4], average=None)
assert_allclose([1., 0.67, 1., 1., 1.], actual, rtol=R_TOL)
assert_array_almost_equal([1., 0.67, 1., 1., 1.], actual, 2)

# Macro average is changed
actual = specificity_score(
y_true, y_pred, labels=[0, 1, 2, 3, 4], average='macro')
assert_allclose(np.mean([1., 0.67, 1., 1., 1.]), actual, rtol=R_TOL)
assert_array_almost_equal(np.mean([1., 0.67, 1., 1., 1.]), actual, 2)

# Check for micro
actual = specificity_score(
y_true, y_pred, labels=[0, 1, 2, 3, 4], average='micro')
assert_allclose(15. / 16., actual, rtol=R_TOL)
assert_array_almost_equal(15. / 16., actual)

# Check for weighted
actual = specificity_score(
y_true, y_pred, labels=[0, 1, 2, 3, 4], average='macro')
assert_allclose(np.mean([1., 0.67, 1., 1., 1.]), actual, rtol=R_TOL)
assert_array_almost_equal(np.mean([1., 0.67, 1., 1., 1.]), actual, 2)


@ignore_warnings
Expand All @@ -152,7 +151,7 @@ def test_sensitivity_specificity_ignored_labels():
specificity_13 = partial(specificity_score, y_true, y_pred, labels=[1, 3])
specificity_all = partial(specificity_score, y_true, y_pred, labels=None)

assert_allclose([1., 0.33], specificity_13(average=None), rtol=R_TOL)
assert_array_almost_equal([1., 0.33], specificity_13(average=None), 2)
assert_almost_equal(
np.mean([1., 0.33]), specificity_13(average='macro'), 2)
assert_almost_equal(
Expand Down Expand Up @@ -224,20 +223,84 @@ def test_geometric_mean_support_binary():


def test_geometric_mean_multiclass():
"""Test geometric mean for multiclass classification task"""
y_true = [0, 0, 1, 1]
y_pred = [0, 0, 1, 1]
assert_almost_equal(geometric_mean_score(y_true, y_pred), 1.0, 10)

y_true = [0, 0, 0, 0]
y_pred = [1, 1, 1, 1]
assert_almost_equal(geometric_mean_score(y_true, y_pred), 0.0, 10)

cor = 0.001
y_true = [0, 0, 0, 0]
y_pred = [0, 0, 0, 0]
assert_almost_equal(geometric_mean_score(y_true, y_pred, correction=cor),
1.0, 10)

y_true = [0, 0, 0, 0]
y_pred = [1, 1, 1, 1]
assert_almost_equal(geometric_mean_score(y_true, y_pred, correction=cor),
cor, 10)

y_true = [0, 0, 1, 1]
y_pred = [0, 1, 1, 0]
assert_almost_equal(geometric_mean_score(y_true, y_pred, correction=cor),
0.5, 10)

y_true = [0, 1, 2, 0, 1, 2]
y_pred = [0, 2, 1, 0, 0, 1]
assert_almost_equal(geometric_mean_score(y_true, y_pred, correction=cor),
(1*cor*cor)**(1.0/3.0), 10)

y_true = [0, 1, 2, 3, 4, 5]
y_pred = [0, 1, 2, 3, 4, 5]
assert_almost_equal(geometric_mean_score(y_true, y_pred, correction=cor),
1, 10)

y_true = [0, 1, 1, 1, 1, 0]
y_pred = [0, 0, 1, 1, 1, 1]
assert_almost_equal(geometric_mean_score(y_true, y_pred, correction=cor),
(0.5*0.75)**0.5, 10)

y_true = [0, 1, 2, 0, 1, 2]
y_pred = [0, 2, 1, 0, 0, 1]
assert_almost_equal(geometric_mean_score(y_true, y_pred, average='macro'),
0.47140452079103168, 10)
assert_almost_equal(geometric_mean_score(y_true, y_pred, average='micro'),
0.47140452079103168, 10)
assert_almost_equal(geometric_mean_score(y_true, y_pred,
average='weighted'),
0.47140452079103168, 10)
assert_almost_equal(geometric_mean_score(y_true, y_pred, average=None),
[0.8660254, 0.0, 0.0])

y_true = [0, 1, 2, 0, 1, 2]
y_pred = [0, 1, 1, 0, 0, 1]
assert_almost_equal(geometric_mean_score(y_true, y_pred, labels=[0, 1]),
0.70710678118654752, 10)
assert_almost_equal(geometric_mean_score(y_true, y_pred, labels=[0, 1],
sample_weight=[1, 2, 1, 1, 2, 1]),
0.70710678118654752, 10)
assert_almost_equal(geometric_mean_score(y_true, y_pred, labels=[0, 1],
sample_weight=[1, 2, 1, 1, 2, 1],
average='weighted'),
0.3333333333, 10)

y_true, y_pred, _ = make_prediction(binary=False)

geo_mean = geometric_mean_score(y_true, y_pred)
assert_array_almost_equal(geo_mean, 0.41, 2)

# Compute the geometric mean for each of the classes
geo_mean = geometric_mean_score(y_true, y_pred, average=None)
assert_allclose(geo_mean, [0.85, 0.29, 0.7], rtol=R_TOL)
assert_array_almost_equal(geo_mean, [0.85, 0.29, 0.7], 2)

# average tests
geo_mean = geometric_mean_score(y_true, y_pred, average='macro')
assert_almost_equal(geo_mean, 0.68, 2)

geo_mean = geometric_mean_score(y_true, y_pred, average='weighted')
assert_allclose(geo_mean, 0.65, rtol=R_TOL)

assert_array_almost_equal(geo_mean, 0.65, 2)

def test_iba_geo_mean_binary():
"""Test to test the iba using the geometric mean"""
Expand Down