Browse files

Small bug fix in the extended features. Added support for a simple Cu…

…rriculum learning
  • Loading branch information...
1 parent 6790318 commit 2ec1806e5a48fb52d8c451982e1e5d41b8a20bc5 Sanjeev Satheesh committed May 11, 2012
View
8 USAGE
@@ -124,16 +124,20 @@ Refer run.sh for sample usage.
All lambda are weights on the regularization terms. LAMBDAL is the weight
on the embedding We (Refer the paper for more details)
--lambdaCat
+-lambdaCat LAMBDACAT
All lambda are weights on the regularization terms. LAMBDACAT is the
weight on the classifier weights.
--lambdaRAE
+-lambdaRAE LAMBDARAE
All lambda are weights on the regularization terms. LAMBDARAE is the
weight on the classifier weights. This differs from LAMBDACAT in that
this is applied in the second phase where the RAE is being fine-tuned
(Refer the paper for more details)
+-CurriculumLearning
+ FLAG is set to False by default. Set to True to turn on Curriculum
+ learning. Refer to Bengio,Y ICML 09 for more details.
+
--help
Print this message and quit.
View
12 src/classify/ReviewDatum.java
@@ -26,8 +26,7 @@ public ReviewDatum(Word[] ReviewTokens, int label, int itemNo, int[] indices)
this(ReviewTokens,label,itemNo);
assert ReviewTokens.length == indices.length;
this.Indices = new int[ indices.length ];
- for(int i=0; i<indices.length; i++)
- this.Indices[i] = indices[i];
+ System.arraycopy(indices, 0, this.Indices, 0, indices.length);
}
@@ -43,10 +42,15 @@ public ReviewDatum(String[] ReviewTokens, int label, int itemNo, int[] indices)
this(ReviewTokens,label,itemNo);
assert ReviewTokens.length == indices.length;
this.Indices = new int[ indices.length ];
- for(int i=0; i<indices.length; i++)
- this.Indices[i] = indices[i];
+ System.arraycopy(indices, 0, this.Indices, 0, indices.length);
}
+ public ReviewDatum(int[] indices, int label, int itemNo) {
+ ReviewTokens = null;
+ this.Indices = new int[indices.length];
+ System.arraycopy(indices, 0, this.Indices, 0, indices.length);
+ }
+
public void indexWords(Map<String,Integer> WordsIndexer)
{
Indices = new int[ ReviewTokens.length ];
View
2 src/classify/SoftmaxClassifier.java
@@ -9,6 +9,7 @@
import org.jblas.*;
import util.Counter;
+import util.DoubleMatrixFunctions;
/**
* TODO Make it more generic later
@@ -80,6 +81,7 @@ public Accuracy train(List<LabeledDatum<F,L>> Data)
public Accuracy test(List<LabeledDatum<F,L>> Data)
{
DoubleMatrix Features = makeFeatureMatrix(Data);
+ DoubleMatrixFunctions.prettyPrint(Features);
int[] Labels = makeLabelVector(Data);
CostFunction = new SoftmaxCost (CatSize, ClassifierTheta.FeatureLength, Lambda);
testScores = CostFunction.getPredictions(ClassifierTheta,Features);
View
5 src/main/Arguments.java
@@ -29,7 +29,7 @@
String ProbabilitiesOutputFile = null;
- boolean TrainModel = false;
+ boolean TrainModel = false, CurriculumLearning = false;
int NumFolds = 10, MaxIterations = 80, EmbeddingSize = 50, CatSize = 1;
int DictionarySize, hiddenSize, visibleSize;
double AlphaCat = 0.2, Beta = 0.5;
@@ -42,6 +42,9 @@
public void parseArguments(String[] args) throws IOException {
Map<String, String> argMap = CommandLineUtils
.simpleCommandLineParser(args);
+
+ if (argMap.containsKey("-CurriculumLearning"))
+ CurriculumLearning = Boolean.parseBoolean(argMap.get("-CurriculumLearning"));
if (argMap.containsKey("-minCount"))
minCount = Integer.parseInt(argMap.get("-minCount"));
View
56 src/main/RAEBuilder.java
@@ -175,23 +175,65 @@ private FineTunableTheta train(Arguments params) throws IOException,
InitialTheta = new FineTunableTheta(params.EmbeddingSize,
params.EmbeddingSize, params.CatSize, params.DictionarySize, true);
+ DoubleMatrix InitialWe = InitialTheta.We.dup();
+
+ RAECost RAECost = null;
FineTunableTheta tunedTheta = null;
-
- RAECost RAECost = new RAECost(params.AlphaCat, params.CatSize, params.Beta,
+ Minimizer<DifferentiableFunction> minFunc = null;
+
+ if(params.CurriculumLearning)
+ slowTrain(params, InitialTheta, InitialWe);
+
+ RAECost = new RAECost(params.AlphaCat, params.CatSize, params.Beta,
params.DictionarySize, params.hiddenSize, params.visibleSize,
- params.Lambda, InitialTheta.We, params.Dataset.Data, null, f);
+ params.Lambda, InitialWe, params.Dataset.Data, null, f);
- Minimizer<DifferentiableFunction> minFunc = new QNMinimizer(10,
- params.MaxIterations);
+ minFunc = new QNMinimizer(10, params.MaxIterations);
double[] minTheta = minFunc.minimize(RAECost, 1e-6, InitialTheta.Theta,
params.MaxIterations);
tunedTheta = new FineTunableTheta(minTheta, params.hiddenSize,
params.visibleSize, params.CatSize, params.DictionarySize);
-
+
+
// Important step
- tunedTheta.setWe(tunedTheta.We.add(InitialTheta.We));
+ tunedTheta.setWe(tunedTheta.We.add(InitialWe));
+ return tunedTheta;
+ }
+
+ private FineTunableTheta slowTrain
+ (Arguments params, FineTunableTheta tunedTheta, DoubleMatrix InitialWe){
+
+ CurriculumLearning slowLearner = new CurriculumLearning(params.Dataset);
+ final int MILLION = 10000;
+
+ int [] curriculum = new int[]{2,3,4,6,8,10};
+
+ RAECost RAECost = null;
+ List<LabeledDatum<Integer,Integer>> Data = null;
+ Minimizer<DifferentiableFunction> minFunc = null;
+
+ for (int ngram : curriculum)
+ {
+ Data = slowLearner.getNGrams(ngram, MILLION);
+
+ System.out.println("SLOW LEARNING : " + ngram + " with " + Data.size() + " data points.");
+
+ RAECost = new RAECost(params.AlphaCat, params.CatSize, params.Beta,
+ params.DictionarySize, params.hiddenSize, params.visibleSize,
+ params.Lambda, InitialWe, Data, null, f);
+
+ minFunc = new QNMinimizer(10, params.MaxIterations);
+
+ double[] minTheta = minFunc.minimize(RAECost, 1e-6, tunedTheta.Theta,
+ params.MaxIterations);
+
+ tunedTheta = new FineTunableTheta(minTheta, params.hiddenSize,
+ params.visibleSize, params.CatSize, params.DictionarySize);
+
+ tunedTheta.setWe(tunedTheta.We.add(InitialWe));
+ }
return tunedTheta;
}
View
4 src/rae/LabeledRAETree.java
@@ -109,10 +109,10 @@ public Integer getLabel() {
tf.putColumn(i, T[i].Features);
meanScores = DoubleArrays.addi(meanScores, T[i].scores);
if (T[i].isLeaf()) {
- leafFeatures.add(T[i].Features);
+ leafFeatures.addi(T[i].Features);
leafScores = DoubleArrays.addi(leafScores, T[i].scores);
} else {
- interFeatures.add(T[i].Features);
+ interFeatures.addi(T[i].Features);
interScores = DoubleArrays.addi(interScores, T[i].scores);
}
}
View
3 src/rae/RAEFeatureExtractor.java
@@ -10,6 +10,7 @@
import java.util.*;
import java.util.concurrent.locks.*;
import math.DifferentiableMatrixFunction;
+
import org.jblas.*;
import util.ArraysHelper;
import classify.*;
@@ -137,7 +138,7 @@ protected void writeTree(BufferedWriter treeStructuresStream, LabeledRAETree tre
} catch (IOException e) {
e.printStackTrace();
}
- new TreeDumpThread(tree, treeDumpDir, dataset, (ReviewDatum) data, index);
+ new TreeDumpThread(tree, treeDumpDir, dataset, data, index);
}
public List<LabeledDatum<Double, Integer>> extractFeaturesIntoArray(final List<LabeledDatum<Integer, Integer>> data) {
View
13 src/util/DoubleMatrixFunctions.java
@@ -38,13 +38,16 @@ public static double SquaredNorm(DoubleMatrix inp)
public static void prettyPrint(DoubleMatrix inp)
{
- System.out.println(">>");
+// System.out.println(">>");
for(int i=0; i< Math.min(Integer.MAX_VALUE,inp.rows); i++)
{
- for(int j=0; j< Math.min(Integer.MAX_VALUE,inp.columns); j++)
- System.out.printf("%.4f ", inp.get(i, j));
- System.out.println();
+ double total = 0;
+ for(int j=0; j< Math.min(Integer.MAX_VALUE,inp.columns); j++){
+ System.out.printf("%.4f ", inp.get(i, j));
+ total += Math.abs(inp.get(i,j));
+ }
+ System.out.printf(" [%f]\n", total);
}
- System.out.println("<<");
+// System.out.println("<<");
}
}

0 comments on commit 2ec1806

Please sign in to comment.