Permalink
Browse files

Tree dumping without storing them all in memory

  • Loading branch information...
1 parent b7b11ba commit 7b027bef09c92b7fb1fb11a423fe3fe68a6a407c Sanjeev Satheesh committed Apr 24, 2012
Showing with 126 additions and 112 deletions.
  1. +16 −14 run.sh
  2. +4 −7 src/main/FullRun.java
  3. +10 −68 src/main/RAEBuilder.java
  4. +5 −0 src/rae/RAECost.java
  5. +91 −23 src/rae/RAEFeatureExtractor.java
View
30 run.sh 100644 → 100755
@@ -1,16 +1,18 @@
-
java -jar jar/jrae.jar \
--DataDir data/mov \
--maxIterations 80 \
--ModelFile data/mov/tunedTheta.rae \
--NumCores 4 \
--TrainModel True
+-DataDir data/tiny.mov \
+-MaxIterations 20 \
+-ModelFile data/tiny.mov/tunedTheta.rae \
+-ClassifierFile data/tiny.mov/Softmax.clf \
+-NumCores 3 \
+-TrainModel True \
+-ProbabilitiesOutputFile data/tiny.mov/prob.out \
+-TreeDumpDir data/tiny.mov/trees
-java -jar jar/jrae.jar \
--DataDir data/
--MaxIterations 80
--ModelFile data/mov/tunedTheta.rae
--ClassifierFile data/mov/Softmax.clf
--NumCores 2
--TrainModel False
--ProbabilitiesOutputFile data/tiny/prob.out
+#java -jar jar/jrae.jar \
+#-DataDir data/
+#-MaxIterations 80
+#-ModelFile data/mov/tunedTheta.rae
+#-ClassifierFile data/mov/Softmax.clf
+#-NumCores 2
+#-TrainModel False
+#-ProbabilitiesOutputFile data/tiny/prob.out
View
@@ -8,7 +8,6 @@
import org.jblas.*;
import rae.FineTunableTheta;
-import rae.LabeledRAETree;
import rae.RAECost;
import rae.RAEFeatureExtractor;
@@ -25,7 +24,7 @@ public static void main(final String[] args) throws Exception
Arguments params = new Arguments();
params.parseArguments(args);
- if( params.exitOnReturn )
+ if(params.exitOnReturn)
return;
RAECost RAECost = null;
@@ -80,13 +79,11 @@ public static void main(final String[] args) throws Exception
FeatureExtractor = new RAEFeatureExtractor(params.EmbeddingSize, tunedTheta,
params.AlphaCat, params.Beta, params.CatSize, params.DictionarySize, f);
- List<LabeledRAETree> trainTrees = FeatureExtractor.getRAETrees (trainingData);
List<LabeledDatum<Double, Integer>> classifierTrainingData
- = FeatureExtractor.extractFeaturesIntoArray(trainTrees);
-
- List<LabeledRAETree> testTrees = FeatureExtractor.getRAETrees (testData);
+ = FeatureExtractor.extractFeaturesIntoArray(trainingData);
+
List<LabeledDatum<Double, Integer>> classifierTestingData
- = FeatureExtractor.extractFeaturesIntoArray(testTrees);
+ = FeatureExtractor.extractFeaturesIntoArray(testData);
SoftmaxClassifier<Double,Integer> classifier = new SoftmaxClassifier<Double,Integer>( );
View
@@ -1,8 +1,5 @@
package main;
-import io.LabeledDataSet;
-
-import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
@@ -22,15 +19,12 @@
import classify.Accuracy;
import classify.ClassifierTheta;
import classify.LabeledDatum;
-import classify.ReviewDatum;
+
import classify.SoftmaxClassifier;
import rae.FineTunableTheta;
import rae.RAECost;
import rae.RAEFeatureExtractor;
-import rae.LabeledRAETree;
-import rae.RAENode;
-import util.ArraysHelper;
public class RAEBuilder {
FineTunableTheta InitialTheta;
@@ -67,10 +61,10 @@ public static void main(final String[] args) throws Exception {
params.EmbeddingSize, tunedTheta, params.AlphaCat,
params.Beta, params.CatSize, params.Dataset.Vocab.size(),
rae.f);
-
- List<LabeledRAETree> Trees = fe.getRAETrees (params.Dataset.Data);
- List<LabeledDatum<Double, Integer>> classifierTrainingData = fe.extractFeaturesIntoArray(Trees);
+ List<LabeledDatum<Double, Integer>> classifierTrainingData
+ = fe.extractFeaturesIntoArray(params.Dataset, params.Dataset.Data, params.TreeDumpDir);
+
SoftmaxClassifier<Double, Integer> classifier = new SoftmaxClassifier<Double, Integer>();
Accuracy TrainAccuracy = classifier.train(classifierTrainingData);
System.out.println("Train Accuracy :" + TrainAccuracy.toString());
@@ -88,8 +82,8 @@ public static void main(final String[] args) throws Exception {
rae.DumpProbabilities(params.ProbabilitiesOutputFile,
classifier.getTrainScores());
- if (params.TreeDumpDir != null)
- rae.DumpTrees(Trees, params.TreeDumpDir, params.Dataset, params.Dataset.Data);
+// if (params.TreeDumpDir != null)
+// rae.DumpTrees(Trees, params.TreeDumpDir, params.Dataset, params.Dataset.Data);
System.out.println("Dumping complete");
@@ -113,13 +107,8 @@ public static void main(final String[] args) throws Exception {
System.err.println("It will be ignored when you are not in the training mode.");
}
- List<LabeledRAETree> testTrees = null;
- if (params.Dataset.TestData.size() > 0)
- testTrees = fe.getRAETrees(params.Dataset.TestData);
- else
- testTrees = fe.getRAETrees(params.Dataset.Data);
-
- List<LabeledDatum<Double, Integer>> classifierTestingData = fe.extractFeaturesIntoArray(testTrees);
+ List<LabeledDatum<Double, Integer>> classifierTestingData
+ = fe.extractFeaturesIntoArray(params.Dataset, params.Dataset.TestData, params.TreeDumpDir);
Accuracy TestAccuracy = classifier.test(classifierTestingData);
if (params.isTestLabelsKnown) {
@@ -134,60 +123,13 @@ public static void main(final String[] args) throws Exception {
rae.DumpProbabilities(params.ProbabilitiesOutputFile,
classifier.getTestScores());
- if (params.TreeDumpDir != null)
- rae.DumpTrees(testTrees, params.TreeDumpDir, params.Dataset, params.Dataset.TestData);
+// if (params.TreeDumpDir != null)
+// rae.DumpTrees(testTrees, params.TreeDumpDir, params.Dataset, params.Dataset.TestData);
}
System.exit(0);
}
- private void DumpTrees( List<LabeledRAETree> trees, String treeDumpDir,
- LabeledDataSet<LabeledDatum<Integer, Integer>, Integer, Integer> dataset,
- List<LabeledDatum<Integer, Integer>> data) throws Exception {
-
- if (trees.size () != data.size())
- throw new Exception ("Inconsistent data!");
-
- File treeStructuresFile = new File (treeDumpDir, "treeStructures.txt");
- PrintStream treeStructuresStream = new PrintStream(treeStructuresFile);
-
- for (int i=0; i<trees.size(); i++)
- {
- LabeledRAETree tree = trees.get(i);
- ReviewDatum datum = (ReviewDatum) data.get(i);
- int[] parentStructure = tree.getStructureString();
-
- treeStructuresStream.println(ArraysHelper.makeStringFromIntArray(parentStructure));
- File vectorsFile = new File (treeDumpDir, "sent"+(i+1)+"_nodeVecs.txt");
- PrintStream vectorsStream = new PrintStream(vectorsFile);
-
- File substringsFile = new File (treeDumpDir, "sent"+(i+1)+"_strings.txt");
- PrintStream substringsStream = new PrintStream(substringsFile);
-
- File classifierOutputFile = new File (treeDumpDir, "sent"+(i+1)+"_classifierOutput.txt");
- PrintStream classifierOutputStream = new PrintStream(classifierOutputFile);
-
- for (RAENode node : tree.getNodes())
- {
- double[] features = node.getFeatures();
- double[] scores = node.getScores();
- List<Integer> subTreeWords = node.getSubtreeWordIndices();
-
- String subTreeString = subTreeWords.size() + " ";
- for (int pos : subTreeWords)
- subTreeString += datum.getToken(pos) + " ";
-
- vectorsStream.println(ArraysHelper.makeStringFromDoubleArray(features));
- classifierOutputStream.println(ArraysHelper.makeStringFromDoubleArray(scores));
- substringsStream.println(subTreeString);
- }
-
- vectorsStream.close();
- classifierOutputStream.close();
- substringsStream.close();
- }
- }
-
public void DumpFeatures(String featuresOutputFile,
List<LabeledDatum<Double, Integer>> Features)
throws FileNotFoundException {
View
@@ -85,6 +85,11 @@ public void perform(int index, LabeledDatum<Integer,Integer> Data)
for(int i=0; i<gradRAE.length; i++)
gradient[i] += gradRAE[i];
+ System.gc(); System.gc();
+ System.gc(); System.gc();
+ System.gc(); System.gc();
+ System.gc(); System.gc();
+
return value;
}
}
@@ -1,5 +1,11 @@
package rae;
+import io.LabeledDataSet;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.PrintStream;
import java.util.*;
import java.util.concurrent.locks.*;
import math.DifferentiableMatrixFunction;
@@ -32,45 +38,107 @@ public RAEFeatureExtractor(int HiddenSize, FineTunableTheta Theta, double AlphaC
Propagator = new RAEPropagation(AlphaCat, Beta, HiddenSize, CatSize, DictionaryLength, f);
}
- public List<LabeledDatum<Double,Integer>>
- extractFeaturesIntoArray(List<LabeledRAETree> trees)
- {
- int numExamples = trees.size();
+ private synchronized void dumpTree(LabeledRAETree tree, String treeDumpDir,
+ final LabeledDataSet<LabeledDatum<Integer, Integer>, Integer, Integer> dataset,
+ ReviewDatum datum, int index){
+ try{
+ File treeStructuresFile = new File(treeDumpDir, "treeStructures.txt");
+ FileWriter treeStructuresFileWriter = new FileWriter(treeStructuresFile.getAbsolutePath(), true);
+ BufferedWriter treeStructuresStream = new BufferedWriter(treeStructuresFileWriter);
+
+ int[] parentStructure = tree.getStructureString();
+
+ treeStructuresStream.write(ArraysHelper.makeStringFromIntArray(parentStructure)+"\n");
+ File vectorsFile = new File (treeDumpDir, "sent"+(index)+"_nodeVecs.txt");
+ PrintStream vectorsStream = new PrintStream(vectorsFile);
+
+ File substringsFile = new File (treeDumpDir, "sent"+(index)+"_strings.txt");
+ PrintStream substringsStream = new PrintStream(substringsFile);
+
+ File classifierOutputFile = new File (treeDumpDir, "sent"+(index)+"_classifierOutput.txt");
+ PrintStream classifierOutputStream = new PrintStream(classifierOutputFile);
+
+ for (RAENode node : tree.getNodes())
+ {
+ double[] features = node.getFeatures();
+ double[] scores = node.getScores();
+ List<Integer> subTreeWords = node.getSubtreeWordIndices();
+
+ String subTreeString = subTreeWords.size() + " ";
+ for (int pos : subTreeWords)
+ subTreeString += datum.getToken(pos) + " ";
+
+ vectorsStream.println(ArraysHelper.makeStringFromDoubleArray(features));
+ classifierOutputStream.println(ArraysHelper.makeStringFromDoubleArray(scores));
+ substringsStream.println(subTreeString);
+ }
+ vectorsStream.close();
+ classifierOutputStream.close();
+ substringsStream.close();
+ treeStructuresStream.close();
+ treeStructuresFileWriter.close();
+ }
+ catch(Exception e)
+ {
+ System.err.println(e.getMessage());
+ e.printStackTrace();
+ }
+ }
+
+ public List<LabeledDatum<Double,Integer>> extractFeaturesIntoArray(
+ final LabeledDataSet<LabeledDatum<Integer, Integer>, Integer, Integer> dataset,
+ final List<LabeledDatum<Integer,Integer>> Data,
+ final String treeDumpDir)
+ {
+ final int numExamples = Data.size();
final LabeledDatum<Double,Integer>[] DataFeatures = new ReviewFeatures[numExamples];
+ final boolean dump = (dataset != null && treeDumpDir != null);
- Parallel.For(trees, new Parallel.Operation<LabeledRAETree>(){
+ Parallel.For(Data, new Parallel.Operation<LabeledDatum<Integer,Integer>>() {
@Override
- public void perform(int index, LabeledRAETree tree) {
+ public void perform(int index, LabeledDatum<Integer, Integer> data) {
+ LabeledRAETree tree = getRAETree(Propagator, data);
+ if(dump)
+ dumpTree(tree, treeDumpDir, dataset, (ReviewDatum) data, index);
double[] feature = tree.getFeaturesVector();
lock.lock();
{
ReviewFeatures r =
- new ReviewFeatures (null, tree.getLabel(), index, feature);
- DataFeatures[index] = r;
+ new ReviewFeatures (null, data.getLabel(), index, feature);
+ DataFeatures[index] = r;
}
lock.unlock();
+ System.gc();
}
- });
+ });
return Arrays.asList(DataFeatures);
}
+ public List<LabeledDatum<Double,Integer>> extractFeaturesIntoArray(
+ final List<LabeledDatum<Integer, Integer>> data)
+ {
+ return extractFeaturesIntoArray(null, data, null);
+ }
+
public DoubleMatrix extractFeatures(List<LabeledDatum<Integer,Integer>> Data)
{
int numExamples = Data.size();
features = DoubleMatrix.zeros(2*HiddenSize,numExamples);
- ThreadPool.map (Data, Propagator,
- new ThreadPool.Operation<RAEPropagation, LabeledDatum<Integer,Integer>>() {
- public void perform(RAEPropagation locPropagator, int index,
- LabeledDatum<Integer,Integer> data)
- {
- double[] feature = extractFeatures(locPropagator, data);
+
+ Parallel.For(Data,
+ new Parallel.Operation<LabeledDatum<Integer,Integer>>() {
+ @Override
+ public void perform(int index, LabeledDatum<Integer, Integer> data) {
+ double[] feature = extractFeatures(Propagator, data);
lock.lock();
{
features.putColumn(index, new DoubleMatrix(feature));
}
lock.unlock();
}
- });
+ }
+ );
+
return features;
}
@@ -89,19 +157,19 @@ public void perform(RAEPropagation locPropagator, int index,
int numExamples = Data.size();
final LabeledRAETree[] ExtractedTrees = new LabeledRAETree[numExamples];
- ThreadPool.map (Data, Propagator,
- new ThreadPool.Operation<RAEPropagation, LabeledDatum<Integer,Integer>>() {
- public void perform(RAEPropagation locPropagator, int index,
- LabeledDatum<Integer,Integer> data)
- {
- LabeledRAETree tree = getRAETree(locPropagator, data);
+ Parallel.For(Data,
+ new Parallel.Operation<LabeledDatum<Integer,Integer>>() {
+ @Override
+ public void perform(int index, LabeledDatum<Integer, Integer> data) {
+ LabeledRAETree tree = getRAETree(Propagator, data);
lock.lock();
{
ExtractedTrees[index] = tree;
}
lock.unlock();
}
- });
+ }
+ );
return Arrays.asList(ExtractedTrees);
}

0 comments on commit 7b027be

Please sign in to comment.