Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Added support for multiclass Matthews correlation coefficient #8094

Merged
merged 37 commits into from Jun 19, 2017

Conversation

@Erotemic
Copy link
Contributor

@Erotemic Erotemic commented Dec 21, 2016

What does this implement/fix? Explain your changes.

This extends the current matthews_corrcoef to handle the multiclass case.

Also fixes #7929 and #8354

The extension is defined here: http://www.sciencedirect.com/science/article/pii/S1476927104000799
(pdf is behind a paywall, but the author has a website with details here http://rk.kvl.dk/introduction/index.html )

and my implementation follows equation (2) in this paper:
http://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0041882&type=printable

The new implementation can handle both the binary and multiclass case. I've left in the original binary case implementation for now (it is a bit faster and more clear as to what is going on).

I've added new tests that inspect properties of the multiclass case as well as ensure that the multiclass case reduces to the binary case.

There's not much else to say. This is a pretty straight forward change.

@Erotemic
Copy link
Contributor Author

@Erotemic Erotemic commented Dec 22, 2016

I made some updates to this code to both simply it and unify the binary and multiclass case.

For both my reference and to describe my process here are the iterations I went through to change the code.

The original binary case computed the MCC as such:

        mean_yt = np.average(y_true, weights=sample_weight)
        mean_yp = np.average(y_pred, weights=sample_weight)

        y_true_u_cent = y_true - mean_yt
        y_pred_u_cent = y_pred - mean_yp

        cov_ytyp = np.average(y_true_u_cent * y_pred_u_cent,
                              weights=sample_weight)
        var_yt = np.average(y_true_u_cent ** 2, weights=sample_weight)
        var_yp = np.average(y_pred_u_cent ** 2, weights=sample_weight)

        mcc = cov_ytyp / np.sqrt(var_yt * var_yp)

My first pass at computing the multiclass looked like this and directly followed
this paper

    C = confusion_matrix(y_pred, y_true, sample_weight=sample_weight)
    N = len(C)
    cov_ytyp = sum([
        C[k, k] * C[m, l] - C[l, k] * C[k, m]
        for k in range(N) for m in range(N) for l in range(N)
    ])
    cov_ytyt = sum([
        C[:, k].sum() *
        np.sum([C[g, f] for f in range(N) for g in range(N) if f != k])
        for k in range(N)
    ])
    cov_ypyp = np.sum([
        C[k, :].sum() *
        np.sum([C[f, g] for f in range(N) for g in range(N) if f != k])
        for k in range(N)
    ])
    mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)

I was able to improve on this a bit using numpy shortcuts

        C = confusion_matrix(y_pred, y_true, sample_weight=sample_weight)
        N = len(C)
        cov_ytyp = ((np.diag(C)[:, np.newaxis, np.newaxis] * C).sum() -
                    (C[np.newaxis, :, :] * C[:, :, np.newaxis]).sum())
        cov_ytyt = np.sum([
            (C[:, k].sum() * (C[:, :k].sum() + C[:, k + 1:].sum()))
            for k in range(N)
        ])
        cov_ypyp = np.sum([
            (C[k, :].sum() * (C[:k, :].sum() + C[k + 1:, :].sum()))
            for k in range(N)
        ])
        mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)

My latest iteration significantly simplifies and increases the interpretability of the code. It runs the fastest out of the multiclass options I've written and only runs marginally slower for the binary case (231.1230 µs vs 285.2201 µs on a set of binary labels of length 200).

    class_covariances = (
        np.cov(y_pred == k, y_true == k, bias=True, fweights=sample_weight)
        for k in range(len(lb.classes_))
    )
    covariance = np.sum(class_covariances, axis=0)
    cov_ypyp, cov_ytyp, _, cov_ytyt = covariance.ravel()
    mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)

Moving to this simplified code exposed a small bug in the original tests.
I had to remove the following line:

    y_true_inv2 = label_binarize(y_true, ["a", "b"]) * -1
    assert_almost_equal(matthews_corrcoef(y_true, y_true_inv2), -1)

because the only reason it was working was due to bug #8098

Copy link
Member

@jnothman jnothman left a comment

Description of this extension is due in doc/modules/model_evaluation.rst.

I also wonder why you prefer the calculation based on cov over using confusion_matrix, which I suspect would be more readable given the discrete application.

@Erotemic
Copy link
Contributor Author

@Erotemic Erotemic commented Dec 27, 2016

The np.cov implementation more closely resembles the calculation used in the original paper. The fancy indexing in the list comprehensions looks a bit more confusing to me. However, it seems older versions of numpy don't support the fweights keyword argument, so regardless, I'll have to switch back to the confusion matrix implementation.

