New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] Added support for multiclass Matthews correlation coefficient #8094
Changes from all commits
db088c4
2b69e58
0d61d44
405048c
15f0fe3
0d08a88
613d661
806448b
1c742cf
7b12781
ea985d4
66f9ed9
96a53eb
90c752c
fe87c9a
1595089
59d44ed
03f19f1
255a031
afd9ac0
8d52b91
14a19ee
5d70f46
a00118c
8c7ec47
0765fd6
cb2d373
7e65a77
34b2b2f
a1016e4
02e3dc5
455a78c
1f73d35
95f4d3d
552d8c2
f45015c
58df854
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -259,7 +259,7 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None): | |
raise ValueError("At least one label specified must be in y_true") | ||
|
||
if sample_weight is None: | ||
sample_weight = np.ones(y_true.shape[0], dtype=np.int) | ||
sample_weight = np.ones(y_true.shape[0], dtype=np.int64) | ||
else: | ||
sample_weight = np.asarray(sample_weight) | ||
|
||
|
@@ -278,8 +278,14 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None): | |
# also eliminate weights of eliminated items | ||
sample_weight = sample_weight[ind] | ||
|
||
# Choose the accumulator dtype to always have high precision | ||
if sample_weight.dtype.kind in {'i', 'u', 'b'}: | ||
dtype = np.int64 | ||
else: | ||
dtype = np.float64 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the logic of upcasting everything to the maximum resolution. Typically, I expect code to keep the same types as what I put in. If I put in float32, it is often a choice, to limit memory consumption. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because the confusion matrix accumulates values. Its common for accumulation functions to have a dtype that is different from the input dtype (see documentation of np.sum). This default dtype depends on the platform and one of these platforms (windows) had failing tests due to this behavior. The choice to always choose int64 is to maintain consistent cross-platform behavior. |
||
|
||
CM = coo_matrix((sample_weight, (y_true, y_pred)), | ||
shape=(n_labels, n_labels) | ||
shape=(n_labels, n_labels), dtype=dtype, | ||
).toarray() | ||
|
||
return CM | ||
|
@@ -452,7 +458,7 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True, | |
|
||
|
||
def matthews_corrcoef(y_true, y_pred, sample_weight=None): | ||
"""Compute the Matthews correlation coefficient (MCC) for binary classes | ||
"""Compute the Matthews correlation coefficient (MCC) | ||
|
||
The Matthews correlation coefficient is used in machine learning as a | ||
measure of the quality of binary (two-class) classifications. It takes into | ||
|
@@ -463,8 +469,9 @@ def matthews_corrcoef(y_true, y_pred, sample_weight=None): | |
an average random prediction and -1 an inverse prediction. The statistic | ||
is also known as the phi coefficient. [source: Wikipedia] | ||
|
||
Only in the binary case does this relate to information about true and | ||
false positives and negatives. See references below. | ||
Binary and multiclass labels are supported. Only in the binary case does | ||
this relate to information about true and false positives and negatives. | ||
See references below. | ||
|
||
Read more in the :ref:`User Guide <matthews_corrcoef>`. | ||
|
||
|
@@ -495,35 +502,40 @@ def matthews_corrcoef(y_true, y_pred, sample_weight=None): | |
.. [2] `Wikipedia entry for the Matthews Correlation Coefficient | ||
<https://en.wikipedia.org/wiki/Matthews_correlation_coefficient>`_ | ||
|
||
.. [3] `Gorodkin, (2004). Comparing two K-category assignments by a | ||
K-category correlation coefficient | ||
<http://www.sciencedirect.com/science/article/pii/S1476927104000799>`_ | ||
|
||
.. [4] `Jurman, Riccadonna, Furlanello, (2012). A Comparison of MCC and CEN | ||
Error Measures in MultiClass Prediction | ||
<http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0041882>`_ | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn.metrics import matthews_corrcoef | ||
>>> y_true = [+1, +1, +1, -1] | ||
>>> y_pred = [+1, -1, +1, +1] | ||
>>> matthews_corrcoef(y_true, y_pred) # doctest: +ELLIPSIS | ||
-0.33... | ||
|
||
""" | ||
y_type, y_true, y_pred = _check_targets(y_true, y_pred) | ||
|
||
if y_type != "binary": | ||
if y_type not in {"binary", "multiclass"}: | ||
raise ValueError("%s is not supported" % y_type) | ||
|
||
lb = LabelEncoder() | ||
lb.fit(np.hstack([y_true, y_pred])) | ||
y_true = lb.transform(y_true) | ||
y_pred = lb.transform(y_pred) | ||
mean_yt = np.average(y_true, weights=sample_weight) | ||
mean_yp = np.average(y_pred, weights=sample_weight) | ||
|
||
y_true_u_cent = y_true - mean_yt | ||
y_pred_u_cent = y_pred - mean_yp | ||
|
||
cov_ytyp = np.average(y_true_u_cent * y_pred_u_cent, weights=sample_weight) | ||
var_yt = np.average(y_true_u_cent ** 2, weights=sample_weight) | ||
var_yp = np.average(y_pred_u_cent ** 2, weights=sample_weight) | ||
|
||
mcc = cov_ytyp / np.sqrt(var_yt * var_yp) | ||
C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) | ||
t_sum = C.sum(axis=1) | ||
p_sum = C.sum(axis=0) | ||
n_correct = np.trace(C) | ||
n_samples = p_sum.sum() | ||
cov_ytyp = n_correct * n_samples - np.dot(t_sum, p_sum) | ||
cov_ypyp = n_samples ** 2 - np.dot(p_sum, p_sum) | ||
cov_ytyt = n_samples ** 2 - np.dot(t_sum, t_sum) | ||
mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp) | ||
|
||
if np.isnan(mcc): | ||
return 0. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should probably note that this no longer ranges from -1 to 1...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically it still does range from -1 to +1, because the multiclass case does encompass the binary case. However, when there are more than 2 labels it will not be possible to achieve -1. I'll note that:
When there are more than two labels, the value of the MCC will no longer range
between -1 and +1. Instead the minimum value will be somewhere between -1 and 0
depending on the number and distribution of ground true labels. The maximum
value is always +1.