Skip to content

Commit

Permalink
removing trainClassifier(List<RVFDatum<...
Browse files Browse the repository at this point in the history
  • Loading branch information
Grace Muzny authored and Stanford NLP committed Sep 1, 2015
1 parent af0b447 commit 488d89f
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 173 deletions.
Expand Up @@ -40,12 +40,6 @@ int numClasses() {
return labelIndex.size(); return labelIndex.size();
} }


public Classifier<L,F> trainClassifier(List<RVFDatum<L, F>> examples) {
Dataset<L, F> dataset = new Dataset<L, F>();
dataset.addAll(examples);
return trainClassifier(dataset);
}

protected abstract double[][] trainWeights(GeneralDataset<L, F> dataset) ; protected abstract double[][] trainWeights(GeneralDataset<L, F> dataset) ;


/** /**
Expand Down
3 changes: 0 additions & 3 deletions src/edu/stanford/nlp/classify/ClassifierFactory.java
Expand Up @@ -16,9 +16,6 @@


public interface ClassifierFactory<L, F, C extends Classifier<L, F>> extends Serializable { public interface ClassifierFactory<L, F, C extends Classifier<L, F>> extends Serializable {


@Deprecated //ClassifierFactory should implement trainClassifier(GeneralDataset) instead.
public C trainClassifier(List<RVFDatum<L, F>> examples);

public C trainClassifier(GeneralDataset<L,F> dataset); public C trainClassifier(GeneralDataset<L,F> dataset);


} }
6 changes: 0 additions & 6 deletions src/edu/stanford/nlp/classify/LinearClassifierFactory.java
Expand Up @@ -989,12 +989,6 @@ public static LinearClassifier<String, String> loadFromFilename(String file) {
} }
} }


@Deprecated
@Override
public LinearClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
throw new UnsupportedOperationException("Unsupported deprecated method");
}

public void setEvaluators(int iters, Evaluator[] evaluators) public void setEvaluators(int iters, Evaluator[] evaluators)
{ {
this.evalIters = iters; this.evalIters = iters;
Expand Down
6 changes: 0 additions & 6 deletions src/edu/stanford/nlp/classify/LogisticClassifierFactory.java
Expand Up @@ -108,10 +108,4 @@ else if(data instanceof RVFDataset<?,?>)
return new LogisticClassifier<L,F>(weights,featureIndex,classes); return new LogisticClassifier<L,F>(weights,featureIndex,classes);
} }


@Deprecated //this method no longer required by the ClassifierFactory Interface.
public LogisticClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
// TODO Auto-generated method stub
return null;
}

} }
183 changes: 78 additions & 105 deletions src/edu/stanford/nlp/classify/NaiveBayesClassifierFactory.java
Expand Up @@ -34,6 +34,7 @@
import edu.stanford.nlp.optimization.QNMinimizer; import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index; import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.Pair;
Expand Down Expand Up @@ -98,43 +99,11 @@ private NaiveBayesClassifier<L, F> trainClassifier(int[][] data, int[] labels, i


} }


/**
* The examples are assumed to be a list of RFVDatum.
* The datums are assumed to contain the zeroes as well.
*/
@Override
@Deprecated
public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
RVFDatum<L, F> d0 = examples.get(0);
int numFeatures = d0.asFeatures().size();
int[][] data = new int[examples.size()][numFeatures];
int[] labels = new int[examples.size()];
labelIndex = new HashIndex<L>();
featureIndex = new HashIndex<F>();
for (int d = 0; d < examples.size(); d++) {
RVFDatum<L, F> datum = examples.get(d);
Counter<F> c = datum.asFeaturesCounter();
for (F feature: c.keySet()) {
if(featureIndex.add(feature)) {
int fNo = featureIndex.indexOf(feature);
int value = (int) c.getCount(feature);
data[d][fNo] = value;
}
}
labelIndex.add(datum.label());
labels[d] = labelIndex.indexOf(datum.label());

}
int numClasses = labelIndex.size();
return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex);
}


