Skip to content

Commit

Permalink
new coreference refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
J38 authored and Stanford NLP committed Sep 24, 2016
1 parent 2775e63 commit 18bef51
Show file tree
Hide file tree
Showing 107 changed files with 3,206 additions and 1,533 deletions.
18 changes: 12 additions & 6 deletions src/edu/stanford/nlp/coref/CorefAlgorithm.java
Expand Up @@ -5,20 +5,21 @@
import edu.stanford.nlp.coref.CorefProperties.CorefAlgorithmType; import edu.stanford.nlp.coref.CorefProperties.CorefAlgorithmType;
import edu.stanford.nlp.coref.data.Dictionaries; import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.data.Document; import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.hybrid.HybridCorefSystem; //import edu.stanford.nlp.coref.hybrid.HybridCorefSystem;
import edu.stanford.nlp.coref.neural.NeuralCorefAlgorithm; import edu.stanford.nlp.coref.neural.NeuralCorefAlgorithm;
import edu.stanford.nlp.coref.statistical.ClusteringCorefAlgorithm; //import edu.stanford.nlp.coref.statistical.ClusteringCorefAlgorithm;
import edu.stanford.nlp.coref.statistical.StatisticalCorefAlgorithm; //import edu.stanford.nlp.coref.statistical.StatisticalCorefAlgorithm;


