Skip to content

Commit

Permalink
Made the rae_cost generic to support multi-class.
Browse files Browse the repository at this point in the history
But the gradient check does not pass still.
  • Loading branch information
Sanjeev Satheesh committed Mar 9, 2012
1 parent 804c4f6 commit 6d840c9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 28 deletions.
12 changes: 9 additions & 3 deletions src/classify/SoftmaxCost.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package classify;

import math.*;

import org.jblas.*;

import util.*;

public class SoftmaxCost extends MemoizedDifferentiableFunction
Expand Down Expand Up @@ -77,7 +75,15 @@ public DoubleMatrix getPredictions (ClassifierTheta Theta, DoubleMatrix Features
int numDataItems = Features.columns;
DoubleMatrix Input = ((Theta.W.transpose()).mmul(Features)).addColumnVector(Theta.b);
Input = DoubleMatrix.concatVertically(Input, DoubleMatrix.zeros(1,numDataItems));
return Activation.valueAt(Input);
return Activation.valueAt (Input);
}

public DoubleMatrix getGradient (ClassifierTheta Theta, DoubleMatrix Features)
{
int numDataItems = Features.columns;
DoubleMatrix Input = ((Theta.W.transpose()).mmul(Features)).addColumnVector(Theta.b);
Input = DoubleMatrix.concatVertically(Input, DoubleMatrix.zeros(1,numDataItems));
return Activation.derivativeAt (Input);
}

private double getNetLogLoss (DoubleMatrix Prediction, DoubleMatrix Labels)
Expand Down
2 changes: 1 addition & 1 deletion src/rae/RAECostTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void test() throws Exception
dataset.add( new ReviewDatum(new String[]{"tmp2"}, 1, 1, data2));
double[] lambda = new double[]{1e-05, 0.0001, 1e-05, 0.01};

