Skip to content

Commit

Permalink
Change signature of printProbsDocument to return a Pair of Counters.
Browse files Browse the repository at this point in the history
  • Loading branch information
manning authored and Stanford NLP committed Mar 22, 2015
1 parent cc60290 commit 8e357e2
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 22 deletions.
36 changes: 34 additions & 2 deletions src/edu/stanford/nlp/ie/AbstractSequenceClassifier.java
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters; import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.Sampler; import edu.stanford.nlp.stats.Sampler;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.util.*; import edu.stanford.nlp.util.*;
import edu.stanford.nlp.util.concurrent.*; import edu.stanford.nlp.util.concurrent.*;


Expand Down Expand Up @@ -949,10 +950,41 @@ public void printProbs(String filename,
* {@link CoreMap}. * {@link CoreMap}.
*/ */
public void printProbsDocuments(ObjectBank<List<IN>> documents) { public void printProbsDocuments(ObjectBank<List<IN>> documents) {
Counter<Integer> calibration = new ClassicCounter<>();
TwoDimensionalCounter<Integer,String> calibratedTokens = new TwoDimensionalCounter<>();

for (List<IN> doc : documents) { for (List<IN> doc : documents) {
printProbsDocument(doc); Pair<Counter<Integer>, TwoDimensionalCounter<Integer,String>> pair = printProbsDocument(doc);
if (pair != null) {
Counters.addInPlace(calibration, pair.first());
calibratedTokens.addAll(pair.second());
}
System.out.println(); System.out.println();
} }
if (calibration.size() > 0) {
// we stored stuff, so print it out
PrintWriter pw = new PrintWriter(System.err);
outputCalibrationInfo(pw, calibration, calibratedTokens);
pw.flush();
}
}

public static void outputCalibrationInfo(PrintWriter pw,
Counter<Integer> calibration,
TwoDimensionalCounter<Integer,String> calibratedTokens) {
final int numBins = 10;
pw.println("----------------------------------------");
pw.println("Probability distribution given to tokens");
pw.println("----------------------------------------");
for (int i = 0; i < numBins; i++) {
pw.printf("[%.1f-%.1f%c: %.1f %s%n",
((double) i) / numBins,
((double) (i+1)) / numBins,
i == (numBins - 1) ? ']': ')',
calibration.getCount(i),
Counters.toSortedString(calibratedTokens.getCounter(i), 10, "%s=%.1f", ", ", "[%s]"));
}
pw.println("----------------------------------------");
} }


public void classifyStdin() public void classifyStdin()
Expand All @@ -974,7 +1006,7 @@ public void classifyStdin(DocumentReaderAndWriter<IN> readerWriter)
} }
} }


public abstract void printProbsDocument(List<IN> document); public abstract Pair<Counter<Integer>, TwoDimensionalCounter<Integer,String>> printProbsDocument(List<IN> document);


/** /**
* Load a test file, run the classifier on it, and then print the answers to * Load a test file, run the classifier on it, and then print the answers to
Expand Down
5 changes: 4 additions & 1 deletion src/edu/stanford/nlp/ie/ClassifierCombiner.java
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import edu.stanford.nlp.ling.HasWord; import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.pipeline.DefaultPaths; import edu.stanford.nlp.pipeline.DefaultPaths;
import edu.stanford.nlp.sequences.DocumentReaderAndWriter; import edu.stanford.nlp.sequences.DocumentReaderAndWriter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.util.CoreMap; import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.ErasureUtils; import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils; import edu.stanford.nlp.util.StringUtils;


import java.io.FileNotFoundException; import java.io.FileNotFoundException;
Expand Down Expand Up @@ -395,7 +398,7 @@ public void train(Collection<List<IN>> docs,
} }


@Override @Override
public void printProbsDocument(List<IN> document) { public Pair<Counter<Integer>, TwoDimensionalCounter<Integer,String>> printProbsDocument(List<IN> document) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }


Expand Down
46 changes: 33 additions & 13 deletions src/edu/stanford/nlp/ie/crf/CRFClassifier.java
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import edu.stanford.nlp.sequences.*; import edu.stanford.nlp.sequences.*;
import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.util.*; import edu.stanford.nlp.util.*;


import java.io.*; import java.io.*;
Expand All @@ -46,6 +47,7 @@
import java.text.NumberFormat; import java.text.NumberFormat;
import java.util.*; import java.util.*;
import java.util.regex.*; import java.util.regex.*;
import java.util.stream.Collectors;
import java.util.zip.GZIPInputStream; import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream; import java.util.zip.GZIPOutputStream;


Expand Down Expand Up @@ -1338,18 +1340,27 @@ public List<IN> classifyGibbs(List<IN> document, Triple<int[][][], int[], double
* the likelihood of each possible label at each point. * the likelihood of each possible label at each point.
* *
* @param document A {@link List} of something that extends CoreMap. * @param document A {@link List} of something that extends CoreMap.
* @return If verboseMode is set, a Pair of Counters recording classification decisions, else null.
*/ */
@Override @Override
public void printProbsDocument(List<IN> document) { public Pair<Counter<Integer>, TwoDimensionalCounter<Integer,String>> printProbsDocument(List<IN> document) {
final int numBins = 10;
boolean verbose = flags.verboseMode;


Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document); Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document);