/** /**
* A CorefAlgorithms makes coreference decisions on the provided {@link Document} after * A CorefAlgorithms makes coreference decisions on the provided {@link Document} after
* mention detection has been performed * mention detection has been performed
* @author Kevin Clark * @author Kevin Clark
*/ */
public interface CorefAlgorithm { public interface CorefAlgorithm {
public void runCoref(Document document);


public static CorefAlgorithm fromProps(Properties props, Dictionaries dictionaries) { public void runCoref(Document document);

/*public static CorefAlgorithm fromProps(Properties props, Dictionaries dictionaries) {
CorefAlgorithmType algorithm = CorefProperties.algorithm(props); CorefAlgorithmType algorithm = CorefProperties.algorithm(props);
if (algorithm == CorefAlgorithmType.CLUSTERING) { if (algorithm == CorefAlgorithmType.CLUSTERING) {
return new ClusteringCorefAlgorithm(props, dictionaries); return new ClusteringCorefAlgorithm(props, dictionaries);
Expand All @@ -33,5 +34,10 @@ public static CorefAlgorithm fromProps(Properties props, Dictionaries dictionari
throw new RuntimeException("Error creating hybrid coref system", e); throw new RuntimeException("Error creating hybrid coref system", e);
} }
} }
} }*/

public static CorefAlgorithm fromProps(Properties props, Dictionaries dictionaries) {
return new NeuralCorefAlgorithm(props, dictionaries);
}

} }
5 changes: 2 additions & 3 deletions src/edu/stanford/nlp/coref/CorefCoreAnnotations.java
Expand Up @@ -6,15 +6,14 @@


import edu.stanford.nlp.coref.data.CorefChain; import edu.stanford.nlp.coref.data.CorefChain;
import edu.stanford.nlp.coref.data.Mention; import edu.stanford.nlp.coref.data.Mention;

import edu.stanford.nlp.ling.CoreAnnotation; import edu.stanford.nlp.ling.CoreAnnotation;
import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.util.ErasureUtils; import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.IntTuple; import edu.stanford.nlp.util.IntTuple;
import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.Pair;


/** /**
* Similar to {@link edu.stanford.nlp.ling.CoreAnnotations}, * Similar to {@link edu.stanford.nlp.ling.CoreAnnotations},
* but this class contains * but this class contains
* annotations made specifically for storing Coref data. This is kept * annotations made specifically for storing Coref data. This is kept
* separate from CoreAnnotations so that systems which only need * separate from CoreAnnotations so that systems which only need
Expand Down Expand Up @@ -80,7 +79,7 @@ public Class<Set<CoreLabel>> getType() {
return ErasureUtils.uncheckedCast(Set.class); return ErasureUtils.uncheckedCast(Set.class);
} }
} }

/** /**
* CorefChainID - CorefChain map * CorefChainID - CorefChain map
*/ */
Expand Down
252 changes: 252 additions & 0 deletions src/edu/stanford/nlp/coref/CorefDocMaker.java
@@ -0,0 +1,252 @@
package edu.stanford.nlp.coref;

import java.util.ArrayList;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.List;
import java.util.Locale;
import java.util.Properties;

import edu.stanford.nlp.classify.LogisticClassifier;
import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.InputDoc;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.coref.docreader.CoNLLDocumentReader;
import edu.stanford.nlp.coref.docreader.DocReader;
import edu.stanford.nlp.coref.md.CorefMentionFinder;
import edu.stanford.nlp.coref.md.DependencyCorefMentionFinder;
import edu.stanford.nlp.coref.md.HybridCorefMentionFinder;
import edu.stanford.nlp.coref.md.RuleBasedCorefMentionFinder;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreAnnotations.SentencesAnnotation;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import edu.stanford.nlp.trees.HeadFinder;
import edu.stanford.nlp.trees.SemanticHeadFinder;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeCoreAnnotations;
import edu.stanford.nlp.trees.TreeLemmatizer;
import edu.stanford.nlp.trees.international.pennchinese.ChineseSemanticHeadFinder;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.logging.Redwood;

/**
*
* make Document for coref input from Annotation and optional info
* read input (raw, conll etc) with DocReader, mention detection, and document preprocessing will be done here
*
* @author heeyoung
*/
public class CorefDocMaker {

Properties props;
DocReader reader;
final HeadFinder headFinder;
CorefMentionFinder md;
Dictionaries dict;
StanfordCoreNLP corenlp;
final TreeLemmatizer treeLemmatizer;
LogisticClassifier<String, String> singletonPredictor;

boolean addMissingAnnotations ;

public CorefDocMaker(Properties props, Dictionaries dictionaries) throws ClassNotFoundException, IOException {
this.props = props;
this.dict = dictionaries;
reader = getDocumentReader(props);
headFinder = getHeadFinder(props);
md = getMentionFinder(props, dictionaries, headFinder);
// corenlp = new StanfordCoreNLP(props, false);
// the property coref.addMissingAnnotations must be set to true to get the CorefDocMaker to add annotations
if (CorefProperties.addMissingAnnotations(props)) {
addMissingAnnotations = true;
corenlp = loadStanfordProcessor(props);
} else {
addMissingAnnotations = false;
}
treeLemmatizer = new TreeLemmatizer();
singletonPredictor = (CorefProperties.useSingletonPredictor(props))?
getSingletonPredictorFromSerializedFile(CorefProperties.getPathSingletonPredictor(props)) : null;
}

/** Load Stanford Processor: skip unnecessary annotator */
protected StanfordCoreNLP loadStanfordProcessor(Properties props) {

Properties pipelineProps = new Properties(props);
StringBuilder annoSb = new StringBuilder("");
if (!CorefProperties.useGoldPOS(props)) {
annoSb.append("pos, lemma");
} else {
annoSb.append("lemma");
}
if(CorefProperties.USE_TRUECASE) {
annoSb.append(", truecase");
}
if (!CorefProperties.useGoldNE(props) || CorefProperties.getLanguage(props)==Locale.CHINESE) {
annoSb.append(", ner");
}
if (!CorefProperties.useGoldParse(props)) {
if(CorefProperties.useConstituencyTree(props)) annoSb.append(", parse");
else annoSb.append(", depparse");
}
// need to add mentions
annoSb.append(", mention");
String annoStr = annoSb.toString();
Redwood.log("MentionExtractor ignores specified annotators, using annotators=" + annoStr);
pipelineProps.put("annotators", annoStr);
return new StanfordCoreNLP(pipelineProps, false);
}


private static DocReader getDocumentReader(Properties props) {
switch (CorefProperties.getInputType(props)) {
case CONLL:
String corpusPath = CorefProperties.getPathInput(props);
CoNLLDocumentReader.Options options = new CoNLLDocumentReader.Options();
options.annotateTokenCoref = false;
if (CorefProperties.useCoNLLAuto(props)) options.setFilter(".*_auto_conll$");
options.lang = CorefProperties.getLanguage(props);
return new CoNLLDocumentReader(corpusPath, options);

case ACE:
// TODO
return null;

case MUC:
// TODO
return null;

case RAW:
default: // default is raw text
// TODO
return null;
}
}

private static HeadFinder getHeadFinder(Properties props) {
Locale lang = CorefProperties.getLanguage(props);
if(lang == Locale.ENGLISH) return new SemanticHeadFinder();
else if(lang == Locale.CHINESE) return new ChineseSemanticHeadFinder();
else {
throw new RuntimeException("Invalid language setting: cannot load HeadFinder");
}
}

private static CorefMentionFinder getMentionFinder(Properties props, Dictionaries dictionaries, HeadFinder headFinder) throws ClassNotFoundException, IOException {

switch (CorefProperties.getMDType(props)) {
case RULE:
return new RuleBasedCorefMentionFinder(headFinder, props);

case HYBRID:
return new HybridCorefMentionFinder(headFinder, props);

case DEPENDENCY:
default: // default is dependency
return new DependencyCorefMentionFinder(props);
}
}

public Document makeDocument(Annotation anno) throws Exception {
return makeDocument(new InputDoc(anno, null, null));
}

/**
* Make Document for coref (for method coref(Document doc, StringBuilder[] outputs)).
* Mention detection and document preprocessing is done here.
* @throws Exception
*/
public Document makeDocument(InputDoc input) throws Exception {
if (input == null) return null;
Annotation anno = input.annotation;

if (Boolean.parseBoolean(props.getProperty("coref.useMarkedDiscourse", "false"))) {
anno.set(CoreAnnotations.UseMarkedDiscourseAnnotation.class, true);
}

// add missing annotation
if (addMissingAnnotations) {
addMissingAnnotation(anno);
}

// remove nested NP with same headword except newswire document for chinese

//if(input.conllDoc != null && CorefProperties.getLanguage(props)==Locale.CHINESE){
//CorefProperties.setRemoveNested(props, !input.conllDoc.documentID.contains("nw"));
//}

// each sentence should have a CorefCoreAnnotations.CorefMentionsAnnotation.class which maps to List<Mention>
// this is set by the mentions annotator
List<List<Mention>> mentions = new ArrayList<>() ;
for (CoreMap sentence : anno.get(CoreAnnotations.SentencesAnnotation.class)) {
mentions.add(sentence.get(CorefCoreAnnotations.CorefMentionsAnnotation.class));
}

Document doc = new Document(input, mentions);

// find headword for gold mentions
if(input.goldMentions!=null) findGoldMentionHeads(doc);

// document preprocessing: initialization (assign ID), mention processing (gender, number, type, etc), speaker extraction, etc
Preprocessor.preprocess(doc, dict, singletonPredictor, headFinder);

return doc;
}

private void findGoldMentionHeads(Document doc) {
List<CoreMap> sentences = doc.annotation.get(SentencesAnnotation.class);
for(int i=0 ; i<sentences.size() ; i++ ) {
// md.findHead(sentences.get(i), doc.goldMentions.get(i));
DependencyCorefMentionFinder.findHeadInDependency(sentences.get(i), doc.goldMentions.get(i));
}
}

private void addMissingAnnotation(Annotation anno) {
if (addMissingAnnotations) {
boolean useConstituency = CorefProperties.useConstituencyTree(props);
final boolean LEMMATIZE = true;

List<CoreMap> sentences = anno.get(CoreAnnotations.SentencesAnnotation.class);
for (CoreMap sentence : sentences) {
boolean hasTree = sentence.containsKey(TreeCoreAnnotations.TreeAnnotation.class);
Tree tree = sentence.get(TreeCoreAnnotations.TreeAnnotation.class);

if (!useConstituency) { // TODO: temp for dev: make sure we don't use constituency tree
sentence.remove(TreeCoreAnnotations.TreeAnnotation.class);
}
if (LEMMATIZE && hasTree && useConstituency) treeLemmatizer.transformTree(tree); // TODO don't need?
}
corenlp.annotate(anno);
} else {
throw new RuntimeException("Error: must set coref.addMissingAnnotations = true to call method addMissingAnnotation");
}
}

public void resetDocs() {
reader.reset();
}

public Document nextDoc() throws Exception {
InputDoc input = reader.nextDoc();
return (input == null)? null : makeDocument(input);
}

public static LogisticClassifier<String, String> getSingletonPredictorFromSerializedFile(String serializedFile) {
try {
ObjectInputStream ois = IOUtils.readStreamFromString(serializedFile);
Object o = ois.readObject();
if (o instanceof LogisticClassifier<?, ?>) {
return (LogisticClassifier<String, String>) o;
}
throw new ClassCastException("Wanted SingletonPredictor, got " + o.getClass());
} catch (IOException e) {
throw new RuntimeIOException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}

}
12 changes: 12 additions & 0 deletions src/edu/stanford/nlp/coref/CorefDocumentProcessor.java
Expand Up @@ -40,13 +40,25 @@ public default void run(DocumentMaker docMaker) throws Exception {
Document document = docMaker.nextDoc(); Document document = docMaker.nextDoc();
long time = System.currentTimeMillis(); long time = System.currentTimeMillis();
while (document != null) { while (document != null) {
/*if (docId < 130) {
Redwood.log(getName(), "Processed document " + docId + " in "
+ (System.currentTimeMillis() - time) / 1000.0 + "s");
time = System.currentTimeMillis();
docId++;
document = docMaker.nextDoc();
continue;
}*/
document.extractGoldCorefClusters(); document.extractGoldCorefClusters();
process(docId, document); process(docId, document);
Redwood.log(getName(), "Processed document " + docId + " in " Redwood.log(getName(), "Processed document " + docId + " in "
+ (System.currentTimeMillis() - time) / 1000.0 + "s"); + (System.currentTimeMillis() - time) / 1000.0 + "s");
time = System.currentTimeMillis(); time = System.currentTimeMillis();
docId++; docId++;
document = docMaker.nextDoc(); document = docMaker.nextDoc();

/*if (docId > 15) {
break;
}*/
} }
finish(); finish();
} }
Expand Down
15 changes: 12 additions & 3 deletions src/edu/stanford/nlp/coref/CorefPrinter.java
Expand Up @@ -14,6 +14,14 @@
import edu.stanford.nlp.util.CoreMap; import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.Generics;


import edu.stanford.nlp.util.logging.Redwood;


import java.text.DecimalFormat;
import java.util.logging.*;
import java.util.regex.*;


/** /**
* Class for printing out coreference output. * Class for printing out coreference output.
* @author Heeyoung Lee * @author Heeyoung Lee
Expand All @@ -28,9 +36,9 @@ public static String printConllOutput(Document document, boolean gold, boolean f
List<List<Mention>> orderedMentions = gold ? document.goldMentions : document.predictedMentions; List<List<Mention>> orderedMentions = gold ? document.goldMentions : document.predictedMentions;
if (filterSingletons) { if (filterSingletons) {
orderedMentions = orderedMentions.stream().map( orderedMentions = orderedMentions.stream().map(
ml -> ml.stream().filter(m -> document.corefClusters.get(m.corefClusterID) != null && ml -> ml.stream().filter(m -> document.corefClusters.get(m.corefClusterID) != null &&
document.corefClusters.get(m.corefClusterID).size() > 1) document.corefClusters.get(m.corefClusterID).getCorefMentions().size() > 1)
.collect(Collectors.toList())) .collect(Collectors.toList()))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
return CorefPrinter.printConllOutput(document, orderedMentions, gold); return CorefPrinter.printConllOutput(document, orderedMentions, gold);
Expand Down Expand Up @@ -105,4 +113,5 @@ public static String printConllOutput(Document document,


return sb.toString(); return sb.toString();
} }

} }

0 comments on commit 18bef51

Please sign in to comment.