/** /**
* The examples are assumed to be a list of RFVDatum. * The examples are assumed to be a list of RFVDatum.
* The datums are assumed to not contain the zeroes and then they are added to each instance. * The datums are assumed to not contain the zeroes and then they are added to each instance.
*/ */
public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples, Set<F> featureSet) { public NaiveBayesClassifier<L, F> trainClassifier(GeneralDataset<L, F> examples, Set<F> featureSet) {
int numFeatures = featureSet.size(); int numFeatures = featureSet.size();
int[][] data = new int[examples.size()][numFeatures]; int[][] data = new int[examples.size()][numFeatures];
int[] labels = new int[examples.size()]; int[] labels = new int[examples.size()];
Expand All @@ -144,7 +113,7 @@ public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples,
featureIndex.add(feat); featureIndex.add(feat);
} }
for (int d = 0; d < examples.size(); d++) { for (int d = 0; d < examples.size(); d++) {
RVFDatum<L, F> datum = examples.get(d); RVFDatum<L, F> datum = examples.getRVFDatum(d);
Counter<F> c = datum.asFeaturesCounter(); Counter<F> c = datum.asFeaturesCounter();
for (F feature : c.keySet()) { for (F feature : c.keySet()) {
int fNo = featureIndex.indexOf(feature); int fNo = featureIndex.indexOf(feature);
Expand Down Expand Up @@ -292,77 +261,81 @@ static class NBWeights {
} }
} }


public static void main(String[] args) { // public static void main(String[] args) {

// List examples = new ArrayList();
/* // String leftLight = "leftLight";
List examples = new ArrayList(); // String rightLight = "rightLight";
String leftLight = "leftLight"; // String broken = "BROKEN";
String rightLight = "rightLight"; // String ok = "OK";
String broken = "BROKEN"; // Counter c1 = new ClassicCounter<>();
String ok = "OK"; // c1.incrementCount(leftLight, 0);
Counter c1 = new Counter(); // c1.incrementCount(rightLight, 0);
c1.incrementCount(leftLight, 0); // RVFDatum d1 = new RVFDatum(c1, broken);
c1.incrementCount(rightLight, 0); // examples.add(d1);
RVFDatum d1 = new RVFDatum(c1, broken); // Counter c2 = new ClassicCounter<>();
examples.add(d1); // c2.incrementCount(leftLight, 1);
Counter c2 = new Counter(); // c2.incrementCount(rightLight, 1);
c2.incrementCount(leftLight, 1); // RVFDatum d2 = new RVFDatum(c2, ok);
c2.incrementCount(rightLight, 1); // examples.add(d2);
RVFDatum d2 = new RVFDatum(c2, ok); // Counter c3 = new ClassicCounter<>();
examples.add(d2); // c3.incrementCount(leftLight, 0);
Counter c3 = new Counter(); // c3.incrementCount(rightLight, 1);
c3.incrementCount(leftLight, 0); // RVFDatum d3 = new RVFDatum(c3, ok);
c3.incrementCount(rightLight, 1); // examples.add(d3);
RVFDatum d3 = new RVFDatum(c3, ok); // Counter c4 = new ClassicCounter<>();
examples.add(d3); // c4.incrementCount(leftLight, 1);
Counter c4 = new Counter(); // c4.incrementCount(rightLight, 0);
c4.incrementCount(leftLight, 1); // RVFDatum d4 = new RVFDatum(c4, ok);
c4.incrementCount(rightLight, 0); // examples.add(d4);
RVFDatum d4 = new RVFDatum(c4, ok); // Dataset data = new Dataset(examples.size());
examples.add(d4); // data.addAll(examples);
NaiveBayesClassifier classifier = (NaiveBayesClassifier) new NaiveBayesClassifierFactory(200, 200, 1.0, LogPrior.QUADRATIC.ordinal(), NaiveBayesClassifierFactory.CL).trainClassifier(examples); // NaiveBayesClassifier classifier = (NaiveBayesClassifier)
classifier.print(); // new NaiveBayesClassifierFactory(200, 200, 1.0,
//now classifiy // LogPrior.LogPriorType.QUADRATIC.ordinal(),
for (int i = 0; i < examples.size(); i++) { // NaiveBayesClassifierFactory.CL)
RVFDatum d = (RVFDatum) examples.get(i); // .trainClassifier(data);
Counter scores = classifier.scoresOf(d); // classifier.print();
System.out.println("for datum " + d + " scores are " + scores.toString()); // //now classifiy
System.out.println(" class is " + scores.argmax()); // for (int i = 0; i < examples.size(); i++) {
} // RVFDatum d = (RVFDatum) examples.get(i);
// Counter scores = classifier.scoresOf(d);
} // System.out.println("for datum " + d + " scores are " + scores.toString());
*/ // System.out.println(" class is " + Counters.topKeys(scores, 1));
String trainFile = args[0]; // System.out.println(" class should be " + d.label());
String testFile = args[1]; // }
NominalDataReader nR = new NominalDataReader(); // }
Map<Integer, Index<String>> indices = Generics.newHashMap();
List<RVFDatum<String, Integer>> train = nR.readData(trainFile, indices);
List<RVFDatum<String, Integer>> test = nR.readData(testFile, indices); // String trainFile = args[0];
System.out.println("Constrained conditional likelihood no prior :"); // String testFile = args[1];
for (int j = 0; j < 100; j++) { // NominalDataReader nR = new NominalDataReader();
NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.CL).trainClassifier(train); // Map<Integer, Index<String>> indices = Generics.newHashMap();
classifier.print(); // List<RVFDatum<String, Integer>> train = nR.readData(trainFile, indices);
//now classifiy // List<RVFDatum<String, Integer>> test = nR.readData(testFile, indices);

// System.out.println("Constrained conditional likelihood no prior :");
float accTrain = classifier.accuracy(train.iterator()); // for (int j = 0; j < 100; j++) {
System.err.println("training accuracy " + accTrain); // NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.CL).trainClassifier(train);
float accTest = classifier.accuracy(test.iterator()); // classifier.print();
System.err.println("test accuracy " + accTest); // //now classifiy

//
} // float accTrain = classifier.accuracy(train.iterator());
System.out.println("Unconstrained conditional likelihood no prior :"); // System.err.println("training accuracy " + accTrain);
for (int j = 0; j < 100; j++) { // float accTest = classifier.accuracy(test.iterator());
NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.UCL).trainClassifier(train); // System.err.println("test accuracy " + accTest);
classifier.print(); //
//now classify // }

// System.out.println("Unconstrained conditional likelihood no prior :");
float accTrain = classifier.accuracy(train.iterator()); // for (int j = 0; j < 100; j++) {
System.err.println("training accuracy " + accTrain); // NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.UCL).trainClassifier(train);
float accTest = classifier.accuracy(test.iterator()); // classifier.print();
System.err.println("test accuracy " + accTest); // //now classify

//
} // float accTrain = classifier.accuracy(train.iterator());
} // System.err.println("training accuracy " + accTrain);
// float accTest = classifier.accuracy(test.iterator());
// System.err.println("test accuracy " + accTest);
// }
// }


@Override @Override
public NaiveBayesClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) { public NaiveBayesClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) {
Expand Down
6 changes: 0 additions & 6 deletions src/edu/stanford/nlp/classify/SVMLightClassifierFactory.java
Expand Up @@ -317,12 +317,6 @@ public void heldOutSetC(final GeneralDataset<L, F> trainSet, final GeneralDatase
useSigmoid = oldUseSigmoid; useSigmoid = oldUseSigmoid;
} }


@Deprecated
public SVMLightClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
// TODO Auto-generated method stub
return null;
}

private boolean tuneHeldOut = false; private boolean tuneHeldOut = false;
private boolean tuneCV = false; private boolean tuneCV = false;
private Scorer<L> scorer = new MultiClassAccuracyStats<L>(); private Scorer<L> scorer = new MultiClassAccuracyStats<L>();
Expand Down
Expand Up @@ -97,10 +97,5 @@ private int[][] convertLabels(int[] labels) {
} }
return result; return result;
} }


@Override
@Deprecated
public MultinomialLogisticClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
return null;
}
} }
63 changes: 28 additions & 35 deletions src/edu/stanford/nlp/optimization/QNMinimizer.java
Expand Up @@ -13,6 +13,8 @@
import edu.stanford.nlp.io.RuntimeIOException; import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.math.ArrayMath; import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.CallbackFunction; import edu.stanford.nlp.util.CallbackFunction;
import edu.stanford.nlp.util.Generics;