I'll make that change and update doc/modules/model_evaluation.rst

@jnothman
Copy link
Member

@jnothman jnothman commented Dec 27, 2016

@Erotemic
Copy link
Contributor Author

@Erotemic Erotemic commented Dec 27, 2016

My second comment shows my iterations on approaching the problem. The 3rd block of code shows my best confusion_matrix implementation. However, I'll repost the relevant code blocks here to avoid confusion:

My confusion_matrix the implementation is:

        C = confusion_matrix(y_pred, y_true, sample_weight=sample_weight)
        N = len(C)
        cov_ytyp = ((np.diag(C)[:, np.newaxis, np.newaxis] * C).sum() -
                    (C[np.newaxis, :, :] * C[:, :, np.newaxis]).sum())
        cov_ytyt = np.sum([
            (C[:, k].sum() * (C[:, :k].sum() + C[:, k + 1:].sum()))
            for k in range(N)
        ])
        cov_ypyp = np.sum([
            (C[k, :].sum() * (C[:k, :].sum() + C[k + 1:, :].sum()))
            for k in range(N)
        ])
        mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)

Whereas the np.cov implementation looks like this:

    class_covariances = (
        np.cov(y_pred == k, y_true == k, bias=True, fweights=sample_weight)
        for k in range(len(lb.classes_))
    )
    covariance = np.sum(class_covariances, axis=0)
    cov_ypyp, cov_ytyp, _, cov_ytyt = covariance.ravel()
    mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)

Perhaps there is a way to make the confusion_matrix implementation more concise that I've been unable to think of?

@jnothman
Copy link
Member

@jnothman jnothman commented Dec 28, 2016

Sorry I'd failed to read your comment above.

Isn't C[:k, :].sum() + C[k + 1:, :].sum() the same as C.sum() - C[k].sum()? Or am I misreading?

If so, I get

s = C.sum(axis=1)
cov_ypyp = s.sum() ** 2 - np.dot(s, s)

I must be doing something wrong.

@Erotemic
Copy link
Contributor Author

@Erotemic Erotemic commented Dec 28, 2016

I think you are correct. Following your observation, I was also able to see a similar pattern in computing cov_ytyp. I've been able to greatly simplify the above code removing all need for list comprehensions and np.newaxis. The new code also runs about 7x faster.

C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
t_sum = C.sum(axis=1)
p_sum = C.sum(axis=0)
n_correct = np.diag(C).sum()
n_samples = p_sum.sum()
cov_ytyp = n_correct * n_samples - np.dot(t_sum, p_sum)
cov_ypyp = n_samples ** 2 - np.dot(p_sum, p_sum)
cov_ytyt = n_samples ** 2 - np.dot(t_sum, t_sum)
mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)
Erotemic added 3 commits Dec 28, 2016
… version. Edited relevant documentation.
@jnothman
Copy link
Member

@jnothman jnothman commented Dec 29, 2016

7x faster than using np.cov? or than list comprehensions?

Copy link
Member

@jnothman jnothman left a comment

I'll try to review test and correctness soon.

C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
t_sum = C.sum(axis=1)
p_sum = C.sum(axis=0)
n_correct = np.diag(C).sum()

This comment has been minimized.

@jnothman

jnothman Dec 29, 2016
Member

can use np.trace(C)

This comment has been minimized.

@Erotemic

Erotemic Dec 29, 2016
Author Contributor

duh @me, fixed.

Copy link
Member

@jnothman jnothman left a comment

otherwise, LGTM!!

# These two weighted vectors have 0 correlation and hence mcc should be 0
y_1 = [0, 1, 2, 0, 1, 2, 0, 1, 2]
y_2 = [1, 1, 1, 2, 2, 2, 0, 0, 0]
np.cov(y_1, y_2)

This comment has been minimized.

This comment has been minimized.

@Erotemic

Erotemic Dec 29, 2016
Author Contributor

A mistake in the comment, and a leftover np.cov from testing. Fixing.

}{\sqrt{
(s^2 - \sum_{k}^{K} p_k^2) \times
(s^2 - \sum_{k}^{K} t_k^2)
}}

This comment has been minimized.

@jnothman

jnothman Dec 29, 2016
Member

You should probably note that this no longer ranges from -1 to 1...?

This comment has been minimized.

@Erotemic

Erotemic Dec 29, 2016
Author Contributor

Technically it still does range from -1 to +1, because the multiclass case does encompass the binary case. However, when there are more than 2 labels it will not be possible to achieve -1. I'll note that:

