Skip to content

Commit

Permalink
Improve and fix classify code. - Fix bug I introduced before in Colum…
Browse files Browse the repository at this point in the history
…nDataClassifier. - Change LinearClassifierFactory to have 1 not 2 loggers. - Change LinearClassifierFactory to have QNMinimizers respect verbose setting correct. - Remove a few of the many LinearClassifierFactory constructors. - Change RVFDataset to have 1 not 2 loggers.
  • Loading branch information
manning authored and Stanford NLP committed Feb 28, 2016
1 parent 2f2b3ea commit 7c77c66
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 88 deletions.
1 change: 0 additions & 1 deletion src/edu/stanford/nlp/classify/ColumnDataClassifier.java
Expand Up @@ -1464,7 +1464,6 @@ public Classifier<String,String> makeClassifier(GeneralDataset<String,String> tr
lcf = new LinearClassifierFactory<>(globalFlags.tolerance, globalFlags.useSum, globalFlags.prior, globalFlags.sigma, globalFlags.epsilon, globalFlags.QNsize);
}
lcf.setVerbose(globalFlags.verboseOptimization);
lcf.useQuasiNewton(); // redundantly specify default a second time so verbose is read. todo: fix architecture of lcf
if ( ! globalFlags.useQN) {
lcf.useConjugateGradientAscent();
}
Expand Down
137 changes: 62 additions & 75 deletions src/edu/stanford/nlp/classify/LinearClassifierFactory.java
@@ -1,6 +1,6 @@
// Stanford Classifier - a multiclass maxent classifier
// LinearClassifierFactory
// Copyright (c) 2003-2007 The Board of Trustees of
// Copyright (c) 2003-2016 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
Expand Down Expand Up @@ -31,6 +31,7 @@
import java.io.BufferedReader;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.Datum;
Expand All @@ -42,12 +43,9 @@
import edu.stanford.nlp.stats.MultiClassAccuracyStats;
import edu.stanford.nlp.stats.Scorer;
import edu.stanford.nlp.util.*;

import java.util.function.Function;


import edu.stanford.nlp.util.logging.Redwood;


/**
* Builds various types of linear classifiers, with functionality for
* setting objective function, optimization method, and other parameters.
Expand All @@ -68,9 +66,6 @@

public class LinearClassifierFactory<L, F> extends AbstractLinearClassifierFactory<L, F> {

/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(LinearClassifierFactory.class);

private static final long serialVersionUID = 7893768984379107397L;
private double TOL;
//public double sigma;
Expand All @@ -85,96 +80,85 @@ public class LinearClassifierFactory<L, F> extends AbstractLinearClassifierFacto
private boolean tuneSigmaCV = false;
//private boolean resetWeight = true;
private int folds;
// range of values to tune sigma across
private double min = 0.1;
private double max = 10.0;
private boolean retrainFromScratchAfterSigmaTuning = false;

private Factory<Minimizer<DiffFunction>> minimizerCreator = null;
private int evalIters = -1;
private Evaluator[] evaluators = null;
private Evaluator[] evaluators; // = null;

final static Redwood.RedwoodChannels logger = Redwood.channels(LinearClassifierFactory.class);
/** A logger for this class */
private final static Redwood.RedwoodChannels logger = Redwood.channels(LinearClassifierFactory.class);

/** This is the {@code Factory<Minimizer<DiffFunction>>} that we use over and over again. */
private static class Factory15 implements Factory<Minimizer<DiffFunction>> {
private class QNFactory implements Factory<Minimizer<DiffFunction>> {

private static final long serialVersionUID = 6215752553371189173L;
private static final long serialVersionUID = 9028306475652690036L;

@Override
public Minimizer<DiffFunction> create() {
return new QNMinimizer(15);
QNMinimizer qnMinimizer = new QNMinimizer(LinearClassifierFactory.this.mem);
if (! verbose) {
qnMinimizer.shutUp();
}
return qnMinimizer;
}

} // end class Factory15
} // end class QNFactory


public LinearClassifierFactory() {
this(new Factory15());
this.mem = 15;
this.useQuasiNewton();
this((Factory<Minimizer<DiffFunction>>) null);
}

