Skip to content

Commit

Permalink
More changes than I should be making in a commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabor Angeli authored and Stanford NLP committed Mar 12, 2015
1 parent 8ad8d17 commit 2b9fce9
Show file tree
Hide file tree
Showing 6 changed files with 664 additions and 99 deletions.
68 changes: 68 additions & 0 deletions src/edu/stanford/nlp/classify/Classifier.java
@@ -1,7 +1,9 @@
package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Pair;

import java.io.Serializable;
import java.util.Collection;
Expand All @@ -25,4 +27,70 @@ public interface Classifier<L, F> extends Serializable {
public Counter<L> scoresOf(Datum<L, F> example);

public Collection<L> labels();

/**
* Evaluates the precision and recall of this classifier against a dataset, and the target label.
*
* @param testData The dataset to evaluate the classifier on.
* @param targetLabel The target label (e.g., for relation extraction, this is the relation we're interested in).
* @return A pair of the precision (first) and recall (second) of the classifier on the target label.
*/
public default Pair<Double, Double> evaluatePrecisionAndRecall(GeneralDataset<L, F> testData, L targetLabel) {
if (targetLabel == null) {
throw new IllegalArgumentException("Must supply a target label to compute precision and recall against");
}
// Variables to count
int numCorrectAndTarget = 0;
int numTargetGuess = 0;
int numTargetGold = 0;
// Iterate over dataset
for (RVFDatum<L, F> datum : testData) {
// Get the gold label
L label = datum.label();
if (label == null) {
throw new IllegalArgumentException("Cannot compute precision and recall on unlabelled dataset. Offending datum: " + datum);
}
// Get the guess label
L guess = classOf(datum);
// Compute statistics on datum
if (label.equals(targetLabel)) {
numTargetGold += 1;
}
if (guess.equals(targetLabel)) {
numTargetGuess += 1;
if (guess.equals(label)) {
numCorrectAndTarget += 1;
}
}
}
// Aggregate statistics
double precision = numTargetGuess == 0 ? 0.0 : ((double) numCorrectAndTarget) / ((double) numTargetGuess);
double recall = numTargetGold == 0 ? 1.0 : ((double) numCorrectAndTarget) / ((double) numTargetGold);
return Pair.makePair(precision, recall);
}

/**
* Evaluate the accuracy of this classifier on the given dataset.
*
* @param testData The dataset to evaluate the classifier on.
* @return The accuracy of the classifier on the given dataset.
*/
public default double evaluateAccuracy(GeneralDataset<L, F> testData) {
int numCorrect = 0;
for (RVFDatum<L, F> datum : testData) {
// Get the gold label
L label = datum.label();
if (label == null) {
throw new IllegalArgumentException("Cannot compute precision and recall on unlabelled dataset. Offending datum: " + datum);
}
// Get the guess
L guess = classOf(datum);
// Compute statistics
if (label.equals(guess)) {
numCorrect += 1;
}
}
return ((double) numCorrect) / ((double) testData.size);
}

}

0 comments on commit 2b9fce9

Please sign in to comment.