When there are more than two labels, the value of the MCC will no longer range
between -1 and +1. Instead the minimum value will be somewhere between -1 and 0
depending on the number and distribution of ground true labels. The maximum
value is always +1.


.. math::
MCC = \frac{tp \times tn - fp \times fn}{\sqrt{(tp + fp)(tp + fn)(tn + fp)(tn + fn)}}.
In the multiclass case, the Matthews correlation coefficient can be `defined
<http://rk.kvl.dk/introduction/index.html>` in terms of a
:ref:`sphx_glr_auto_examples_model_selection_plot_confusion_matrix.py`

This comment has been minimized.

@jnothman

jnothman Dec 29, 2016
Member

I think

:func:`confusion_matrix`

would be more apt here.

This comment has been minimized.

@Erotemic

Erotemic Dec 29, 2016
Author Contributor

fixing

@jnothman jnothman changed the title Added support for multiclass Matthews correlation coefficient [MRG+1] Added support for multiclass Matthews correlation coefficient Dec 29, 2016
@Erotemic
Copy link
Contributor Author

@Erotemic Erotemic commented Dec 29, 2016

@jnothman It was 7x faster than list comprehensions. That benchmark does not include the time it takes to compute the confusion matrix, which is the bottleneck of the function.


On a separate note

I'm noticing on AppVeyor that a test for MCC is failing, and I'm not sure why as it seems to pass on my machine as well as the other CI machines.

The test that is failing (test_common.py.test_sample_weight_invariance:check_sample_weight_invariance(matthews_corrcoef_score) is a yeild test, which is something that I had problems with in #7654. I'm not sure if it is causing issues here. I've manually tried scaling the sample weights on some dummy data and it always seems consistent when I do it.

[00:07:30] ======================================================================
[00:07:30] FAIL: C:\Python27-x64\lib\site-packages\sklearn\metrics\tests\test_common.py.test_sample_weight_invariance:check_sample_weight_invariance(matthews_corrcoef_score)
[00:07:30] ----------------------------------------------------------------------
[00:07:30] Traceback (most recent call last):
[00:07:30]   File "C:\Python27-x64\lib\site-packages\nose\case.py", line 197, in runTest
[00:07:30]     self.test(*self.arg)
[00:07:30]   File "C:\Python27-x64\lib\site-packages\sklearn\utils\testing.py", line 741, in __call__
[00:07:30]     return self.check(*args, **kwargs)
[00:07:30]   File "C:\Python27-x64\lib\site-packages\sklearn\utils\testing.py", line 292, in wrapper
[00:07:30]     return fn(*args, **kwargs)
[00:07:30]   File "C:\Python27-x64\lib\site-packages\sklearn\metrics\tests\test_common.py", line 1006, in check_sample_weight_invariance
[00:07:30]     "under scaling" % name)
[00:07:30]   File "C:\Python27-x64\lib\site-packages\numpy\testing\utils.py", line 490, in assert_almost_equal
[00:07:30]     raise AssertionError(_build_err_msg())
[00:07:30] AssertionError: 
[00:07:30] Arrays are not almost equal to 7 decimals
[00:07:30] matthews_corrcoef_score sample_weight is not invariant under scaling
[00:07:30]  ACTUAL: 0.19988003199146895
[00:07:30]  DESIRED: 0.61482001028003908
[00:07:30] 
[00:07:30] ======================================================================
[00:07:30] FAIL: C:\Python27-x64\lib\site-packages\sklearn\metrics\tests\test_common.py.test_sample_weight_invariance:check_sample_weight_invariance(matthews_corrcoef_score)
[00:07:30] ----------------------------------------------------------------------
[00:07:30] Traceback (most recent call last):
[00:07:30]   File "C:\Python27-x64\lib\site-packages\nose\case.py", line 197, in runTest
[00:07:30]     self.test(*self.arg)
[00:07:30]   File "C:\Python27-x64\lib\site-packages\sklearn\utils\testing.py", line 741, in __call__
[00:07:30]     return self.check(*args, **kwargs)
[00:07:30]   File "C:\Python27-x64\lib\site-packages\sklearn\utils\testing.py", line 292, in wrapper
[00:07:30]     return fn(*args, **kwargs)
[00:07:30]   File "C:\Python27-x64\lib\site-packages\sklearn\metrics\tests\test_common.py", line 1006, in check_sample_weight_invariance
[00:07:30]     "under scaling" % name)
[00:07:30]   File "C:\Python27-x64\lib\site-packages\numpy\testing\utils.py", line 490, in assert_almost_equal
[00:07:30]     raise AssertionError(_build_err_msg())
[00:07:30] AssertionError: 
[00:07:30] Arrays are not almost equal to 7 decimals
[00:07:30] matthews_corrcoef_score sample_weight is not invariant under scaling
[00:07:30]  ACTUAL: 0.0
[00:07:30]  DESIRED: -0.039763715905510061

