Skip to content

Commit

Permalink
FIX bug where expected_mutual_information may miscalculate
Browse files Browse the repository at this point in the history
  • Loading branch information
jnothman committed May 14, 2016
1 parent 8f199fe commit a89f754
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -200,6 +200,9 @@ Bug fixes
:class:`linear_model.SGDClassifier` and :class:`linear_model.SGDRegressor`
(`#6764 <https://github.com/scikit-learn/scikit-learn/pull/6764>`_). By `Wenhua Yang`_.

- Fix bug where expected and adjusted mutual information were incorrect if
cluster contingency cells exceeded ``2**16``. By `Joel Nothman`_.


API changes summary
-------------------
Expand Down Expand Up @@ -4179,5 +4182,8 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
.. _Ori Ziv: https://github.com/zivori

.. _Sears Merritt: https://github.com/merritts
<<<<<<< HEAD

.. _Wenhua Yang: https://github.com/geekoala
=======
>>>>>>> FIX bug where expected_mutual_information may miscalculate
2 changes: 1 addition & 1 deletion sklearn/metrics/cluster/expected_mutual_info_fast.pyx
Expand Up @@ -39,7 +39,7 @@ def expected_mutual_information(contingency, int n_samples):
term1 = nijs / N
# term2 is log((N*nij) / (a * b)) == log(N * nij) - log(a * b)
# term2 uses the outer product
log_ab_outer = np.log(np.outer(a, b))
log_ab_outer = np.log(a)[:, np.newaxis] + np.log(b)
# term2 uses N * nij
log_Nnij = np.log(N * nijs)
# term3 is large, and involved many factorials. Calculate these in log
Expand Down
6 changes: 6 additions & 0 deletions sklearn/metrics/cluster/tests/test_supervised.py
Expand Up @@ -158,6 +158,12 @@ def test_adjusted_mutual_info_score():
assert_almost_equal(ami, 0.37, 2)


def test_expected_mutual_info_overflow():
# Test for regression where contingency cell exceeds 2**16
# leading to overflow in np.outer, resulting in EMI > 1
assert expected_mutual_information(np.array([[70000]]), 70000) <= 1


def test_entropy():
ent = entropy([0, 0, 42.])
assert_almost_equal(ent, 0.6365141, 5)
Expand Down

0 comments on commit a89f754

Please sign in to comment.