ENH Implement Jensen-Shannon Divergence #3213

Open
wants to merge 11 commits into
from

Projects

None yet

6 participants

@luispedro

Adds two functions: one computes a pairwise divergence, the other a
matrix.

This is used to compute divergences between distributions. Interface is similar to that of entropy().

@pv pv and 1 other commented on an outdated diff Jan 14, 2014
scipy/stats/tests/test_stats.py
@@ -2691,5 +2691,28 @@ def test_three_groups(self):
assert_approx_equal(p, stats.chisqprob(h, 2))
+def test_jensen_shannon_divergence():
+ for _ in xrange(8):
@pv
pv Jan 14, 2014 SciPy member

Python 3 issue here

@luispedro
luispedro Jan 14, 2014

Tx! Fixed & ran with python3 to check the rest of the code.

@josef-pkt josef-pkt and 2 others commented on an outdated diff Jan 14, 2014
scipy/stats/stats.py
@@ -4105,6 +4106,70 @@ def friedmanchisquare(*args):
return chisq, chisqprob(chisq,k-1)
+def jensen_shannon_divergence(a, b):
@josef-pkt
josef-pkt Jan 14, 2014 SciPy member

I wonder whether this shouldn't be rolled together with entropy, with a symmetric keyword for example
https://github.com/scipy/scipy/blob/c254b4ff5def78c2a5e2d9decbd35882a2f0f650/scipy/stats/_distn_infrastructure.py#L2292

Otherwise we should get the same broadcasting behavior, i.e. makes sum wrt axis=0

@argriffing
argriffing Jan 14, 2014 collaborator

If entropy had a symmetric keyword I'd expect it to do symmetrized K-L divergence which is a little different from the divergence in this PR.

@josef-pkt
josef-pkt Jan 14, 2014 SciPy member

thanks for the link, I didn't know that
I agree

I think then entropy should be in the See Also

@argriffing
argriffing Jan 14, 2014 collaborator

Honestly I'd prefer entropy to not be overloaded to do K-L divergence, but I guess this would be hard to change for non-technical reasons. When I see entropy(p, q=None) I expect it to do cross entropy when q is specified, not K-L divergence. But I am getting off topic..

@luispedro
luispedro Jan 15, 2014

@josef-pkt wrt to the broadcast behaviour: Unlike entropy, which makes sense for a single sample, this is only meaningful in a pairwise fashion. Thus, I think the interface like the ones in scipy.spatial.distance is more appropriate. If others disagree, then I can change the PR, but as a user, it'd probably take me a while to figure out what the interface is.

@coveralls

Coverage Status

Changes Unknown when pulling eafbbcc on luispedro:master into * on scipy:master*.

@argriffing
Collaborator

If it weren't for having to keep backwards compatibility I would say to just add this basically as-is, but maybe it would be worth a little planning for some collection of statistical distances that have function names and implementations that are somewhat consistent with each other.

@luispedro

@argriffing The alternative place to put this would be scipy.spatial.distance. In there, there are two ways of computing distances: a function f(a, b) (which may take additional parameters if appropriate) for each distance and the pdist function. I used the same interface here.

@coveralls

Coverage Status

Changes Unknown when pulling 2173a66 on luispedro:master into * on scipy:master*.

@luispedro

Is there anything that is still needed for this PR to be merged?

@argriffing
Collaborator

Looks OK to me. I personally don't have any experience with scipy.spatial so I'm not familiar with its interfaces. At some point the assertions in the new tests should be changed to use numpy.testing instead of raw python assert but this could be done after this PR is merged.

@WarrenWeckesser
Collaborator

The plain assert statements should be changed to the appropriate functions from numpy.testing before this is merged. There is no reason to put that off until later.

@WarrenWeckesser WarrenWeckesser commented on an outdated diff Jan 27, 2014
scipy/stats/stats.py
+ j : float
+
+ See Also
+ --------
+ jsd_matrix : function
+ Computes all pair-wise distances for a set of measurements
+ entropy : function
+ Computes entropy and K-L divergence
+ """
+ a = np.asanyarray(a, dtype=float)
+ b = np.asanyarray(b, dtype=float)
+ a = a/a.sum()
+ b = b/b.sum()
+ m = (a + b)
+ m /= 2.
+ m = np.where(m,m,1.)
@WarrenWeckesser
WarrenWeckesser Jan 27, 2014 collaborator

PEP8: spaces after commas: np.where(m, m, 1.). (Here and elsewhere...)

@josef-pkt
Member

I still think we get the "wrong" result if users call jensen_shannon_divergence with 2-D arrays ("wrong" for stats)
From reading the function, I guess we can some strange raveled results (and no exception)

@luispedro

I added code to raise exceptions in the case where the inputs are not 1 dimensional.

I think it's not very obvious what the behaviour should be for multi-dimensional arrays. Still, this leaves open the possibility that someone will extend the function later.

@coveralls

Coverage Status

Changes Unknown when pulling eb687e4 on luispedro:master into * on scipy:master*.

@josef-pkt
Member

What's the problem with adding axis=0 to all np.sum?
IMO we would get straightforward vectorized, but not pairwise.

(I never use spatial, because I almost never need the all pair distances, and if I do I have different broadcasting. But I haven't looked at spatial in a long time.)

I don't really care, and an exception will prevent tying us to a specific behavior, but I'm not a big fan of having a hodge-podge of different ways of axis handling, behavior or restrictions in the stats functions.

@WarrenWeckesser WarrenWeckesser commented on an outdated diff Jan 27, 2014
scipy/stats/stats.py
+ j : float
+
+ See Also
+ --------
+ jsd_matrix : function
+ Computes all pair-wise distances for a set of measurements
+ entropy : function
+ Computes entropy and K-L divergence
+ """
+ a = np.asanyarray(a, dtype=float)
+ b = np.asanyarray(b, dtype=float)
+ if a.ndim != 1 or b.ndim != 1:
+ raise ValueError('jensen_shannon_divergence only accepts 1-dimensional arguments')
+ a = a/a.sum()
+ b = b/b.sum()
+ m = (a + b)
@WarrenWeckesser
WarrenWeckesser Jan 27, 2014 collaborator