When I run

nosetests "sklearn/metrics/tests/test_common.py:test_sample_weight_invariance" --verbose 2>&1 | grep matthew

it outputs

/home/joncrall/code/scikit-learn/sklearn/metrics/tests/test_common.py.test_sample_weight_invariance:check_sample_weight_invariance(matthews_corrcoef_score) ... ok
/home/joncrall/code/scikit-learn/sklearn/metrics/tests/test_common.py.test_sample_weight_invariance:check_sample_weight_invariance(matthews_corrcoef_score) ... ok

Am I running the test wrong? Is there anything about this test or AppVeyor that is known to be unstable?


Continuing to look into the AppVeyor failure and I'm just unable to reproduce the issue. I wrote the following standalone script with additional cases and I don't see how the failure numbers could be getting generated. The function seems perfectly scale invariant.

from sklearn.metrics import matthews_corrcoef
import numpy as np

def test_scaled(metric, y1, y2, rng):
    sample_weight = rng.randint(1, 10, size=len(y1))
    mcc_want = metric(y1, y2, sample_weight=sample_weight)
    print('mcc_want = %r' % (mcc_want,))
    # print('sample_weight = %r' % (sample_weight,))
    # print('y1 = %r' % (y1,))
    # print('y2 = %r' % (y2,))
    # print('mcc_want   = %r' % (mcc_want,))
    for s in [.003, .03, .5, 2, 2.1, 10.9]:
        weight = sample_weight * s
        mcc = metric(y1, y2, sample_weight=weight)
        # print('weight = %r' % (weight,))
        # print('mcc = %r' % (mcc,))
        assert np.isclose(mcc, mcc_want)

# rng = np.random
rng = np.random.RandomState(0)
metric = matthews_corrcoef

for n_classes in range(1, 10):
    for n_samples in [1, 2, 5, 20, 50, 100, 1000]:
        y1 = rng.randint(0, n_classes, size=(n_samples, ))
        y2 = rng.randint(0, n_classes, size=(n_samples, ))
        print('n_classes, n_samples = %r, %r' % (n_classes, n_samples,))
        test_scaled(metric, y1, y2, rng)
Erotemic and others added 2 commits Dec 29, 2016
@jnothman
Copy link
Member

@jnothman jnothman commented Dec 29, 2016

I hope you don't mind me hacking your branch to try debug this. I'm as lost as you are.... except that I thought I'd check that this isn't an issue of sample_weight being modified somehow between checks (not that I see how it can be).

@jnothman
Copy link
Member

@jnothman jnothman commented Dec 30, 2016

Still failing. It's hard to comprehend what might make this fail in Windows but work elsewhere, if not for some kind of interaction across generated tests. :\

@jnothman
Copy link
Member

@jnothman jnothman commented Dec 30, 2016

I've considered replacing y1 with y1.copy(), etc, again to isolate assertions/metrics from one another, but I've not done it yet.

Any bright ideas for debugging an appveyor failure... @ogrisel, @lesteve?


.. math::
MCC = \frac{tp \times tn - fp \times fn}{\sqrt{(tp + fp)(tp + fn)(tn + fp)(tn + fn)}}.
In the multiclass case, the Matthews correlation coefficient can be `defined
<http://rk.kvl.dk/introduction/index.html>` in terms of a

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Mar 5, 2017
Member

Aren't missing an underscore at the end of the markup for this link?

if sample_weight.dtype.kind in {'i', 'u', 'b'}:
dtype = np.int64
else:
dtype = np.float64

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Mar 5, 2017
Member

I don't understand the logic of upcasting everything to the maximum resolution.

Typically, I expect code to keep the same types as what I put in. If I put in float32, it is often a choice, to limit memory consumption.

This comment has been minimized.

@Erotemic

Erotemic Mar 5, 2017
Author Contributor

This is because the confusion matrix accumulates values. Its common for accumulation functions to have a dtype that is different from the input dtype (see documentation of np.sum). This default dtype depends on the platform and one of these platforms (windows) had failing tests due to this behavior. The choice to always choose int64 is to maintain consistent cross-platform behavior.

# The minimum will be different for depending on the input
y_true = [0, 0, 1, 1, 2, 2]
y_pred_min = [1, 1, 0, 0, 0, 0]
assert_almost_equal(matthews_corrcoef(y_true, y_pred_min), -0.6123724)

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Mar 5, 2017
Member

