/
KBPAnnotatorBenchmarkSlowITest.java
140 lines (130 loc) · 5.78 KB
/
KBPAnnotatorBenchmarkSlowITest.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package edu.stanford.nlp.pipeline;
import junit.framework.TestCase;
import edu.stanford.nlp.ie.util.*;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.util.CoreMap;
import java.io.*;
import java.util.*;
public class KBPAnnotatorBenchmarkSlowITest extends TestCase {
public HashMap<String,String> docIDToText;
public HashMap<String,Set<String>> docIDToRelations;
public StanfordCoreNLP pipeline;
public String KBP_DOCS_DIR = "/scr/nlp/data/kbp-benchmark//kbp-docs";
public String GOLD_RELATIONS_PATH = "/scr/nlp/data/kbp-benchmark/kbp-gold-relations.txt";
public double KBP_MINIMUM_SCORE = .450;
private String convertRelationName(String relationName) {
/*if (relationName.equals("org:top_members/employees")) {
return "org:top_members_employees";
}*/
if (relationName.equals("per:employee_of")) {
return "per:employee_or_member_of";
}
if (relationName.equals("per:stateorprovinces_of_residence")) {
return "per:statesorprovinces_of_residence";
}
if (relationName.equals("org:number_of_employees/members")) {
return "org:number_of_employees_members";
}
if (relationName.equals("org:stateorprovince_of_headquarters")) {
return "org:stateprovince_of_headquarters";
}
if (relationName.equals("per:other_family")) {
return "per:otherfamily";
}
if (relationName.equals("org:founded")) {
return "org:date_founded";
}
if (relationName.equals("org:political/religious_affiliation")) {
return "org:political_religious_affiliation";
}
return relationName;
}
@Override
public void setUp() {
String pathToDocs = KBP_DOCS_DIR;
String goldRelationFilePath = GOLD_RELATIONS_PATH;
docIDToText = new HashMap<String,String>();
docIDToRelations = new HashMap<String,Set<String>>();
// load the gold relations from gold relations file
List<String> goldRelationLines = IOUtils.linesFromFile(goldRelationFilePath);
for (String relationLine : goldRelationLines) {
String[] docIDAndRelation = relationLine.split("\t");
if (docIDToRelations.get(docIDAndRelation[0]) == null) {
docIDToRelations.put(docIDAndRelation[0], new HashSet<String>());
}
docIDToRelations.get(docIDAndRelation[0]).add(docIDAndRelation[1]);
}
// load the text for each docID
File directoryWithDocs = new File(pathToDocs);
File[] allFiles = directoryWithDocs.listFiles();
for (File kbpTestDocFile : allFiles) {
String kbpTestDocID = kbpTestDocFile.getName();
String kbpTestDocPath = kbpTestDocFile.getAbsolutePath();
String kbpTestDocContents = IOUtils.stringFromFile(kbpTestDocPath);
docIDToText.put(kbpTestDocID, kbpTestDocContents);
}
// set up the pipeline
Properties props = new Properties();
props.put("annotators",
"tokenize,ssplit,pos,lemma,ner,regexner,parse,mention,entitymentions,coref,kbp");
props.put("coref.md.type", "RULE");
pipeline = new StanfordCoreNLP(props);
}
public Set<String> convertKBPTriplesToStrings(List<RelationTriple> relationTripleList) {
HashSet<String> foundRelationStrings = new HashSet<String>();
for (RelationTriple rt : relationTripleList) {
String relationName = convertRelationName(rt.relationGloss());
String relationString = relationName+"("+rt.subjectGloss()+","+rt.objectGloss()+")";
foundRelationStrings.add(relationString);
}
return foundRelationStrings;
}
public void testKBPAnnotatorResults() {
int totalGoldRelations = 0;
int totalCorrectFoundRelations = 0;
int totalWrongFoundRelations = 0;
int totalGuessRelations = 0;
double finalF1 = 0.0;
for (String docID : docIDToText.keySet()) {
System.out.println("---");
System.out.println(docID);
Annotation currAnnotation = new Annotation(docIDToText.get(docID));
pipeline.annotate(currAnnotation);
// increment number of seen gold relations
int docGoldRelationSetSize = 0;
if (docIDToRelations.get(docID) != null) {
docGoldRelationSetSize = docIDToRelations.get(docID).size();
}
totalGoldRelations += docGoldRelationSetSize;
ArrayList<RelationTriple> relationTriplesForThisDoc = new ArrayList<RelationTriple>();
for (CoreMap sentence : currAnnotation.get(CoreAnnotations.SentencesAnnotation.class)) {
List<RelationTriple> rtList = sentence.get(CoreAnnotations.KBPTriplesAnnotation.class);
for (RelationTriple rt : rtList) {
System.out.println("\t"+rt.toString());
relationTriplesForThisDoc.add(rt);
}
}
Set<String> foundRelationStrings = convertKBPTriplesToStrings(relationTriplesForThisDoc);
HashSet<String> intersectionOfFoundAndGold = new HashSet<String>(foundRelationStrings);
if (docIDToRelations.get(docID) != null) {
intersectionOfFoundAndGold.retainAll(docIDToRelations.get(docID));
totalCorrectFoundRelations += (intersectionOfFoundAndGold.size());
totalWrongFoundRelations += (foundRelationStrings.size()-intersectionOfFoundAndGold.size());
} else {
totalWrongFoundRelations += foundRelationStrings.size();
}
totalGuessRelations += foundRelationStrings.size();
System.out.println("curr score: ");
double recall = (((double) totalCorrectFoundRelations)/((double) totalGoldRelations));
double precision = (((double) totalCorrectFoundRelations)/((double) totalGuessRelations));
System.out.println("\trecall: "+recall);
System.out.println("\tprecision: "+precision);
double f1 = (2 * (precision * recall))/(precision + recall);
System.out.println("\tf1: "+f1);
finalF1 = f1;
}
// check final F1 score is
assertTrue("f1 score: " + finalF1 +" is below threshold of 45.3", finalF1 >= KBP_MINIMUM_SCORE);
}
}