CRFCliqueTree<String> cliqueTree = getCliqueTree(p); CRFCliqueTree<String> cliqueTree = getCliqueTree(p);


Counter<Integer> calibration = new ClassicCounter<>();
TwoDimensionalCounter<Integer,String> calibratedTokens = new TwoDimensionalCounter<>();

// for (int i = 0; i < factorTables.length; i++) { // for (int i = 0; i < factorTables.length; i++) {
for (int i = 0; i < cliqueTree.length(); i++) { for (int i = 0; i < cliqueTree.length(); i++) {
IN wi = document.get(i); IN wi = document.get(i);
System.out.print(wi.get(CoreAnnotations.TextAnnotation.class)); String token = wi.get(CoreAnnotations.TextAnnotation.class);
String goldAnswer = wi.get(CoreAnnotations.GoldAnswerAnnotation.class);
System.out.print(token);
System.out.print('\t');
System.out.print(goldAnswer);
for (String label : classIndex) { for (String label : classIndex) {
int index = classIndex.indexOf(label); int index = classIndex.indexOf(label);
// double prob = Math.pow(Math.E, factorTables[i].logProbEnd(index)); // double prob = Math.pow(Math.E, factorTables[i].logProbEnd(index));
Expand All @@ -1358,9 +1369,24 @@ public void printProbsDocument(List<IN> document) {
System.out.print(label); System.out.print(label);
System.out.print('='); System.out.print('=');
System.out.print(prob); System.out.print(prob);
if (verbose ) {
int binnedProb = (int) prob * numBins;
if (binnedProb > (numBins - 1)) {
binnedProb = numBins - 1;
}
calibration.incrementCount(binnedProb);
if (label.equals(goldAnswer)) {
calibratedTokens.incrementCount(binnedProb, token);
}
}
} }
System.out.println(); System.out.println();
} }
if (verbose) {
return new Pair<>(calibration, calibratedTokens);
} else {
return null;
}
} }