Where does this value come from? I am not very comfortable with tests comparing against such hard-coded value if it is not easily understandable why the value is the correct one.

This comment has been minimized.

@Erotemic

Erotemic Mar 5, 2017
Author Contributor

This is simply the correct output for this specific multiclass instance. The reason why this specific example takes a weird value and not -1 is because technically some of the negative predictions are correct. When there are 2 classes you can construct an instance that is completely wrong, but in more than 2 classes every time you say class 2 when it should have been class 1, you are technically correct that it wasn't class 0, so you'll always get something right.

Perhaps -12 / np.sqrt(24 * 16) would be better? I'm not sure how to give a better intuition without redefining the function itself or using a lot of terms.

I actually found this particular example by doing a brute force search over 6 examples with 3 labels to find the minimum value the MCC would take in this instance.

@GaelVaroquaux
Copy link
Member

@GaelVaroquaux GaelVaroquaux commented Mar 5, 2017

I made a few small comments. Overall, this looks good.

@jnothman
Copy link
Member

@jnothman jnothman commented Mar 5, 2017

@jnothman
Copy link
Member

@jnothman jnothman commented Mar 5, 2017

@@ -366,8 +397,6 @@ def test_matthews_corrcoef():
y_true_inv = ["b" if i == "a" else "a" for i in y_true]

assert_almost_equal(matthews_corrcoef(y_true, y_true_inv), -1)
y_true_inv2 = label_binarize(y_true, ["a", "b"]) * -1

This comment has been minimized.

@lesteve

lesteve Mar 6, 2017
Member

I think #8377 should be merged before this PR, which would reduce its scope. I already had @jnothman's +1 maybe @GaelVaroquaux you can have a look?

@lesteve
Copy link
Member

@lesteve lesteve commented Apr 25, 2017

I think #8377 should be merged before this PR, which would reduce its scope. I already had @jnothman's +1 maybe @GaelVaroquaux you can have a look?

#8377 has been merged, I fixed the conflicts via the web interface, let's see what the CIs have to say.

@agramfort
Copy link
Member

@agramfort agramfort commented Jun 8, 2017

@jnothman @lesteve all green here

good to go?

.. [4] `Jurman, Riccadonna, Furlanello, (2012). A Comparison of MCC and CEN
Error Measures in MultiClass Prediction
<http://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0041882>`_

This comment has been minimized.

This comment has been minimized.

@Erotemic

Erotemic Jun 8, 2017
Author Contributor

fixed

@lesteve
Copy link
Member

@lesteve lesteve commented Jun 8, 2017

It would be nice to merge this one during the sprint. It looks like it is useful and has been sitting idle for a while.

@agramfort
Copy link
Member

@agramfort agramfort commented Jun 8, 2017

@lesteve
Copy link
Member

@lesteve lesteve commented Jun 8, 2017

then merge and fix the link on master :)

For completeness, you can even push into people's branch now (or edit inline via the github web interface for small things). So if you have the necessary rights you can do the fix yourself before merging.

@lesteve
Copy link
Member

@lesteve lesteve commented Jun 8, 2017

Also I am not familiar at all with the ML aspects of this PR.

@jnothman
Copy link
Member

@jnothman jnothman commented Jun 8, 2017

@jnothman jnothman added this to the 0.19 milestone Jun 18, 2017
@jnothman
Copy link
Member

@jnothman jnothman commented Jun 19, 2017

I think there is consensus to merge this. I'm taking Gael's "overall this looks good" to be a +1. Enough other eyes have looked at it.

@jnothman jnothman merged commit e339240 into scikit-learn:master Jun 19, 2017
2 of 3 checks passed
2 of 3 checks passed
continuous-integration/appveyor/pr AppVeyor build failed
Details
ci/circleci Your tests passed on CircleCI!
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
…scikit-learn#8094)

Also ensure confusion matrix is accumulated with high precision.
dmohns added a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
…scikit-learn#8094)

Also ensure confusion matrix is accumulated with high precision.
NelleV added a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
…scikit-learn#8094)

Also ensure confusion matrix is accumulated with high precision.
paulha added a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…scikit-learn#8094)

Also ensure confusion matrix is accumulated with high precision.
AishwaryaRK added a commit to AishwaryaRK/scikit-learn that referenced this pull request Aug 29, 2017
…scikit-learn#8094)

Also ensure confusion matrix is accumulated with high precision.
maskani-moh added a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
…scikit-learn#8094)

Also ensure confusion matrix is accumulated with high precision.
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
…scikit-learn#8094)

Also ensure confusion matrix is accumulated with high precision.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

5 participants
You can’t perform that action at this time.