Minor style nit: remove the extraneous parentheses.

@WarrenWeckesser WarrenWeckesser commented on an outdated diff Jan 27, 2014
scipy/stats/stats.py
+ --------
+ jsd_matrix : function
+ Computes all pair-wise distances for a set of measurements
+ entropy : function
+ Computes entropy and K-L divergence
+ """
+ a = np.asanyarray(a, dtype=float)
+ b = np.asanyarray(b, dtype=float)
+ if a.ndim != 1 or b.ndim != 1:
+ raise ValueError('jensen_shannon_divergence only accepts 1-dimensional arguments')
+ a = a/a.sum()
+ b = b/b.sum()
+ m = (a + b)
+ m /= 2.
+ m = np.where(m,m,1.)
+ return 0.5*np.sum(special.xlogy(a,a/m)+special.xlogy(b,b/m))
@WarrenWeckesser
WarrenWeckesser Jan 27, 2014 collaborator

PEP8: Spaces around + in an expression like this. Spaces after ,.

@WarrenWeckesser WarrenWeckesser commented on an outdated diff Jan 27, 2014
scipy/stats/tests/test_stats.py
@@ -2691,5 +2691,30 @@ def test_three_groups(self):
assert_approx_equal(p, stats.chisqprob(h, 2))
+def test_jensen_shannon_divergence():
+ for _ in range(8):
+ a = np.random.random(16)
+ b = np.random.random(16)
+ c = (a+b)
@WarrenWeckesser
WarrenWeckesser Jan 27, 2014 collaborator

Remove extraneous parentheses.

@WarrenWeckesser WarrenWeckesser commented on an outdated diff Jan 27, 2014
scipy/stats/tests/test_stats.py
@@ -2691,5 +2691,30 @@ def test_three_groups(self):
assert_approx_equal(p, stats.chisqprob(h, 2))
+def test_jensen_shannon_divergence():
+ for _ in range(8):
+ a = np.random.random(16)
+ b = np.random.random(16)
+ c = (a+b)
+
+ assert_(stats.jensen_shannon_divergence(a,a) < 1e-4)
@WarrenWeckesser
WarrenWeckesser Jan 27, 2014 collaborator

PEP8: Space after , (here and elsewhere).

@WarrenWeckesser WarrenWeckesser commented on an outdated diff Jan 27, 2014
scipy/stats/tests/test_stats.py
@@ -2691,5 +2691,30 @@ def test_three_groups(self):
assert_approx_equal(p, stats.chisqprob(h, 2))
+def test_jensen_shannon_divergence():
+ for _ in range(8):
+ a = np.random.random(16)
+ b = np.random.random(16)
+ c = (a+b)
+
+ assert_(stats.jensen_shannon_divergence(a,a) < 1e-4)
+ assert_(stats.jensen_shannon_divergence(a,b) > 0.)
+ assert_(stats.jensen_shannon_divergence(a,b) > stats.jensen_shannon_divergence(a,c))
@WarrenWeckesser
WarrenWeckesser Jan 27, 2014 collaborator

PEP8: Line too long (here and elsewhere). Lines must be at most 79 characters.

@WarrenWeckesser WarrenWeckesser commented on an outdated diff Jan 27, 2014
scipy/stats/tests/test_stats.py
@@ -2691,5 +2691,33 @@ def test_three_groups(self):
assert_approx_equal(p, stats.chisqprob(h, 2))
+def test_jensen_shannon_divergence():
@WarrenWeckesser
WarrenWeckesser Jan 27, 2014 collaborator

The unit tests here are all indirect. Either they check that an inequality is satisfied, or that scaling an input results in the same output. None of them actually test that, for a given input, the function returns the correct value. A really basic test could be to check the result for, say, a = np.array([1, 0, 0, 0]), b = np.array([0, 1, 0, 0]) (or something similarly simple). The correct result is easily worked out "by hand", or it can be compared to 0.5*entropy(a, 0.5*(a + b)) + 0.5*entropy(b, 0.5*(a + b)).

@luispedro

I added the axis=0 in the sums. I am not a big fan of the resulting interface, but it's a superset of the simple one I like, so sure.

@coveralls

Coverage Status

Changes Unknown when pulling da40040 on luispedro:master into * on scipy:master*.

@pv pv added the PR label Feb 19, 2014
@pv pv removed the PR label Aug 13, 2014
@argriffing
Collaborator

I think this PR is slow to be merged because it looks like it's about a statistical divergence, but it's actually about a metric which happens to be the square root of a symmetrized statistical divergence. Metrics care about things like symmetry and triangle inequality, whereas divergences care about things like non-negativity and convexity, so I think that if it keeps the compressed output matrix format, then it should go into scipy.spatial where the symmetric interface is more standard instead of in scipy.stats where the divergences are generally not symmetric and so they don't have that interface. Or keep the Jensen-Shannon divergence function in stats, possibly generalized to the lambda divergence accepting lamda other than 0.5 and move the other function (the one that computes all-pairs divergences and uses the special symmetry property of this particular divergence to compress the output) into scipy.spatial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment