Skip to content

Commit

Permalink
We can now dump out the learned trees
Browse files Browse the repository at this point in the history
Also includes features calculated at each internal node. Read USAGE.
  • Loading branch information
Sanjeev Satheesh committed Mar 13, 2012
1 parent 5bfdd36 commit 2f1f5e6
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 125 deletions.
38 changes: 32 additions & 6 deletions USAGE
Expand Up @@ -55,6 +55,32 @@ Refer run.sh for sample usage.
NOTE :: FILE should either not be a txt file or it should not point to the
same directory as the -DataDir option. The next time jrae processes the
directory, it will read in the probabilities as training data.

-TreeDumpDir DIR
Set DIR to point to point to a directory where you want to dump out all
the trees. You can use this parameter both during training and testing.
For each data item, it writes out three files as follows with each line
containing information about a subtree of the RAE model.
(# stands for data item number, with index starting at 1):
* sent#_strings.txt
Each line is of the form <n word1 word2 ... wordn>
n indicates how long the subtree is.

* sent#_classifierOutput.txt
The probability emitted by the classifier indicating which
class it belongs to. It is increasing order of label.
NOTE: The ordering of the labels is listed in "labels.map".

* sent#_nodeVecs.txt
Each line contains the features calculated by the RAE model at
each node. This is the feature of the entire subtree underneath
this node.

There is also a treeStructures.txt file which lists the tree structure
built by the RAE, one data item per line. The first "n" values indicate
the index of the parent of the individual tokens in the data item. The
next "n-1" entries each correspond to the internal nodes built by the
RAE model.

-NumCores NUMCORES
Indicates how many parallel threads to use for feature processing.
Expand Down Expand Up @@ -99,14 +125,14 @@ Refer run.sh for sample usage.
on the embedding We (Refer the paper for more details)

-lambdaCat
All lambda are weights on the regularization terms. LAMBDACAT is the weight
on the classifier weights.
All lambda are weights on the regularization terms. LAMBDACAT is the
weight on the classifier weights.

-lambdaRAE
All lambda are weights on the regularization terms. LAMBDARAE is the weight
on the classifier weights. This differs from LAMBDACAT in that this is applied
in the second phase where the RAE is being fine-tuned (Refer the paper for more
details)
All lambda are weights on the regularization terms. LAMBDARAE is the
weight on the classifier weights. This differs from LAMBDACAT in that
this is applied in the second phase where the RAE is being fine-tuned
(Refer the paper for more details)

--help
Print this message and quit.
Expand Down
9 changes: 8 additions & 1 deletion src/classify/ReviewDatum.java
Expand Up @@ -75,12 +75,19 @@ public Integer getLabel()
@Override
public String toString()
{
String retString = Index + " // ";
String retString = ""; //Index + " // ";
for (int i=0; i<ReviewTokens.length; i++)
retString += ReviewTokens[i] + " ";
return retString;
}

public String getToken (int pos)
{
if (pos > ReviewTokens.length)
System.err.println ("Invalid query for item #" + Index);
return ReviewTokens[pos];
}

public int[] getIndices()
{
return Indices;
Expand Down
15 changes: 15 additions & 0 deletions src/main/Arguments.java
Expand Up @@ -120,6 +120,21 @@ public void parseArguments(String[] args) throws IOException {
if (argMap.containsKey("-FeaturesOutputFile"))
featuresOutputFile = argMap.get("-FeaturesOutputFile");

if (argMap.containsKey("-TreeDumpDir"))
{
TreeDumpDir = argMap.get("-TreeDumpDir");

File treeDumpFile = new File (TreeDumpDir);
if (!treeDumpFile.exists())
treeDumpFile.mkdir();
else if (!treeDumpFile.isDirectory())
{
System.err.println ("TreeDumpDir file exists but it is not a directory.");
exitOnReturn = true;
printUsage();
}
}

if (argMap.containsKey("-ProbabilitiesOutputFile"))
ProbabilitiesOutputFile = argMap.get("-ProbabilitiesOutputFile");

Expand Down
81 changes: 65 additions & 16 deletions src/main/RAEBuilder.java
@@ -1,5 +1,8 @@
package main;

import io.LabeledDataSet;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
Expand All @@ -19,12 +22,15 @@
import classify.Accuracy;
import classify.ClassifierTheta;
import classify.LabeledDatum;
import classify.ReviewDatum;
import classify.SoftmaxClassifier;

import rae.FineTunableTheta;
import rae.RAECost;
import rae.RAEFeatureExtractor;
import rae.LabeledRAETree;
import rae.RAENode;
import util.ArraysHelper;

public class RAEBuilder {
FineTunableTheta InitialTheta;
Expand Down Expand Up @@ -78,18 +84,17 @@ public static void main(final String[] args) throws Exception {
rae.DumpFeatures(params.featuresOutputFile,
classifierTrainingData);

if (params.featuresOutputFile != null)
if (params.ProbabilitiesOutputFile != null)
rae.DumpProbabilities(params.ProbabilitiesOutputFile,
classifier.getTrainScores());

if (params.featuresOutputFile != null)
rae.DumpTrees(params.TreeDumpDir, Trees);
if (params.TreeDumpDir != null)
rae.DumpTrees(Trees, params.TreeDumpDir, params.Dataset, params.Dataset.Data);

} else {
System.out
.println("Using the trained RAE. Model file retrieved from "
+ params.ModelFile
+ "\nNote that this overrides all RAE specific arguments you passed.");
System.out.println
("Using the trained RAE. Model file retrieved from " + params.ModelFile
+ "\nNote that this overrides all RAE specific arguments you passed.");

FineTunableTheta tunedTheta = rae.loadRAE(params);
assert tunedTheta.getNumCategories() == params.Dataset.getCatSize();
Expand All @@ -103,8 +108,7 @@ public static void main(final String[] args) throws Exception {

if (params.Dataset.Data.size() > 0) {
System.err.println("There is training data in the directory.");
System.err
.println("It will be ignored when you are not in the training mode.");
System.err.println("It will be ignored when you are not in the training mode.");
}

List<LabeledRAETree> testTrees = fe.getRAETrees (params.Dataset.TestData);
Expand All @@ -119,9 +123,59 @@ public static void main(final String[] args) throws Exception {
rae.DumpFeatures(params.featuresOutputFile,
classifierTestingData);

if (params.featuresOutputFile != null)
if (params.ProbabilitiesOutputFile != null)
rae.DumpProbabilities(params.ProbabilitiesOutputFile,
classifier.getTestScores());

if (params.TreeDumpDir != null)
rae.DumpTrees(testTrees, params.TreeDumpDir, params.Dataset, params.Dataset.TestData);
}
}

private void DumpTrees( List<LabeledRAETree> trees, String treeDumpDir,
LabeledDataSet<LabeledDatum<Integer, Integer>, Integer, Integer> dataset,
List<LabeledDatum<Integer, Integer>> data) throws Exception {

if (trees.size () != data.size())
throw new Exception ("Inconsistent data!");

File treeStructuresFile = new File (treeDumpDir, "treeStructures.txt");
PrintStream treeStructuresStream = new PrintStream(treeStructuresFile);

for (int i=0; i<trees.size(); i++)
{
LabeledRAETree tree = trees.get(i);
ReviewDatum datum = (ReviewDatum) data.get(i);
int[] parentStructure = tree.getStructureString();

treeStructuresStream.println(ArraysHelper.makeStringFromIntArray(parentStructure));
File vectorsFile = new File (treeDumpDir, "sent"+(i+1)+"_nodeVecs.txt");
PrintStream vectorsStream = new PrintStream(vectorsFile);

File substringsFile = new File (treeDumpDir, "sent"+(i+1)+"_strings.txt");
PrintStream substringsStream = new PrintStream(substringsFile);

File classifierOutputFile = new File (treeDumpDir, "sent"+(i+1)+"_classifierOutput.txt");
PrintStream classifierOutputStream = new PrintStream(classifierOutputFile);

for (RAENode node : tree.getNodes())
{
double[] features = node.getFeatures();
double[] scores = node.getScores();
List<Integer> subTreeWords = node.getSubtreeWordIndices();

String subTreeString = subTreeWords.size() + " ";
for (int pos : subTreeWords)
subTreeString += datum.getToken(pos) + " ";

vectorsStream.println(ArraysHelper.makeStringFromDoubleArray(features));
classifierOutputStream.println(ArraysHelper.makeStringFromDoubleArray(scores));
substringsStream.println(subTreeString);
}

vectorsStream.close();
classifierOutputStream.close();
substringsStream.close();
}
}

Expand Down Expand Up @@ -152,12 +206,7 @@ public void DumpProbabilities(String ProbabilitiesOutputFile,
}
out.close();
}

public void DumpTrees(String TreesDumpDirectory, List<LabeledRAETree> trees)
throws Exception {
throw new Exception("Dumping trees not implemented yet!");
}


private FineTunableTheta train(Arguments params) throws IOException,
ClassNotFoundException {

Expand Down
91 changes: 29 additions & 62 deletions src/rae/LabeledRAETree.java
Expand Up @@ -9,7 +9,7 @@
import java.util.*;

public class LabeledRAETree implements LabeledDatum<Double, Integer>{
Node[] T;
RAENode[] T;
double[] feature;
Structure structure;
int SentenceLength, TreeSize, Label;
Expand All @@ -19,7 +19,7 @@ public LabeledRAETree(int SentenceLength, int Label)
{
this.SentenceLength = SentenceLength;
TreeSize = 2 * SentenceLength - 1;
T = new Node[TreeSize];
T = new RAENode[TreeSize];
structure = new Structure( TreeSize );
this.Label = Label;
}
Expand All @@ -29,30 +29,45 @@ public LabeledRAETree(int SentenceLength, int Label, int HiddenSize, DoubleMatri
this(SentenceLength, Label);
for(int i=0; i<TreeSize; i++)
{
T[i] = new Node(i,SentenceLength,HiddenSize,WordsEmbedded);
T[i] = new RAENode(i,SentenceLength,HiddenSize,WordsEmbedded);
structure.add(new Pair<Integer,Integer>(-1,-1));
}
}

public RAENode[] getNodes ()
{
return T;
}

public LabeledRAETree(int SentenceLength, int Label, int HiddenSize, int CatSize, DoubleMatrix WordsEmbedded)
{
this(SentenceLength, Label);
for(int i=0; i<TreeSize; i++)
{
T[i] = new Node(i,SentenceLength,HiddenSize,CatSize,WordsEmbedded);
T[i] = new RAENode(i,SentenceLength,HiddenSize,CatSize,WordsEmbedded);
structure.add(new Pair<Integer,Integer>(-1,-1));
}
}

public String getStructureString()
public int[] getStructureString()
{
int[] parents = new int[ TreeSize ];
Arrays.fill(parents, -1);

for (int i=TreeSize-1; i>=0; i--)
{
parents[ structure.get(i).getFirst() ] = i;
parents[ structure.get(i).getSecond() ] = i;
int leftChild = structure.get(i).getFirst();
int rightChild = structure.get(i).getSecond();
if (leftChild != -1 && rightChild != -1)
{
if (parents[ leftChild ] != -1
|| parents[ rightChild ] != -1)
System.err.println ("TreeStructure is messed up!");
parents[ leftChild ] = i;
parents[ rightChild ] = i;
}
}
return ArraysHelper.makeStringFromIntArray(parents);
return parents;
}

@Override
Expand Down Expand Up @@ -108,60 +123,12 @@ public Structure(int Capacity)
{
super(Capacity);
}
}

class Node {
Node parent, LeftChild, RightChild;
int NodeName, SubtreeSize;
double[] scores; //, Freq;
DoubleMatrix UnnormalizedFeatures,
Features, LeafFeatures, Z,
DeltaOut1, DeltaOut2, ParentDelta,
catDelta, dW1, dW2, dW3, dW4, dL, Y1C1, Y2C2;

/**
* Specialized Constructor for fitting in that list
* @param NodeIndex
* @param SentenceLength
* @param HiddenSize
* @param WordsEmbedded
*/
public Node(int NodeIndex, int SentenceLength, int HiddenSize, DoubleMatrix WordsEmbedded)
{
NodeName = NodeIndex;
parent = LeftChild = RightChild = null;
scores = null;
// Freq = 0;
SubtreeSize = 0;
if( NodeIndex < SentenceLength )
{
Features = WordsEmbedded.getColumn(NodeIndex);
UnnormalizedFeatures = WordsEmbedded.getColumn(NodeIndex);
}
}

public Node(int NodeIndex, int SentenceLength, int HiddenSize, int CatSize, DoubleMatrix WordsEmbedded)
{
this(NodeIndex,SentenceLength,HiddenSize,WordsEmbedded);
DeltaOut1 = DoubleMatrix.zeros(HiddenSize,1);
DeltaOut2 = DoubleMatrix.zeros(HiddenSize,1);
ParentDelta = DoubleMatrix.zeros(HiddenSize,1);
Y1C1 = DoubleMatrix.zeros(HiddenSize,1);
Y2C2 = DoubleMatrix.zeros(HiddenSize,1);
if( NodeIndex >= SentenceLength )
{
Features = DoubleMatrix.zeros(HiddenSize, 1);
UnnormalizedFeatures = DoubleMatrix.zeros(HiddenSize, 1);
}
}

public boolean isLeaf()
public String toString ()
{
if( LeftChild == null && RightChild == null )
return true;
else if( LeftChild != null && RightChild != null )
return false;
System.err.println("Broken tree, node has one child " + NodeName);
return false;
String retString = "";
for (Pair<Integer,Integer> pii : this)
retString += "<"+pii.getFirst()+","+pii.getSecond()+">";
return retString;
}
}
}

0 comments on commit 2f1f5e6

Please sign in to comment.