RAECost cost = new RAECost(alphaCat, 1, beta, DictionarySize, hiddenSize, hiddenSize,
RAECost cost = new RAECost(alphaCat, 2, beta, DictionarySize, hiddenSize, hiddenSize,
lambda, DoubleMatrix.zeros(hiddenSize, DictionarySize), dataset, null, f);

assertTrue( GradientChecker.check(cost) );
Expand Down
59 changes: 35 additions & 24 deletions src/rae/RAEPropagation.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import math.*;

import org.jblas.*;

import classify.ClassifierTheta;
import classify.SoftmaxCost;
import util.*;
import java.util.*;

Expand Down Expand Up @@ -160,11 +163,13 @@ public LabeledRAETree ForwardPropagate(Theta theta, DoubleMatrix WordsEmbedded,
public LabeledRAETree ForwardPropagate(FineTunableTheta theta,
DoubleMatrix WordsEmbedded, FloatMatrix Freq, int CurrentLabel,
int SentenceLength, Structure TreeStructure) {
int CatSize = theta.Wcat.rows;
CatSize = theta.Wcat.rows;
int TreeSize = 2 * SentenceLength - 1;
LabeledRAETree tree = new LabeledRAETree(SentenceLength, CurrentLabel, HiddenSize, CatSize, WordsEmbedded);
int[] SubtreeSize = new int[TreeSize];
DoubleMatrix labelVector = makeLabelVector (CurrentLabel);
int[] requiredEntries = ArraysHelper.makeArray(0, CatSize-1);
DoubleMatrix Labels = makeLabelVector (CurrentLabel, SentenceLength);
DoubleMatrix LabelVector = Labels.getColumn(0);

for (int i = SentenceLength; i < TreeSize; i++) {
int LeftChild = TreeStructure.get(i).getFirst(),
Expand All @@ -173,20 +178,20 @@ public LabeledRAETree ForwardPropagate(FineTunableTheta theta,
SubtreeSize[i] = SubtreeSize[LeftChild] + SubtreeSize[RightChild];
}

// classifier on single words
DifferentiableMatrixFunction SigmoidCalc = CatSize > 1 ? new Softmax() : new Sigmoid();

DoubleMatrix Input = theta.Wcat.mmul(WordsEmbedded).addColumnVector(theta.bcat);
DoubleMatrix SM = SigmoidCalc.valueAt(Input);
DoubleMatrix Diff = SM.sub(CurrentLabel);
// We only use Soft-max to get the prediction, the error is calculated differently!
SoftmaxCost softmaxCalc = new SoftmaxCost (CatSize, HiddenSize, 0);
ClassifierTheta ClassifierTheta = theta.getClassifierParameters();
DoubleMatrix Predictions = softmaxCalc.getPredictions(ClassifierTheta, WordsEmbedded);
DoubleMatrix Diff = Predictions.sub(Labels);
DoubleMatrix SquaredError = (Diff.mul(Diff)).mul((1 - AlphaCat) * 0.5f);
DoubleMatrix ErrorGradient = Diff.mul(1 - AlphaCat).mul(SigmoidCalc.derivativeAt(Input));
DoubleMatrix ErrorGradient = Diff.mul(1 - AlphaCat).mul(
softmaxCalc.getGradient(ClassifierTheta, WordsEmbedded));

for (int i = 0; i < TreeSize; i++) {
Node CurrentNode = tree.T[i];
if (i < SentenceLength) {
CurrentNode.scores = SquaredError.getColumn(i).data;
CurrentNode.catDelta = ErrorGradient.getColumn(i);
CurrentNode.catDelta = ErrorGradient.getColumn(i).getRows(requiredEntries);
tree.TotalScore += SquaredError.getColumn(i).sum();
} else {
int LeftChild = TreeStructure.get(i).getFirst(), RightChild = TreeStructure
Expand All @@ -209,12 +214,12 @@ public LabeledRAETree ForwardPropagate(FineTunableTheta theta,
CurrentNode.Features = pNorm1;

// Eq. (7) in the paper (for special case of 1d label)
Input = (theta.Wcat.mmul(pNorm1)).addColumnVector(theta.bcat);
SM = SigmoidCalc.valueAt(Input);
Diff = SM.subColumnVector(labelVector);
DoubleMatrix Prediction = softmaxCalc.getPredictions(ClassifierTheta, pNorm1);
Diff = Prediction.sub(LabelVector);
CurrentNode.catDelta = (Diff.mul(Beta * (1 - AlphaCat)))
.mul(SigmoidCalc.derivativeAt(Input));
CurrentNode.scores = SM.data;
.mul(softmaxCalc.getGradient(ClassifierTheta, pNorm1))
.getRows(requiredEntries);
CurrentNode.scores = Prediction.data;
tree.TotalScore += DoubleMatrixFunctions.SquaredNorm(Diff) * 0.5 * Beta * (1 - AlphaCat);

CurrentNode.SubtreeSize = SubtreeSize[i];
Expand Down Expand Up @@ -302,7 +307,10 @@ public void BackPropagate(LabeledRAETree tree, FineTunableTheta theta,
System.err.println("Bad Tree for backpropagation!");

DoubleMatrix GL = DoubleMatrix.zeros(HiddenSize, SentenceLength);

//TODO Move this row of zeros into Wcat!
// DoubleMatrix TWcat = DoubleMatrix.concatVertically (
// theta.Wcat.dup (), DoubleMatrix.zeros(1, HiddenSize));

// Stack of currentNode, Left(1) or Right(2), Parent Node pointer
Stack<Triplet<Node, Integer, Node>> ToPopulate = new Stack<Triplet<Node, Integer, Node>>();

Expand Down Expand Up @@ -339,9 +347,9 @@ public void BackPropagate(LabeledRAETree tree, FineTunableTheta theta,
DoubleMatrix PD = CurrentNode.ParentDelta;

DoubleMatrix Activation = ((theta.W3.transpose()).mmul(ND1))
.addi((theta.W4.transpose()).mmul(ND2));
Activation.addi(((NodeW.transpose()).mmul(PD)).addi((theta.Wcat
.transpose()).mmul(CurrentNode.catDelta)));
.addi(theta.W4.transpose().mmul(ND2))
.addi(NodeW.transpose().mmul(PD))
.addi(theta.Wcat.transpose().mmul(CurrentNode.catDelta));
Activation.subi(delta);
DoubleMatrix CurrentDelta = f.derivativeAt(A1).mmul(Activation);

Expand Down Expand Up @@ -372,12 +380,15 @@ public void BackPropagate(LabeledRAETree tree, FineTunableTheta theta,
incrementWordEmbedding(GL,WordsIndexed);
}

private DoubleMatrix makeLabelVector (int label)
private DoubleMatrix makeLabelVector (int label, int numDataItems)
{
DoubleMatrix labelVector = DoubleMatrix.zeros (CatSize);
if (label > 0 && label <= CatSize)
labelVector.put(label-1, 0, 1);
return labelVector;
if (label < 0 || label > CatSize)
System.err.println("makeLabelVector : over CatSize ->" + CatSize + "<" + label);

DoubleMatrix LabelRepresentation = DoubleMatrix.zeros(1+CatSize,numDataItems);
for (int i = 0; i < numDataItems; i++)
LabelRepresentation.put(label, i, 1);
return LabelRepresentation;
}

private synchronized void incrementWordEmbedding(DoubleMatrix GL, int[] WordsIndexed)
Expand Down

0 comments on commit 6d840c9

Please sign in to comment.