-
Notifications
You must be signed in to change notification settings - Fork 0
/
BasicParsingReranker.java
137 lines (121 loc) · 4.39 KB
/
BasicParsingReranker.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package edu.berkeley.nlp.assignments.rerank.student;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import edu.berkeley.nlp.assignments.rerank.KbestList;
import edu.berkeley.nlp.assignments.rerank.ParsingReranker;
import edu.berkeley.nlp.ling.Tree;
import edu.berkeley.nlp.parser.EnglishPennTreebankParseEvaluator;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.IntCounter;
import edu.berkeley.nlp.util.Pair;
/**
* k-best discriminative reranker
*
* @author Samridhi
* Uses Perceptron to learn the feature weights.
*/
public class BasicParsingReranker implements ParsingReranker
{
IntCounter weightVector;
Indexer<String> featureIndexer;
List<FeatureData> lossAugmentedTrainData;
/*
* Class constructor for training.
*/
public BasicParsingReranker(Iterable<Pair<KbestList,Tree<String>>> kbestListsAndGoldTrees)
{
featureIndexer = new Indexer<String>();
lossAugmentedTrainData = new ArrayList<FeatureData>();
/*
* Extract the features for each tree in the k-best list and build the cache
*/
FeatureExtractor featExtract = new FeatureExtractor();
List<FeatureData> trainingDataList = new ArrayList<FeatureData>();
System.out.println("Extracting Features");
for(Pair<KbestList,Tree<String>> iter : kbestListsAndGoldTrees)
{
KbestList kbestList = iter.getFirst();
List<Tree<String>> kBestTrees = kbestList.getKbestTrees();
int kListSize = kBestTrees.size();
Tree<String> goldTree = iter.getSecond();
double[] kBestLoss = new double[kListSize];
// Get the feature vector for goldtree
int[] goldList = featExtract.ExtractGoldTreeFeatures(goldTree, kbestList, featureIndexer, true);
IntCounter goldFeatures = CreateFeatureCounterList(goldList);
int[][] featureList = new int[kBestTrees.size()][];
// Get the feature vectors corresponding to each parse tree
for(int i = 0; i < kBestTrees.size(); i++)
{
int[] curFeatureList = featExtract.ExtractKListFeatures(kbestList, i, featureIndexer, true);
featureList[i] = curFeatureList;
kBestLoss[i] = GetLossFunctionValue(goldTree, kBestTrees.get(i), false);
}
FeatureData trainData = new FeatureData(goldFeatures, kBestLoss, featureList);
trainingDataList.add(trainData);
}
IntCounter tempWeightVector = new IntCounter(featureIndexer.size());
for(int i = 0; i < featureIndexer.size(); i++)
{
tempWeightVector.put(i, 0);
}
System.out.println("Training with perceptron");
int maxIter = 30;
double tolerance = 0.05;
System.out.println("Norm squared before calling the perceptron = " + tempWeightVector.normSquared());
PerceptronLearner trainer = new PerceptronLearner(maxIter, tolerance);
weightVector = trainer.Train(tempWeightVector, trainingDataList);
}
public IntCounter CreateFeatureCounterList(int[] featureList)
{
IntCounter finalList = new IntCounter();
//Update the counts for the occurrences of a feature
for(int curFeature : featureList)
{
double val = finalList.get(curFeature) + 1;
finalList.put(curFeature, val);
}
return finalList;
}
public Tree<String> getBestParse(List<String> sentence, KbestList kbestList)
{
double maxVal, curVal;
Tree<String> argmax;
maxVal = Double.NEGATIVE_INFINITY;
curVal = 0;
argmax = null;
FeatureExtractor featExtract = new FeatureExtractor();
List<Tree<String>> kBestTrees = kbestList.getKbestTrees();
for(int idx = 0; idx < kBestTrees.size(); idx++)
{
int[] curFeature = featExtract.ExtractKListFeatures(kbestList, idx, featureIndexer, false);
IntCounter featVector = CreateFeatureCounterList(curFeature);
curVal = featVector.dotProduct(weightVector);
if(curVal > maxVal)
{
maxVal = curVal;
argmax = kBestTrees.get(idx);
}
}
return argmax;
}
private double GetLossFunctionValue(Tree<String> goldTree, Tree<String> guessTree, Boolean useF1)
{
double lossValue = 0;
if(useF1)
{
EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String> eval = new EnglishPennTreebankParseEvaluator
.LabeledConstituentEval<String>(Collections.singleton("ROOT"),
new HashSet<String>(Arrays.asList(new String[] { "''", "``", ".", ":", "," })));
double f1 = eval.evaluateF1(guessTree, goldTree);
lossValue = 1-f1;
}
else
{
lossValue = ((guessTree.toString()).equalsIgnoreCase(goldTree.toString()))? 0 : 1;
}
return lossValue;
}
}