From 2b9fce9533c348e81e37c4f33c6481019f5e2f0e Mon Sep 17 00:00:00 2001 From: Gabor Angeli Date: Thu, 12 Mar 2015 16:55:45 -0700 Subject: [PATCH] More changes than I should be making in a commit --- src/edu/stanford/nlp/classify/Classifier.java | 68 ++ .../nlp/naturalli/ClauseSearcher.java | 584 +++++++++++++++--- .../nlp/naturalli/NaturalLogicRelation.java | 3 + .../nlp/naturalli/SentenceFragment.java | 5 + src/edu/stanford/nlp/naturalli/Util.java | 74 +++ src/edu/stanford/nlp/util/Pointer.java | 29 + 6 files changed, 664 insertions(+), 99 deletions(-) create mode 100644 src/edu/stanford/nlp/naturalli/Util.java create mode 100644 src/edu/stanford/nlp/util/Pointer.java diff --git a/src/edu/stanford/nlp/classify/Classifier.java b/src/edu/stanford/nlp/classify/Classifier.java index 4998323f7d..d983b7f6a8 100644 --- a/src/edu/stanford/nlp/classify/Classifier.java +++ b/src/edu/stanford/nlp/classify/Classifier.java @@ -1,7 +1,9 @@ package edu.stanford.nlp.classify; import edu.stanford.nlp.ling.Datum; +import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.stats.Counter; +import edu.stanford.nlp.util.Pair; import java.io.Serializable; import java.util.Collection; @@ -25,4 +27,70 @@ public interface Classifier extends Serializable { public Counter scoresOf(Datum example); public Collection labels(); + + /** + * Evaluates the precision and recall of this classifier against a dataset, and the target label. + * + * @param testData The dataset to evaluate the classifier on. + * @param targetLabel The target label (e.g., for relation extraction, this is the relation we're interested in). + * @return A pair of the precision (first) and recall (second) of the classifier on the target label. + */ + public default Pair evaluatePrecisionAndRecall(GeneralDataset testData, L targetLabel) { + if (targetLabel == null) { + throw new IllegalArgumentException("Must supply a target label to compute precision and recall against"); + } + // Variables to count + int numCorrectAndTarget = 0; + int numTargetGuess = 0; + int numTargetGold = 0; + // Iterate over dataset + for (RVFDatum datum : testData) { + // Get the gold label + L label = datum.label(); + if (label == null) { + throw new IllegalArgumentException("Cannot compute precision and recall on unlabelled dataset. Offending datum: " + datum); + } + // Get the guess label + L guess = classOf(datum); + // Compute statistics on datum + if (label.equals(targetLabel)) { + numTargetGold += 1; + } + if (guess.equals(targetLabel)) { + numTargetGuess += 1; + if (guess.equals(label)) { + numCorrectAndTarget += 1; + } + } + } + // Aggregate statistics + double precision = numTargetGuess == 0 ? 0.0 : ((double) numCorrectAndTarget) / ((double) numTargetGuess); + double recall = numTargetGold == 0 ? 1.0 : ((double) numCorrectAndTarget) / ((double) numTargetGold); + return Pair.makePair(precision, recall); + } + + /** + * Evaluate the accuracy of this classifier on the given dataset. + * + * @param testData The dataset to evaluate the classifier on. + * @return The accuracy of the classifier on the given dataset. + */ + public default double evaluateAccuracy(GeneralDataset testData) { + int numCorrect = 0; + for (RVFDatum datum : testData) { + // Get the gold label + L label = datum.label(); + if (label == null) { + throw new IllegalArgumentException("Cannot compute precision and recall on unlabelled dataset. Offending datum: " + datum); + } + // Get the guess + L guess = classOf(datum); + // Compute statistics + if (label.equals(guess)) { + numCorrect += 1; + } + } + return ((double) numCorrect) / ((double) testData.size); + } + } diff --git a/src/edu/stanford/nlp/naturalli/ClauseSearcher.java b/src/edu/stanford/nlp/naturalli/ClauseSearcher.java index 84f3723b27..ab18d51492 100644 --- a/src/edu/stanford/nlp/naturalli/ClauseSearcher.java +++ b/src/edu/stanford/nlp/naturalli/ClauseSearcher.java @@ -1,9 +1,17 @@ package edu.stanford.nlp.naturalli; +import edu.stanford.nlp.classify.*; +import edu.stanford.nlp.ie.machinereading.structure.Span; +import edu.stanford.nlp.ie.util.RelationTriple; +import edu.stanford.nlp.io.IOUtils; +import edu.stanford.nlp.io.RuntimeIOException; +import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.IndexedWord; +import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.math.SloppyMath; import edu.stanford.nlp.semgraph.SemanticGraph; +import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations; import edu.stanford.nlp.semgraph.SemanticGraphEdge; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; @@ -12,10 +20,20 @@ import edu.stanford.nlp.util.*; import edu.stanford.nlp.util.PriorityQueue; +import java.io.*; +import java.text.DecimalFormat; import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import java.util.zip.GZIPOutputStream; + +import static edu.stanford.nlp.util.logging.Redwood.Util.*; /** * A search problem for finding clauses in a sentence. @@ -36,6 +54,15 @@ public class ClauseSearcher { * A mapping from a word to the extra edges that come out of it. */ private final Map> extraEdgesByGovernor = new HashMap<>(); + /** + * The classifier for whether a particular dependency edge defines a clause boundary. + */ + private final Optional> isClauseClassifier; + /** + * An optional featurizer to use with the clause classifier ({@link ClauseSearcher#isClauseClassifier}). + * If that classifier is defined, this should be as well. + */ + private final Optional, Counter>> featurizer; /** * A mapping from edges in the tree, to an index. @@ -75,19 +102,61 @@ public State(State source, boolean isDone) { this.thunk = source.thunk; this.isDone = isDone; } + + public SemanticGraph originalTree() { + return ClauseSearcher.this.tree; + } } + /** + * An action being taken; that is, the type of clause splitting going on. + */ public static interface Action { public String signature(); public Optional applyTo(SemanticGraph tree, State source, - SemanticGraphEdge outgoingEdge, - SemanticGraphEdge subjectOrNull, - SemanticGraphEdge ppOrNull); + SemanticGraphEdge outgoingEdge, + SemanticGraphEdge subjectOrNull, + SemanticGraphEdge ppOrNull); } - public ClauseSearcher(SemanticGraph tree) { + /** + * The options used for training the clause searcher. + */ + public static class TrainingOptions { + @Execution.Option(name = "negativeSubsampleRatio", gloss = "The percent of negative datums to take") + public double negativeSubsampleRatio = 0.05; + @Execution.Option(name = "positiveDatumWeight", gloss = "The weight to assign every positive datum.") + public float positiveDatumWeight = 10.0f; + @Execution.Option(name = "seed", gloss = "The random seed to use") + public int seed = 42; + @Execution.Option(name = "classifierFactory", gloss = "The class of the classifier factory to use for training the various classifiers") + public Class>> classifierFactory = (Class>>) ((Object) LinearClassifierFactory.class); + } + + /** + * Mostly just an alias, but make sure our featurizer is serializable! + */ + public static interface Featurizer extends Function, Counter>, Serializable { } + + /** + * Create a searcher manually, suppling a dependency tree, an optional classifier for when to split clauses, + * and a featurizer for that classifier. + * You almost certainly want to use {@link edu.stanford.nlp.naturalli.ClauseSearcher#factory(java.io.File)} instead of this + * constructor. + * + * @param tree The dependency tree to search over. + * @param isClauseClassifier The classifier for whether a given dependency arc should be a new clause. If this is not given, all arcs are treated as clause separators. + * @param featurizer The featurizer for the classifier. If no featurizer is given, one should be given in {@link ClauseSearcher#search(java.util.function.Predicate, edu.stanford.nlp.stats.Counter, java.util.function.Function, int)}, or else the classifier will be useless. + * @see edu.stanford.nlp.naturalli.ClauseSearcher#factory(java.io.File) + */ + public ClauseSearcher(SemanticGraph tree, + Optional> isClauseClassifier, + Optional, Counter>> featurizer + ) { this.tree = new SemanticGraph(tree); + this.isClauseClassifier = isClauseClassifier; + this.featurizer = featurizer; // Index edges this.tree.edgeIterable().forEach(edgeToIndex::addToIndex); // Get length @@ -104,16 +173,26 @@ public ClauseSearcher(SemanticGraph tree) { } } + /** + * Create a clause searcher which searches naively through every possible subtree as a clause. + * For an end-user, this is almost certainly not what you want. + * However, it is very useful for training time. + * + * @param tree The dependency tree to search over. + */ + protected ClauseSearcher(SemanticGraph tree) { + this(tree, Optional.empty(), Optional.empty()); + } + /** * Fix some bizarre peculiarities with certain trees. * So far, these include: *
    - *
  • Sometimes there's a node from a word to itself. This seems wrong.
  • + *
  • Sometimes there's a node from a word to itself. This seems wrong.
  • *
* * @param tree The tree to clean (in place!). - * * @return A list of extra edges, which are valid but were removed. */ private static List cleanTree(SemanticGraph tree) { @@ -128,7 +207,9 @@ private static List cleanTree(SemanticGraph tree) { } } } - for (IndexedWord v : toDelete) { tree.removeVertex(v); } + for (IndexedWord v : toDelete) { + tree.removeVertex(v); + } // Clean edges Iterator iter = tree.edgeIterable().iterator(); @@ -336,6 +417,14 @@ private static boolean isTree(SemanticGraph tree) { return true; } + /** + * Create a mock node, to be added to the dependency tree but which is not part of the original sentence. + * + * @param toCopy The CoreLabel to copy from initially. + * @param word The new word to add. + * @param POS The new part of speech to add. + * @return + */ private CoreLabel mockNode(CoreLabel toCopy, String word, String POS) { CoreLabel mock = new CoreLabel(toCopy); mock.setWord(word); @@ -348,28 +437,54 @@ private CoreLabel mockNode(CoreLabel toCopy, String word, String POS) { } /** - * A dummy action denoting that we're done trying to find a clause. + * TODO(gabor) JavaDoc + * @param thresholdProbability + * @return */ - private final Action STOP = new Action() { - @Override - public String signature() { - return "$STOP$"; - } - @Override - public Optional applyTo(SemanticGraph tree, State source, SemanticGraphEdge outgoingEdge, SemanticGraphEdge subjectOrNull, SemanticGraphEdge ppOrNull) { - return Optional.of(new State(source, true)); + public List topClauses(double thresholdProbability) { + List results = new ArrayList<>(); + search(triple -> { + if (triple.first >= thresholdProbability) { + results.add(triple.third.get()); + return true; + } else { + return false; + } + }); + return results; + } + + /** + * TODO(gabor) JavaDoc + * @param candidateFragments + */ + public void search(final Predicate>, Supplier>> candidateFragments) { + if (!isClauseClassifier.isPresent() || + !(isClauseClassifier.get() instanceof LinearClassifier)) { + throw new IllegalArgumentException("For now, only linear classifiers are supported"); } - }; + search(candidateFragments, + ((LinearClassifier) isClauseClassifier.get()).weightsAsMapOfCounters().get(true), + this.featurizer.get(), + 10000); + } + /** + * TODO(gabor) JavaDoc + * + * @param candidateFragments + * @param weights + * @param featurizer + */ public void search( - // The output specs - final Consumer,Supplier>> candidateFragments, - // The learning specs - final Counter weights, - final Function, Counter> featurizer - ) { + // The output specs + final Predicate>, Supplier>> candidateFragments, + // The learning specs + final Counter weights, + final Function, Counter> featurizer, + final int maxTicks + ) { Collection actionSpace = new ArrayList<>(); - actionSpace.add(STOP); // SIMPLE SPLIT actionSpace.add(new Action() { @@ -377,6 +492,7 @@ public void search( public String signature() { return "simple"; } + @Override public Optional applyTo(SemanticGraph tree, State source, SemanticGraphEdge outgoingEdge, SemanticGraphEdge subjectOrNull, SemanticGraphEdge ppOrNull) { return Optional.of(new State( @@ -399,6 +515,7 @@ public Optional applyTo(SemanticGraph tree, State source, SemanticGraphEd public String signature() { return "clone_root_as_nsubjpass"; } + @Override public Optional applyTo(SemanticGraph tree, State source, SemanticGraphEdge outgoingEdge, SemanticGraphEdge subjectOrNull, SemanticGraphEdge ppOrNull) { return Optional.of(new State( @@ -423,6 +540,7 @@ public Optional applyTo(SemanticGraph tree, State source, SemanticGraphEd public String signature() { return "clone_nsubj"; } + @Override public Optional applyTo(SemanticGraph tree, State source, SemanticGraphEdge outgoingEdge, SemanticGraphEdge subjectOrNull, SemanticGraphEdge ppOrNull) { if (subjectOrNull != null && !outgoingEdge.equals(subjectOrNull)) { @@ -446,108 +564,376 @@ public Optional applyTo(SemanticGraph tree, State source, SemanticGraphEd }); for (IndexedWord root : tree.getRoots()) { - search(root, candidateFragments, weights, featurizer, actionSpace); + search(root, candidateFragments, weights, featurizer, actionSpace, maxTicks); } } + /** + * TODO(gabor) JavaDoc + * + * @param root + * @param candidateFragments + * @param weights + * @param featurizer + * @param actionSpace + */ public void search( - // The root to search from - IndexedWord root, - // The output specs - final Consumer,Supplier>> candidateFragments, - // The learning specs - final Counter weights, - final Function, Counter> featurizer, - final Collection actionSpace - ) { + // The root to search from + IndexedWord root, + // The output specs + final Predicate>, Supplier>> candidateFragments, + // The learning specs + final Counter weights, + final Function, Counter> featurizer, + final Collection actionSpace, + final int maxTicks + ) { // (the fringe) - PriorityQueue>> fringe = new FixedPrioritiesPriorityQueue<>(); + PriorityQueue>>> fringe = new FixedPrioritiesPriorityQueue<>(); // (a helper list) List ppEdges = new ArrayList<>(); + // (avoid duplicate work) + Set seenWords = new HashSet<>(); - State firstState = new State(null, null, -9000, null, x -> { }, false); - fringe.add(Pair.makePair(firstState, new ClassicCounter<>()), -0.0); + State firstState = new State(null, null, -9000, null, x -> { + }, false); + fringe.add(Pair.makePair(firstState, new ArrayList<>(0)), -0.0); + int ticks = 0; while (!fringe.isEmpty()) { + if (++ticks > maxTicks) { + System.err.println("WARNING! Timed out on search with " + ticks + " ticks"); + return; + } // Useful variables double logProbSoFar = fringe.getPriority(); - Pair> lastStatePair = fringe.removeFirst(); + Pair>> lastStatePair = fringe.removeFirst(); State lastState = lastStatePair.first; - Counter featuresSoFar = lastStatePair.second; + List> featuresSoFar = lastStatePair.second; IndexedWord rootWord = lastState.edge == null ? root : lastState.edge.getDependent(); +// System.err.println("Looking at " + rootWord); // Register thunk - if (lastState.isDone) { - candidateFragments.accept(Triple.makeTriple(logProbSoFar, featuresSoFar, () -> { - SemanticGraph copy = new SemanticGraph(tree); - lastState.thunk.andThen( x -> { - // Add the extra edges back in, if they don't break the tree-ness of the extraction - for (IndexedWord newTreeRoot : x.getRoots()) { - for (SemanticGraphEdge extraEdge : extraEdgesByGovernor.get(newTreeRoot)) { - assert isTree(x); - //noinspection unchecked - addSubtree(x, newTreeRoot, extraEdge.getRelation().toString(), tree, extraEdge.getDependent(), tree.getIncomingEdgesSorted(newTreeRoot)); - assert isTree(x); - } + if (!candidateFragments.test(Triple.makeTriple(logProbSoFar, featuresSoFar, () -> { + SemanticGraph copy = new SemanticGraph(tree); + lastState.thunk.andThen(x -> { + // Add the extra edges back in, if they don't break the tree-ness of the extraction + for (IndexedWord newTreeRoot : x.getRoots()) { + for (SemanticGraphEdge extraEdge : extraEdgesByGovernor.get(newTreeRoot)) { + assert isTree(x); + //noinspection unchecked + addSubtree(x, newTreeRoot, extraEdge.getRelation().toString(), tree, extraEdge.getDependent(), tree.getIncomingEdgesSorted(newTreeRoot)); + assert isTree(x); } - }).accept(copy); - return new SentenceFragment(copy, false); - })); - } else { + } + }).accept(copy); + return new SentenceFragment(copy, false); + }))) { + break; + } - // Find relevant auxilliary terms - ppEdges.clear(); - ppEdges.add(null); - SemanticGraphEdge subjOrNull = null; - for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) { - String relString = auxEdge.getRelation().toString(); - if (relString.startsWith("prep")) { - ppEdges.add(auxEdge); - } else if (relString.contains("subj")) { - subjOrNull = auxEdge; + // Find relevant auxilliary terms + ppEdges.clear(); + ppEdges.add(null); + SemanticGraphEdge subjOrNull = null; + for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) { + String relString = auxEdge.getRelation().toString(); + if (relString.startsWith("prep")) { + ppEdges.add(auxEdge); + } else if (relString.contains("subj")) { + subjOrNull = auxEdge; + } + } + + // Iterate over children + for (Action action : actionSpace) { + // For each action... + for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) { + // For each outgoing edge... + // 1. Find the best aux information to carry along + double max = Double.NEGATIVE_INFINITY; + Pair>> argmax = null; + for (SemanticGraphEdge ppEdgeOrNull : ppEdges) { + Optional candidate = action.applyTo(tree, lastState, + outgoingEdge, subjOrNull, + ppEdgeOrNull); + if (candidate.isPresent()) { + Counter features = featurizer.apply(Triple.makeTriple(lastState, action, candidate.get())); + double probability = SloppyMath.sigmoid(Counters.dotProduct(features, weights)); + if (probability > max) { + max = probability; + argmax = Pair.makePair(candidate.get(), new ArrayList>(featuresSoFar) {{ + add(features); + }}); + } + } + } + // 2. Register the child state + if (argmax != null && !seenWords.contains(argmax.first.edge.getDependent())) { +// System.err.println(" pushing " + action.signature() + " with " + argmax.first.edge); + fringe.add(argmax, Math.log(max)); } } + } + + seenWords.add(rootWord); + } + } + + + + /** + * TODO(gabor) JavaDoc + * @param classifier + * @param dataset + */ + private static void dumpAccuracy(Classifier classifier, GeneralDataset dataset) { + DecimalFormat df = new DecimalFormat("0.000"); + log("size: " + dataset.size()); + log("true count: " + StreamSupport.stream(dataset.spliterator(), false).filter(RVFDatum::label).collect(Collectors.toList()).size()); + Pair pr = classifier.evaluatePrecisionAndRecall(dataset, true); + log("precision: " + df.format(pr.first)); + log("recall: " + df.format(pr.second)); + log("f1: " + df.format(2 * pr.first * pr.second / (pr.first + pr.second))); + } - // Iterate over children - for (Action action : actionSpace) { - // For each action... - if (action == STOP) { - // Special case the STOP action - State candidate = action.applyTo(tree, lastState, - lastState.edge, lastState.subjectOrNull, lastState.ppOrNull).get(); - Counter features = featurizer.apply(Triple.makeTriple(lastState, action, candidate)); - features.addAll(featuresSoFar); - double probability = SloppyMath.sigmoid(Counters.dotProduct(features, weights)); - fringe.add(Pair.makePair(candidate, features), Math.log(probability)); + /** + * TODO(gabor) JavaDoc + * + * @param trainingData + * @param featurizer + * @param options + * @param modelPath + * @param trainingDataDump + * @return + */ + public static Function trainFactory( + Stream> trainingData, + Featurizer featurizer, + TrainingOptions options, + Optional modelPath, + Optional trainingDataDump) { + // Parse options + ClassifierFactory> classifierFactory = MetaClass.create(options.classifierFactory).createInstance(); + // Generally useful objects + OpenIE openie = new OpenIE(); + Random rand = new Random(options.seed); + WeightedDataset dataset = new WeightedDataset<>(); + AtomicInteger numExamplesProcessed = new AtomicInteger(0); + final Optional datasetDumpWriter = trainingDataDump.map(file -> { + try { + return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get())))); + } catch (IOException e) { + throw new RuntimeIOException(e); + } + }); + + // Step 1: Inference over training sentences + forceTrack("Training inference"); + trainingData.forEach(triple -> { + // Parse training datum + CoreMap sentence = triple.first; + List tokens = sentence.get(CoreAnnotations.TokensAnnotation.class); + Span subjectSpan = Util.extractNER(tokens, triple.second); + Span objectSpan = Util.extractNER(tokens, triple.third); +// log("inference on " + StringUtils.join(tokens.subList(0, Math.min(10, tokens.size())).stream().map(CoreLabel::word), " ")); + // Create raw clause searcher (no classifier) + SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.CollapsedDependenciesAnnotation.class); + ClauseSearcher problem = new ClauseSearcher(sentence.get(SemanticGraphCoreAnnotations.CollapsedDependenciesAnnotation.class)); + Pointer anyCorrect = new Pointer<>(false); + + // Run search + problem.search(fragmentAndScore -> { + // Parse the search output + double score = fragmentAndScore.first; + List> features = fragmentAndScore.second; + Supplier fragmentSupplier = fragmentAndScore.third; + SentenceFragment fragment = fragmentSupplier.get(); + // Search for extractions + List extractions = openie.relationInClause(fragment.parseTree); + boolean correct = false; + RelationTriple bestExtraction = null; + for (RelationTriple extraction : extractions) { + // Clean up the guesses + Span subjectGuess = Util.extractNER(tokens, Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index())); + Span objectGuess = Util.extractNER(tokens, Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index())); + // Check if it matches + if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || + (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan)) + ) { + correct = true; + anyCorrect.set(true); + bestExtraction = extraction; + } else if ((subjectGuess.contains(subjectSpan) && objectGuess.contains(objectSpan)) || + (subjectGuess.contains(objectSpan) && objectGuess.contains(subjectSpan)) + ) { + correct = true; + anyCorrect.set(true); + if (bestExtraction == null) { + bestExtraction = extraction; + } } else { - // All other actions can apply to any edge - for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) { - // For each outgoing edge... - // 1. Find the best aux information to carry along - double max = Double.NEGATIVE_INFINITY; - Pair> argmax = null; - for (SemanticGraphEdge ppEdgeOrNull : ppEdges) { - Optional candidate = action.applyTo(tree, lastState, - outgoingEdge, subjOrNull, - ppEdgeOrNull); - if (candidate.isPresent()) { - Counter features = featurizer.apply(Triple.makeTriple(lastState, action, candidate.get())); - double probability = SloppyMath.sigmoid(Counters.dotProduct(features, weights)); - if (probability > max) { - max = probability; - argmax = Pair.makePair(candidate.get(), features); - } - } - } - // 2. Register the child state - if (argmax != null) { - argmax.second.addAll(featuresSoFar); - fringe.add(argmax, Math.log(max)); + if (bestExtraction == null && !correct) { + bestExtraction = extraction; + } + correct = false; + } + } + // Dump the datum + if (bestExtraction != null || fragment.length() == 1) { + if (correct || rand.nextDouble() > (1.0 - options.negativeSubsampleRatio)) { // Subsample + for (Counter decision : features) { + // Add datum to dataset + RVFDatum datum = new RVFDatum<>(decision); + datum.setLabel(correct); + dataset.add(datum, correct ? options.positiveDatumWeight : 1.0f); + // Dump datum to debug log + if (datasetDumpWriter.isPresent()) { + datasetDumpWriter.get().println("" + correct + "\t" + StringUtils.join(decision.entrySet().stream().map(entry -> "" + entry.getKey() + "->" + entry.getValue()), ";")); } } } } + return true; + }, new ClassicCounter<>(), featurizer, 10000); + // Debug info + if (numExamplesProcessed.incrementAndGet() % 100 == 0) { + log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums"); } + }); + // Close dataset dump + datasetDumpWriter.ifPresent(PrintWriter::close); + endTrack("Training inference"); + + // Step 2: Train classifier + forceTrack("Training"); + Classifier fullClassifier = classifierFactory.trainClassifier(dataset); + endTrack("Training"); + if (modelPath.isPresent()) { + Pair, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer); + try { + IOUtils.writeObjectToFile(toSave, modelPath.get()); + log("SUCCESS: wrote model to " + modelPath.get().getPath()); + } catch (IOException e) { + log("ERROR: failed to save model to path: " + modelPath.get().getPath()); + err(e); + } + } + + // Step 3: Check accuracy of classifier + forceTrack("Training accuracy"); + dataset.randomize(options.seed); + dumpAccuracy(fullClassifier, dataset); + endTrack("Training accuracy"); + + int numFolds = 5; + forceTrack("" + numFolds + " fold cross-validation"); + for (int fold = 0; fold < numFolds; ++fold) { + forceTrack("Fold " + (fold + 1)); + forceTrack("Training"); + Pair, GeneralDataset> foldData = dataset.splitOutFold(fold, numFolds); + Classifier classifier = classifierFactory.trainClassifier(foldData.first); + endTrack("Training"); + forceTrack("Test"); + dumpAccuracy(classifier, foldData.second); + endTrack("Test"); + endTrack("Fold " + (fold + 1)); + } + endTrack("" + numFolds + " fold cross-validation"); + + + // Step 5: return factory + return tree -> new ClauseSearcher(tree, Optional.of(fullClassifier), Optional.of(featurizer)); + } + + + /** + * TODO(gabor) JavaDoc + * @param trainingData + * @param modelPath + * @param trainingDataDump + * @return + */ + public static Function trainFactory( + Stream> trainingData, + File modelPath, + File trainingDataDump) { + // Featurizer + Featurizer featurizer = triple -> { + // Variables + ClauseSearcher.State from = triple.first; + ClauseSearcher.Action action = triple.second; + ClauseSearcher.State to = triple.third; + String signature = action.signature(); + String edgeRelTaken = to.edge == null ? "root" : to.edge.getRelation().toString(); + String edgeRelShort = to.edge == null ? "root" : to.edge.getRelation().getShortName(); + if (edgeRelShort.contains("_")) { + edgeRelShort = edgeRelShort.substring(0, edgeRelShort.indexOf("_")); + } + String edgeRelSpecific = to.edge == null ? null : to.edge.getRelation().getSpecific(); + + // -- Featurize -- + // Variables to aggregate + boolean parentHasSubj = false; + boolean parentHasObj = false; + boolean childHasSubj = false; + boolean childHasObj = false; + + // 1. edge taken + Counter feats = new ClassicCounter<>(); + feats.incrementCount(signature + "&edge:" + edgeRelTaken); + feats.incrementCount(signature + "&edge_type:" + edgeRelShort); + + if (to.edge != null) { + // 2. other edges at parent + for (SemanticGraphEdge parentNeighbor : from.originalTree().outgoingEdgeIterable(to.edge.getGovernor())) { + if (parentNeighbor != to.edge) { + String parentNeighborRel = parentNeighbor.getRelation().toString(); + if (parentNeighborRel.contains("subj")) { parentHasSubj = true; } + if (parentNeighborRel.contains("obj")) { parentHasObj = true; } + // (add feature) + feats.incrementCount(signature + "&parent_neighbor:" + parentNeighborRel); + feats.incrementCount(signature + "&edge_type:" + edgeRelShort + "&parent_neighbor:" + parentNeighborRel); + } + } + + // 3. Other edges at child + for (SemanticGraphEdge childNeighbor : from.originalTree().outgoingEdgeIterable(to.edge.getDependent())) { + String childNeighborRel = childNeighbor.getRelation().toString(); + if (childNeighborRel.contains("subj")) { childHasSubj = true; } + if (childNeighborRel.contains("obj")) { childHasObj = true; } + // (add feature) + feats.incrementCount(signature + "&child_neighbor:" + childNeighborRel); + feats.incrementCount(signature + "&edge_type:" + edgeRelShort + "&child_neighbor:" + childNeighborRel); + } + + // 4. Subject/Object stats + feats.incrementCount(signature + "&parent_neighbor_subj:" + parentHasSubj); + feats.incrementCount(signature + "&parent_neighbor_obj:" + parentHasObj); + feats.incrementCount(signature + "&child_neighbor_subj:" + childHasSubj); + feats.incrementCount(signature + "&child_neighbor_obj:" + childHasObj); + } + + // Return + return feats; + }; + // Train + return trainFactory(trainingData, featurizer, new TrainingOptions(), Optional.of(modelPath), Optional.of(trainingDataDump)); + } + + + /** + * TODO(gabor) JavaDoc + * @return + */ + public static Function factory(File serializedModel) throws IOException { + try { + System.err.println("Loading clause searcher from " + serializedModel.getPath() + " ..."); + Pair, Featurizer> data = IOUtils.readObjectFromFile(serializedModel); + return tree -> new ClauseSearcher(tree, Optional.of(data.first), Optional.of(data.second)); + } catch (ClassNotFoundException e) { + throw new IllegalStateException("Invalid model at path: " + serializedModel.getPath(), e); } } + } diff --git a/src/edu/stanford/nlp/naturalli/NaturalLogicRelation.java b/src/edu/stanford/nlp/naturalli/NaturalLogicRelation.java index 807e190407..cdeb245399 100644 --- a/src/edu/stanford/nlp/naturalli/NaturalLogicRelation.java +++ b/src/edu/stanford/nlp/naturalli/NaturalLogicRelation.java @@ -203,6 +203,7 @@ public NaturalLogicRelation join(NaturalLogicRelation other) { put("prep_after", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_against", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_along", NaturalLogicRelation.REVERSE_ENTAILMENT); // + put("prep_alongside", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_amid", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_among", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_anti", NaturalLogicRelation.REVERSE_ENTAILMENT); // @@ -225,6 +226,7 @@ public NaturalLogicRelation join(NaturalLogicRelation other) { put("prep_down", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_during", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_except", NaturalLogicRelation.REVERSE_ENTAILMENT); // + put("prep_en", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_excepting", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_excluding", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_following", NaturalLogicRelation.REVERSE_ENTAILMENT); // @@ -243,6 +245,7 @@ public NaturalLogicRelation join(NaturalLogicRelation other) { put("prep_on", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_onto", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_opposite", NaturalLogicRelation.REVERSE_ENTAILMENT); // + put("prep_out", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_outside", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_over", NaturalLogicRelation.REVERSE_ENTAILMENT); // put("prep_past", NaturalLogicRelation.REVERSE_ENTAILMENT); // diff --git a/src/edu/stanford/nlp/naturalli/SentenceFragment.java b/src/edu/stanford/nlp/naturalli/SentenceFragment.java index d6f72d0130..4087706299 100644 --- a/src/edu/stanford/nlp/naturalli/SentenceFragment.java +++ b/src/edu/stanford/nlp/naturalli/SentenceFragment.java @@ -27,6 +27,11 @@ public SentenceFragment(SemanticGraph tree, boolean copy) { words.addAll(this.parseTree.vertexListSorted().stream().map(IndexedWord::backingLabel).collect(Collectors.toList())); } + /** The length of this fragment, in words */ + public int length() { + return words.size(); + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/src/edu/stanford/nlp/naturalli/Util.java b/src/edu/stanford/nlp/naturalli/Util.java new file mode 100644 index 0000000000..7653dabdcd --- /dev/null +++ b/src/edu/stanford/nlp/naturalli/Util.java @@ -0,0 +1,74 @@ +package edu.stanford.nlp.naturalli; + +import edu.stanford.nlp.ie.machinereading.structure.Span; +import edu.stanford.nlp.ling.CoreAnnotations; +import edu.stanford.nlp.ling.CoreLabel; +import edu.stanford.nlp.pipeline.Annotation; +import edu.stanford.nlp.pipeline.AnnotationPipeline; +import edu.stanford.nlp.stats.ClassicCounter; +import edu.stanford.nlp.stats.Counter; +import edu.stanford.nlp.stats.Counters; +import edu.stanford.nlp.util.CoreMap; +import edu.stanford.nlp.util.StringUtils; + +import java.util.Collections; +import java.util.List; + +/** + * TODO(gabor) JavaDoc + * + * @author Gabor Angeli + */ +public class Util { + + public static String guessNER(List tokens, Span span) { + Counter nerGuesses = new ClassicCounter<>(); + for (int i : span) { + nerGuesses.incrementCount(tokens.get(i).ner()); + } + nerGuesses.remove("O"); + nerGuesses.remove(null); + if (nerGuesses.size() > 0) { + return Counters.argmax(nerGuesses); + } else { + return "O"; + } + } + + public static String guessNER(List tokens) { + return guessNER(tokens, new Span(0, tokens.size())); + } + + /** + * TODO(gabor) JavaDoc + * + * @param tokens + * @param seed + * @return + */ + public static Span extractNER(List tokens, Span seed) { + if (seed == null) { + return new Span(0, 1); + } + if (tokens.get(seed.start()).ner() == null) { + return seed; + } + int begin = seed.start(); + while (begin > 0 && tokens.get(begin - 1).ner().equals(tokens.get(seed.start()).ner())) { + begin -= 1; + } + int end = seed.end() - 1; + while (end < tokens.size() - 1 && tokens.get(end + 1).ner().equals(tokens.get(seed.end() - 1).ner())) { + end += 1; + } + return Span.fromValues(begin, end + 1); + } + + + public static void annotate(CoreMap sentence, AnnotationPipeline pipeline) { + Annotation ann = new Annotation(StringUtils.join(sentence.get(CoreAnnotations.TokensAnnotation.class), " ")); + ann.set(CoreAnnotations.TokensAnnotation.class, sentence.get(CoreAnnotations.TokensAnnotation.class)); + ann.set(CoreAnnotations.SentencesAnnotation.class, Collections.singletonList(sentence)); + pipeline.annotate(ann); + } +} diff --git a/src/edu/stanford/nlp/util/Pointer.java b/src/edu/stanford/nlp/util/Pointer.java new file mode 100644 index 0000000000..c5d830143d --- /dev/null +++ b/src/edu/stanford/nlp/util/Pointer.java @@ -0,0 +1,29 @@ +package edu.stanford.nlp.util; + +import java.util.Optional; + +/** + * A pointer to an object, to get around not being able to access non-final + * variables within an anonymous function. + * + * @author Gabor Angeli + */ +public class Pointer { + + private Optional impl; + + public Pointer() { + this.impl = Optional.empty(); + } + + @SuppressWarnings("UnusedDeclaration") + public Pointer(T impl) { + this.impl = Optional.of(impl); + } + + public Optional dereference() { return impl; } + + public void set(T impl) { this.impl = Optional.of(impl); } + + public void set(Optional impl) { this.impl = impl.isPresent() ? impl : this.impl; } +}