Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
manning authored and Stanford NLP committed Jan 25, 2015
1 parent 46d2051 commit ca588d5
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 17 deletions.
36 changes: 26 additions & 10 deletions src/edu/stanford/nlp/stats/Counters.java
Expand Up @@ -492,7 +492,19 @@ public static <E> Counter<E> multiplyInPlace(Counter<E> target, Counter<E> mult)
* @param <E> Type of elements in Counter * @param <E> Type of elements in Counter
*/ */
public static <E> void normalize(Counter<E> target) { public static <E> void normalize(Counter<E> target) {
multiplyInPlace(target, 1.0 / target.totalCount()); divideInPlace(target, target.totalCount());
}

/**
* L1 normalize a counter. Return a counter that is a probability distribution,
* so the sum of the resulting value equals 1.
*
* @param c The {@link Counter} to be L1 normalized. This counter is not
* modified.
* @return A new L1-normalized Counter based on c.
*/
public static <E, C extends Counter<E>> C asNormalizedCounter(C c) {
return scale(c, 1.0 / c.totalCount());
} }


/** /**
Expand Down Expand Up @@ -1406,13 +1418,17 @@ public static <E> double klDivergence(Counter<E> from, Counter<E> to) {
/** /**
* Calculates the Jensen-Shannon divergence between the two counters. That is, * Calculates the Jensen-Shannon divergence between the two counters. That is,
* it calculates 1/2 [KL(c1 || avg(c1,c2)) + KL(c2 || avg(c1,c2))] . * it calculates 1/2 [KL(c1 || avg(c1,c2)) + KL(c2 || avg(c1,c2))] .
* This code assumes that the Counters have only non-negative values in them.
* *
* @return The Jensen-Shannon divergence between the distributions * @return The Jensen-Shannon divergence between the distributions
*/ */
public static <E> double jensenShannonDivergence(Counter<E> c1, Counter<E> c2) { public static <E> double jensenShannonDivergence(Counter<E> c1, Counter<E> c2) {
Counter<E> average = average(c1, c2); // need to normalize the counters first before averaging them! Else buggy if not a probability distribution
double kl1 = klDivergence(c1, average); Counter<E> d1 = asNormalizedCounter(c1);
double kl2 = klDivergence(c2, average); Counter<E> d2 = asNormalizedCounter(c2);
Counter<E> average = average(d1, d2);
double kl1 = klDivergence(d1, average);
double kl2 = klDivergence(d2, average);
return (kl1 + kl2) / 2.0; return (kl1 + kl2) / 2.0;
} }


Expand All @@ -1424,8 +1440,10 @@ public static <E> double jensenShannonDivergence(Counter<E> c1, Counter<E> c2) {
* @return The skew divergence between the distributions * @return The skew divergence between the distributions
*/ */
public static <E> double skewDivergence(Counter<E> c1, Counter<E> c2, double skew) { public static <E> double skewDivergence(Counter<E> c1, Counter<E> c2, double skew) {
Counter<E> average = linearCombination(c2, skew, c1, (1.0 - skew)); Counter<E> d1 = asNormalizedCounter(c1);
return klDivergence(c1, average); Counter<E> d2 = asNormalizedCounter(c2);
Counter<E> average = linearCombination(d2, skew, d1, (1.0 - skew));
return klDivergence(d1, average);
} }


/** /**
Expand Down Expand Up @@ -1701,10 +1719,8 @@ public static <E> Counter<Double> getCountCounts(Counter<E> c) {
/** /**
* Returns a new Counter which is scaled by the given scale factor. * Returns a new Counter which is scaled by the given scale factor.
* *
* @param c * @param c The counter to scale. It is not changed
* The counter to scale. It is not changed * @param s The constant to scale the counter by
* @param s
* The constant to scale the counter by
* @return A new Counter which is the argument scaled by the given scale * @return A new Counter which is the argument scaled by the given scale
* factor. * factor.
*/ */
Expand Down
6 changes: 3 additions & 3 deletions src/edu/stanford/nlp/stats/Distributions.java
Expand Up @@ -126,7 +126,7 @@ public static <K> double klDivergence(Distribution<K> from, Distribution<K> to)
p2 = (1.0 - assignedMass2) / numKeysRemaining; p2 = (1.0 - assignedMass2) / numKeysRemaining;
double logFract = Math.log(p1 / p2); double logFract = Math.log(p1 / p2);
if (logFract == Double.POSITIVE_INFINITY) { if (logFract == Double.POSITIVE_INFINITY) {
System.out.println("Didtributions.kldivergence (remaining mass) returning +inf: p1=" + p1 + ", p2=" +p2); System.out.println("Distributions.klDivergence (remaining mass) returning +inf: p1=" + p1 + ", p2=" +p2);
System.out.flush(); System.out.flush();
return Double.POSITIVE_INFINITY; // can't recover return Double.POSITIVE_INFINITY; // can't recover
} }
Expand All @@ -136,7 +136,7 @@ public static <K> double klDivergence(Distribution<K> from, Distribution<K> to)
return result; return result;
} }