/** /**
* *
Expand Down Expand Up @@ -1209,62 +1211,53 @@ public void useOWLQN(boolean use, double lambda) {
this.lambdaOWL = lambda; this.lambdaOWL = lambda;
} }


private static double[] projectOWL(double[] x, double[] orthant, Function func) { private static Set<Integer> initializeParamRange(Function func, double[] x) {
Set<Integer> paramRange;
if (func instanceof HasRegularizerParamRange) { if (func instanceof HasRegularizerParamRange) {
Set<Integer> paramRange = ((HasRegularizerParamRange)func).getRegularizerParamRange(x); paramRange = ((HasRegularizerParamRange)func).getRegularizerParamRange(x);
for (int i : paramRange) { } else {
if (x[i] * orthant[i] <= 0) paramRange = Generics.newHashSet(x.length);
x[i] = 0; for (int i = 0; i < x.length; i++) {
paramRange.add(i);
} }
} else { }
for (int i=0; i!=x.length; ++i){ return paramRange;
if (x[i] * orthant[i] <= 0) }
x[i] = 0;
} private static double[] projectOWL(double[] x, double[] orthant, Function func) {
} Set<Integer> paramRange = initializeParamRange(func, x);
for (int i : paramRange) {
if (x[i] * orthant[i] <= 0)
x[i] = 0;
}
return x; return x;
} }


private static double l1NormOWL(double[] x, Function func) { private static double l1NormOWL(double[] x, Function func) {
Set<Integer> paramRange = initializeParamRange(func, x);
double sum = 0.0; double sum = 0.0;
if (func instanceof HasRegularizerParamRange) { for (int i: paramRange) {
Set<Integer> paramRange = ((HasRegularizerParamRange)func).getRegularizerParamRange(x); sum += Math.abs(x[i]);
for (int i: paramRange) {
sum += Math.abs(x[i]);
}
} else {
for (double v : x){
sum += Math.abs(v);
}
} }
return sum; return sum;
} }


private static void constrainSearchDir(double[] dir, double[] fg, double[] x, Function func) { private static void constrainSearchDir(double[] dir, double[] fg, double[] x, Function func) {
if (func instanceof HasRegularizerParamRange) { Set<Integer> paramRange = initializeParamRange(func, x);
Set<Integer> paramRange = ((HasRegularizerParamRange)func).getRegularizerParamRange(x); for (int i: paramRange) {
for (int i: paramRange) { if (dir[i] * fg[i] >= 0.0) {
if (dir[i] * fg[i] >= 0.0) { dir[i] = 0.0;
dir[i] = 0.0;
}
}
} else {
for (int i=0; i!=x.length; ++i){
if (dir[i] * fg[i] >= 0.0) {
dir[i] = 0.0;
}
} }
} }
} }


private double[] pseudoGradientOWL(double[] x, double[] grad, Function func) { private double[] pseudoGradientOWL(double[] x, double[] grad, Function func) {
Set<Integer> paramRange = func instanceof HasRegularizerParamRange ? Set<Integer> paramRange = initializeParamRange(func, x); // initialized below
((HasRegularizerParamRange)func).getRegularizerParamRange(x) : null ; // initialized below
double[] newGrad = new double[grad.length]; double[] newGrad = new double[grad.length];


// compute pseudo gradient // compute pseudo gradient
for (int i = 0; i < x.length; i++) { for (int i = 0; i < x.length; i++) {
if (paramRange == null || paramRange.contains(i)) { if (paramRange.contains(i)) {
if (x[i] < 0.0) { if (x[i] < 0.0) {
// Differentiable // Differentiable
newGrad[i] = grad[i] - lambdaOWL; newGrad[i] = grad[i] - lambdaOWL;
Expand Down

0 comments on commit 488d89f

Please sign in to comment.