Skip to content

Commit

Permalink
weights of multinomialLR
Browse files Browse the repository at this point in the history
  • Loading branch information
sonalg authored and Stanford NLP committed Jan 23, 2015
1 parent 1ffb2a4 commit e1bf1ab
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
20 changes: 20 additions & 0 deletions src/edu/stanford/nlp/classify/MultinomialLogisticClassifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
Expand Down Expand Up @@ -132,4 +134,22 @@ private void save(String path) throws IOException {

System.out.println("done.");
}

public Map<L, Counter<F>> weightsAsGenericCounter() {

Map<L, Counter<F>> allweights = new HashMap<L, Counter<F>>();
for(int i = 0; i < weights.length; i++){
Counter<F> c = new ClassicCounter<F>();
L label = labelIndex.get(i);
double[] w = weights[i];
for (F f : featureIndex) {
int indexf = featureIndex.indexOf(f);
if(w[indexf] != 0.0)
c.setCount(f, w[indexf]);

}
allweights.put(label, c);
}
return allweights;
}
}
16 changes: 5 additions & 11 deletions src/edu/stanford/nlp/patterns/ScorePhrasesLearnFeatWt.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
import edu.stanford.nlp.patterns.dep.ExtractPhraseFromPattern;
import edu.stanford.nlp.patterns.dep.ExtractedPhrase;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.stats.*;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.util.Execution.Option;
import edu.stanford.nlp.util.concurrent.AtomicDouble;
Expand Down Expand Up @@ -162,13 +159,10 @@ public edu.stanford.nlp.classify.Classifier learnClassifier(String label, boolea
classifier = factory.trainClassifier(newdataset);

//print weights
LogisticClassifier logcl = ((LogisticClassifier) classifier);
String l = (String) logcl.getLabelForInternalPositiveClass();
Counter<String> weights = logcl.weightsAsGenericCounter();
if (l.equals(Boolean.FALSE.toString())) {
Counters.multiplyInPlace(weights, -1);
}
List<Pair<String, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
MultinomialLogisticClassifier<String, ScorePhraseMeasures> logcl = ((MultinomialLogisticClassifier) classifier);
Counter<ScorePhraseMeasures> weights = logcl.weightsAsGenericCounter().get("true");

List<Pair<ScorePhraseMeasures, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(wtd.subList(0, Math.min(wtd.size(), 600)), "\n"));

} else if(scoreClassifierType.equals(ClassifierType.LINEAR)){
Expand Down

0 comments on commit e1bf1ab

Please sign in to comment.