/** NOTE: Constructors that take in a Minimizer create a LinearClassifierFactory that will reuse the minimizer
* and will not be threadsafe (unless the Minimizer itself is ThreadSafe which is probably not the case).
* and will not be threadsafe (unless the Minimizer itself is ThreadSafe, which is probably not the case).
*/
public LinearClassifierFactory(Minimizer<DiffFunction> min) {
this(min, false);
this(min, 1e-4, false);
}

public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min) {
this(min, false);
this(min, 1e-4, false);
}

public LinearClassifierFactory(boolean useSum) {
this(new Factory15(), useSum);
this.mem = 15;
this.useQuasiNewton();
}

public LinearClassifierFactory(double tol) {
this(new Factory15(), tol, false);
this.mem = 15;
this.useQuasiNewton();
}
public LinearClassifierFactory(Minimizer<DiffFunction> min, boolean useSum) {
this(min, 1e-4, useSum);
}
public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, boolean useSum) {
this(min, 1e-4, useSum);
}
public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum) {
this(min, tol, useSum, 1.0);
}

public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum) {
this(min, tol, useSum, 1.0);
}

public LinearClassifierFactory(double tol, boolean useSum, double sigma) {
this(new Factory15(), tol, useSum, sigma);
this.mem = 15;
this.useQuasiNewton();
this((Factory<Minimizer<DiffFunction>>) null, tol, useSum, sigma);
}

public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum, double sigma) {
this(min, tol, useSum, LogPrior.LogPriorType.QUADRATIC.ordinal(), sigma);
}

public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum, double sigma) {
this(min, tol, useSum, LogPrior.LogPriorType.QUADRATIC.ordinal(), sigma);
}

public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum, int prior, double sigma) {
this(min, tol, useSum, prior, sigma, 0.0);
}

public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum, int prior, double sigma) {
this(min, tol, useSum, prior, sigma, 0.0);
}

public LinearClassifierFactory(double tol, boolean useSum, int prior, double sigma, double epsilon) {
this(new Factory15(), tol, useSum, new LogPrior(prior, sigma, epsilon));
this.mem = 15;
this.useQuasiNewton();
this((Factory<Minimizer<DiffFunction>>) null, tol, useSum, new LogPrior(prior, sigma, epsilon));
}

public LinearClassifierFactory(double tol, boolean useSum, int prior, double sigma, double epsilon, final int mem) {
this(new Factory15(), tol, useSum, new LogPrior(prior, sigma, epsilon));
this.useQuasiNewton();
this((Factory<Minimizer<DiffFunction>>) null, tol, useSum, new LogPrior(prior, sigma, epsilon));
this.mem = mem;
}