/** /**
Expand All @@ -1382,8 +1408,7 @@ public void printFirstOrderProbs(String filename, DocumentReaderAndWriter<IN> re
* Takes a {@link List} of documents and prints the likelihood of each * Takes a {@link List} of documents and prints the likelihood of each
* possible label at each point. * possible label at each point.
* *
* @param documents * @param documents A {@link List} of {@link List} of INs.
* A {@link List} of {@link List} of INs.
*/ */
public void printFirstOrderProbsDocuments(ObjectBank<List<IN>> documents) { public void printFirstOrderProbsDocuments(ObjectBank<List<IN>> documents) {
for (List<IN> doc : documents) { for (List<IN> doc : documents) {
Expand All @@ -1395,8 +1420,7 @@ public void printFirstOrderProbsDocuments(ObjectBank<List<IN>> documents) {
/** /**
* Takes the file, reads it in, and prints out the factor table at each position. * Takes the file, reads it in, and prints out the factor table at each position.
* *
* @param filename * @param filename The path to the specified file
* The path to the specified file
*/ */
public void printFactorTable(String filename, DocumentReaderAndWriter<IN> readerAndWriter) { public void printFactorTable(String filename, DocumentReaderAndWriter<IN> readerAndWriter) {
// only for the OCR data does this matter // only for the OCR data does this matter
Expand All @@ -1410,8 +1434,7 @@ public void printFactorTable(String filename, DocumentReaderAndWriter<IN> reader
* Takes a {@link List} of documents and prints the factor table * Takes a {@link List} of documents and prints the factor table
* at each point. * at each point.
* *
* @param documents * @param documents A {@link List} of {@link List} of INs.
* A {@link List} of {@link List} of INs.
*/ */
public void printFactorTableDocuments(ObjectBank<List<IN>> documents) { public void printFactorTableDocuments(ObjectBank<List<IN>> documents) {
for (List<IN> doc : documents) { for (List<IN> doc : documents) {
Expand Down Expand Up @@ -2964,10 +2987,7 @@ public static void main(String[] args) throws Exception {
} }


if (testFiles != null) { if (testFiles != null) {
List<File> files = new ArrayList<File>(); List<File> files = Arrays.asList(testFiles.split(",")).stream().map(File::new).collect(Collectors.toList());
for (String filename : testFiles.split(",")) {
files.add(new File(filename));
}
crf.classifyFilesAndWriteAnswers(files, crf.defaultReaderAndWriter(), true); crf.classifyFilesAndWriteAnswers(files, crf.defaultReaderAndWriter(), true);
} }


Expand Down
8 changes: 4 additions & 4 deletions src/edu/stanford/nlp/ie/ner/CMMClassifier.java
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import edu.stanford.nlp.sequences.SequenceModel; import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.util.CoreMap; import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.ErasureUtils; import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.Generics;
Expand Down Expand Up @@ -1547,15 +1548,14 @@ public Counter<String> scoresOf(List<IN> lineInfos, int pos) {
/** /**
* Takes a {@link List} of {@link CoreLabel}s and prints the likelihood * Takes a {@link List} of {@link CoreLabel}s and prints the likelihood
* of each possible label at each point. * of each possible label at each point.
* TODO: Finish or delete this method! * TODO: Write this method!
* *
* @param document A {@link List} of {@link CoreLabel}s. * @param document A {@link List} of {@link CoreLabel}s.
*/ */
@Override @Override
public void printProbsDocument(List<IN> document) { public Pair<Counter<Integer>, TwoDimensionalCounter<Integer,String>> printProbsDocument(List<IN> document) {

//ClassicCounter<String> c = scoresOf(document, 0); //ClassicCounter<String> c = scoresOf(document, 0);

throw new UnsupportedOperationException();
} }


/** Command-line version of the classifier. See the class /** Command-line version of the classifier. See the class
Expand Down
6 changes: 5 additions & 1 deletion src/edu/stanford/nlp/ie/regexp/NumberSequenceClassifier.java
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import edu.stanford.nlp.pipeline.Annotation; import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.sequences.DocumentReaderAndWriter; import edu.stanford.nlp.sequences.DocumentReaderAndWriter;
import edu.stanford.nlp.sequences.PlainTextDocumentReaderAndWriter; import edu.stanford.nlp.sequences.PlainTextDocumentReaderAndWriter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.time.TimeAnnotations; import edu.stanford.nlp.time.TimeAnnotations;
import edu.stanford.nlp.time.TimeExpressionExtractor; import edu.stanford.nlp.time.TimeExpressionExtractor;
import edu.stanford.nlp.time.TimeExpressionExtractorFactory; import edu.stanford.nlp.time.TimeExpressionExtractorFactory;
import edu.stanford.nlp.time.Timex; import edu.stanford.nlp.time.Timex;
import edu.stanford.nlp.util.CoreMap; import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.PaddedList; import edu.stanford.nlp.util.PaddedList;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils; import edu.stanford.nlp.util.StringUtils;


import java.io.ObjectInputStream; import java.io.ObjectInputStream;
Expand Down Expand Up @@ -809,7 +812,8 @@ public void train(Collection<List<CoreLabel>> docs,
} }


@Override @Override
public void printProbsDocument(List<CoreLabel> document) { public Pair<Counter<Integer>, TwoDimensionalCounter<Integer,String>> printProbsDocument(List<CoreLabel> document) {
throw new UnsupportedOperationException();
} }


@Override @Override
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.sequences.DocumentReaderAndWriter; import edu.stanford.nlp.sequences.DocumentReaderAndWriter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.util.CoreMap; import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;


/** /**
* A sequence classifier that labels tokens with types based on a simple manual mapping from * A sequence classifier that labels tokens with types based on a simple manual mapping from
Expand Down Expand Up @@ -357,7 +360,9 @@ public void train(Collection<List<CoreLabel>> docs,
DocumentReaderAndWriter<CoreLabel> readerAndWriter) {} DocumentReaderAndWriter<CoreLabel> readerAndWriter) {}


@Override @Override
public void printProbsDocument(List<CoreLabel> document) {} public Pair<Counter<Integer>, TwoDimensionalCounter<Integer,String>> printProbsDocument(List<CoreLabel> document) {
throw new UnsupportedOperationException();
}


@Override @Override
public void serializeClassifier(String serializePath) {} public void serializeClassifier(String serializePath) {}
Expand Down

0 comments on commit 8e357e2

Please sign in to comment.