/** /**
* Calculates the Jensen-Shannon divergence between the two distributions. * Calculates the Jensen-Shannon divergence between the two distributions.
* That is, it calculates 1/2 [KL(d1 || avg(d1,d2)) + KL(d2 || avg(d1,d2))] . * That is, it calculates 1/2 [KL(d1 || avg(d1,d2)) + KL(d2 || avg(d1,d2))] .
* *
Expand All @@ -150,7 +150,7 @@ public static <K> double jensenShannonDivergence(Distribution<K> d1, Distributio
return js; return js;
} }


/** /**
* Calculates the skew divergence between the two distributions. * Calculates the skew divergence between the two distributions.
* That is, it calculates KL(d1 || (d2*skew + d1*(1-skew))) . * That is, it calculates KL(d1 || (d2*skew + d1*(1-skew))) .
* In other words, how well can d1 be represented by a "smoothed" d2. * In other words, how well can d1 be represented by a "smoothed" d2.
Expand Down
10 changes: 10 additions & 0 deletions test/src/edu/stanford/nlp/math/ArrayMathTest.java
Expand Up @@ -242,4 +242,14 @@ public void testSafeSumAndMean() {
helpTestSafeSumAndMean(d4); helpTestSafeSumAndMean(d4);
} }


public void testJensenShannon() {
double[] a = { 0.1, 0.1, 0.7, 0.1, 0.0, 0.0 };
double[] b = { 0.0, 0.1, 0.1, 0.7, 0.1, 0.0 };
assertEquals(0.46514844544032313, ArrayMath.jensenShannonDivergence(a, b), 1e-5);

double[] c = { 1.0, 0.0, 0.0 };
double[] d = { 0.0, 0.5, 0.5 };
assertEquals(1.0, ArrayMath.jensenShannonDivergence(c, d), 1e-5);
}

} }
31 changes: 27 additions & 4 deletions test/src/edu/stanford/nlp/stats/CountersTest.java
Expand Up @@ -5,6 +5,7 @@
import java.io.ObjectInputStream; import java.io.ObjectInputStream;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
Expand Down Expand Up @@ -340,7 +341,7 @@ public void testPearsonsCorrelationCoefficient(){
setUp(); setUp();
Counters.pearsonsCorrelationCoefficient(c1, c2); Counters.pearsonsCorrelationCoefficient(c1, c2);
} }

public void testToTiedRankCounter(){ public void testToTiedRankCounter(){
setUp(); setUp();
c1.setCount("t",1.0); c1.setCount("t",1.0);
Expand All @@ -351,8 +352,8 @@ public void testToTiedRankCounter(){
assertEquals(1.5, rank.getCount("z")); assertEquals(1.5, rank.getCount("z"));
assertEquals(7.0, rank.getCount("t")); assertEquals(7.0, rank.getCount("t"));
} }

public void testTransformWithValuesAdd(){ public void testTransformWithValuesAdd() {
setUp(); setUp();
c1.setCount("P",2.0); c1.setCount("P",2.0);
System.out.println(c1); System.out.println(c1);
Expand All @@ -366,7 +367,7 @@ public String apply(String in) {


} }


public void testEquals(){ public void testEquals() {
setUp(); setUp();
c1.clear(); c1.clear();
c2.clear(); c2.clear();
Expand All @@ -389,6 +390,28 @@ public void testEquals(){
c2.setCount("2", 3.0 + 8e-5); c2.setCount("2", 3.0 + 8e-5);
c2.setCount("s", 4.0 + 8e-5); c2.setCount("s", 4.0 + 8e-5);
assertFalse(Counters.equals(c1, c2, 1e-5)); // fails totalCount() equality check assertFalse(Counters.equals(c1, c2, 1e-5)); // fails totalCount() equality check
}


public void testJensenShannonDivergence() {
// borrow from ArrayMathTest
Counter<String> a = new ClassicCounter<>();
a.setCount("a", 1.0);
a.setCount("b", 1.0);
a.setCount("c", 7.0);
a.setCount("d", 1.0);

Counter<String> b = new ClassicCounter<>();
b.setCount("b", 1.0);
b.setCount("c", 1.0);
b.setCount("d", 7.0);
b.setCount("e", 1.0);
b.setCount("f", 0.0);

assertEquals(0.46514844544032313, Counters.jensenShannonDivergence(a, b), 1e-5);

Counter<String> c = new ClassicCounter<>(Arrays.asList("A"));
Counter<String> d = new ClassicCounter<>(Arrays.asList("B", "C"));
assertEquals(1.0, Counters.jensenShannonDivergence(c, d), 1e-5);
} }

} }

0 comments on commit ca588d5

Please sign in to comment.