/**
Expand All @@ -195,6 +179,7 @@ public LinearClassifierFactory(double tol, boolean useSum, int prior, double sig
public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum, int prior, double sigma, double epsilon) {
this(min, tol, useSum, new LogPrior(prior, sigma, epsilon));
}

public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum, int prior, double sigma, double epsilon) {
this(min, tol, useSum, new LogPrior(prior, sigma, epsilon));
}
Expand All @@ -213,8 +198,24 @@ public Minimizer<DiffFunction> create() {
this.logPrior = logPrior;
}

/**
* Create a factory that builds linear classifiers from training data. This is the recommended constructor to
* bottom out with. Use of a minimizerCreator makes the classifier threadsafe.
*
* @param minimizerCreator A Factory for creating minimizers. If this is null, a standard quasi-Newton minimizer
* factory will be used.
* @param tol The convergence threshold for the minimization (default: 1e-4)
* @param useSum Asks to the optimizer to minimize the sum of the
* likelihoods of individual data items rather than their product (Klein and Manning 2001 WSD.)
* NOTE: this is currently ignored!!! At some point support for this option was deleted
* @param logPrior What kind of prior to use, this class specifies its type and hyperparameters.
*/
public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> minimizerCreator, double tol, boolean useSum, LogPrior logPrior) {
this.minimizerCreator = minimizerCreator;
if (minimizerCreator == null) {
this.minimizerCreator = new QNFactory();
} else {
this.minimizerCreator = minimizerCreator;
}
this.TOL = tol;
//this.useSum = useSum;
this.logPrior = logPrior;
Expand Down Expand Up @@ -272,17 +273,7 @@ public double getSigma() {
* Sets the minimizer to QuasiNewton. {@link QNMinimizer} is the default.
*/
public void useQuasiNewton() {
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = 9028306475652690036L;
@Override
public Minimizer<DiffFunction> create() {
QNMinimizer qnMinimizer = new QNMinimizer(LinearClassifierFactory.this.mem);
if (!verbose) {
qnMinimizer.shutUp();
}
return qnMinimizer;
}
};
this.minimizerCreator = new QNFactory();
}

public void useQuasiNewton(final boolean useRobust) {
Expand Down Expand Up @@ -515,11 +506,7 @@ public double[][] trainWeights(GeneralDataset<L, F> dataset, double[] initial) {
}

public double[][] trainWeights(GeneralDataset<L, F> dataset, double[] initial, boolean bypassTuneSigma) {
return trainWeights(dataset, initial, bypassTuneSigma, null);
}

public double[][] trainWeights(GeneralDataset<L, F> dataset, double[] initial, boolean bypassTuneSigma, Minimizer<DiffFunction> minimizer) {
if (minimizer == null) minimizer = getMinimizer();
Minimizer<DiffFunction> minimizer = getMinimizer();
if(dataset instanceof RVFDataset)
((RVFDataset<L,F>)dataset).ensureRealValues();
double[] interimWeights = null;
Expand Down Expand Up @@ -777,8 +764,7 @@ public void crossValidateSetSigma(GeneralDataset<L, F> dataset,int kfold, final
//sigma = sigmaToTry;
setSigma(sigmaToTry);
Double averageScore = crossValidator.computeAverage(scoreFn);
log.info("##sigma = "+getSigma()+" ");
logger.info("-> average Score: "+averageScore);
logger.info("##sigma = "+getSigma() + " -> average Score: " + averageScore);
return -averageScore;
};

Expand Down Expand Up @@ -812,15 +798,14 @@ public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F>
public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F> dev, final Scorer<L> scorer) {
return heldOutSetSigma(train, dev, scorer, new GoldenSectionLineSearch(true, 1e-2, min, max));
}
public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F> dev, LineSearcher minimizer) {

public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F> dev, LineSearcher minimizer) {
return heldOutSetSigma(train, dev, new MultiClassAccuracyStats<>(MultiClassAccuracyStats.USE_LOGLIKELIHOOD), minimizer);
}

/**
* Sets the sigma parameter to a value that optimizes the held-out score given by <code>scorer</code>. Search for an optimal value
* is carried out by <code>minimizer</code>
* dataset the data set to optimize sigma on.
* kfold
* Sets the sigma parameter to a value that optimizes the held-out score given by {@code scorer}. Search for an
* optimal value is carried out by {@code minimizer} dataset the data set to optimize sigma on. kfold
*
* @return an interim set of optimal weights: the weights
*/
Expand Down Expand Up @@ -869,8 +854,7 @@ public Double apply(Double sigmaToTry) {
double score = scorer.score(classifier, devSet);
//System.out.println("score: "+score);
//System.out.print(".");
log.info("##sigma = "+getSigma()+" ");
logger.info("-> average Score: " + score);
logger.info("##sigma = " + getSigma() + " -> average Score: " + score);
logger.info("##time elapsed: " + timer.stop() + " milliseconds.");
timer.restart();
return -score;
Expand Down Expand Up @@ -925,10 +909,12 @@ public Classifier<L, F> trainClassifier(GeneralDataset<L, F> dataset, float[] da
public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) {
return trainClassifier(dataset, null);
}

public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset, double[] initial) {
// Sanity check
if(dataset instanceof RVFDataset)
((RVFDataset<L,F>)dataset).ensureRealValues();
if (dataset instanceof RVFDataset) {
((RVFDataset<L, F>) dataset).ensureRealValues();
}
if (initial != null) {
for (double weight : initial) {
if (Double.isNaN(weight) || Double.isInfinite(weight)) {
Expand All @@ -941,10 +927,12 @@ public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset, doub
LinearClassifier<L, F> classifier = new LinearClassifier<>(weights, dataset.featureIndex(), dataset.labelIndex());
return classifier;
}

public LinearClassifier<L, F> trainClassifierWithInitialWeights(GeneralDataset<L, F> dataset, double[][] initialWeights2D) {
double[] initialWeights = (initialWeights2D != null)? ArrayUtils.flatten(initialWeights2D):null;
return trainClassifier(dataset, initialWeights);
}

public LinearClassifier<L, F> trainClassifierWithInitialWeights(GeneralDataset<L, F> dataset, LinearClassifier<L,F> initialClassifier) {
double[][] initialWeights2D = (initialClassifier != null)? initialClassifier.weights():null;
return trainClassifierWithInitialWeights(dataset, initialWeights2D);
Expand Down Expand Up @@ -999,8 +987,7 @@ public static LinearClassifier<String, String> loadFromFilename(String file) {
}
}

public void setEvaluators(int iters, Evaluator[] evaluators)
{
public void setEvaluators(int iters, Evaluator[] evaluators) {
this.evalIters = iters;
this.evaluators = evaluators;
}
Expand Down
20 changes: 10 additions & 10 deletions src/edu/stanford/nlp/classify/RVFDataset.java
Expand Up @@ -46,10 +46,7 @@
* @param <L> The type of the labels in the Dataset
* @param <F> The type of the features in the Dataset
*/
public class RVFDataset<L, F> extends GeneralDataset<L, F> {

/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(RVFDataset.class); // implements Iterable<RVFDatum<L, F>>, Serializable
public class RVFDataset<L, F> extends GeneralDataset<L, F> { // implements Iterable<RVFDatum<L, F>>, Serializable

private static final long serialVersionUID = -3841757837680266182L;

Expand All @@ -61,7 +58,8 @@ public class RVFDataset<L, F> extends GeneralDataset<L, F> {
double[] means;
double[] stdevs; // means and stdevs of features, used for

final static Redwood.RedwoodChannels logger = Redwood.channels(RVFDataset.class);
/** A logger for this class */
private static final Redwood.RedwoodChannels logger = Redwood.channels(RVFDataset.class);

/*
* Store source and id of each datum; optional, and not fully supported.
Expand Down Expand Up @@ -521,20 +519,22 @@ protected void initialize(int numDatums) {
}

/**
* Prints some summary statistics to stderr for the Dataset.
* Prints some summary statistics to the logger for the Dataset.
*/
@Override
public void summaryStatistics() {
logger.info("numDatums: " + size);
log.info("numLabels: " + labelIndex.size() + " [");
StringBuilder sb = new StringBuilder("numLabels: ");
sb.append(labelIndex.size()).append(" [");
Iterator<L> iter = labelIndex.iterator();
while (iter.hasNext()) {
log.info(iter.next());
sb.append(iter.next());
if (iter.hasNext()) {
log.info(", ");
sb.append(", ");
}
}
logger.info("]");
sb.append(']');
logger.info(sb.toString());
logger.info("numFeatures (Phi(X) types): " + featureIndex.size());
/*for(int i = 0; i < data.length; i++) {
for(int j = 0; j < data[i].length; j++) {
Expand Down
3 changes: 1 addition & 2 deletions src/edu/stanford/nlp/optimization/QNMinimizer.java
Expand Up @@ -100,7 +100,7 @@ public class QNMinimizer implements Minimizer<DiffFunction>, HasEvaluators {
private int mem = 10; // the number of s,y pairs to retain for BFGS
private int its; // = 0; // the number of iterations through the main do-while loop of L-BFGS's minimize()
private final Function monitor;
private boolean quiet;
private boolean quiet; // = false
private static final NumberFormat nf = new DecimalFormat("0.000E0");
private static final NumberFormat nfsec = new DecimalFormat("0.00"); // for times
private static final double ftol = 1e-4; // Linesearch parameters
Expand Down Expand Up @@ -272,7 +272,6 @@ public boolean wasSuccessful() {
public void shutUp() {
this.quiet = true;
}

public void setM(int m) {
mem = m;
}
Expand Down

0 comments on commit 7c77c66

Please sign in to comment.