Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Feature extraction also made faster.

  • Loading branch information...
commit c1c08b0a6008b8920cb0a9a659c4b2609611893b 1 parent 31176cf
Sanjeev Satheesh authored
View
4 src/math/Norm1Tanh.java
@@ -14,7 +14,7 @@
* diag(1-x.^2)./nrm - y*x'./nrm^3
*/
@Override
- public synchronized DoubleMatrix derivativeAt(DoubleMatrix M)
+ public DoubleMatrix derivativeAt(DoubleMatrix M)
{
double norm = M.norm2();
DoubleMatrix Squared = M.mul(M);
@@ -31,7 +31,7 @@ public synchronized DoubleMatrix derivativeAt(DoubleMatrix M)
* as in the original matlab file.
*/
@Override
- public synchronized DoubleMatrix valueAt(DoubleMatrix M)
+ public DoubleMatrix valueAt(DoubleMatrix M)
{
return MatrixFunctions.tanh(M);
}
View
4 src/parallel/ParallelTest.java
@@ -7,12 +7,8 @@
import java.util.Random;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
-
import org.junit.Test;
-import rae.RAEPropagation;
-import classify.LabeledDatum;
-
import util.Reducible;
public class ParallelTest {
View
13 src/parallel/ThreadPool.java
@@ -16,10 +16,9 @@ public static void setPoolSize(int poolSize)
{
ThreadPool.poolSize = poolSize;
}
-
-// @SuppressWarnings("unchecked")
- public static <T,E extends Reducible<E>,F> E mapReduce
- (final Collection<T> pElements, final E Operator, final Operation<E,T> pOperation) {
+
+ public static <T,E extends Reducible<E>,F> Collection<E> map
+ (final Collection<T> pElements, final E Operator, final Operation<E,T> pOperation) {
final LinkedList<E> queue = new LinkedList<E>();
for (int i=0; i<ThreadPool.poolSize; i++)
@@ -79,6 +78,12 @@ public void run() {
System.err.println ("Some data processing was lost! " + "Only " +
queue.size() + " processors of " + poolSize + " exists now");
+ return queue;
+ }
+
+ public static <T,E extends Reducible<E>,F> E mapReduce
+ (final Collection<T> pElements, final E Operator, final Operation<E,T> pOperation) {
+ Collection<E> queue = map(pElements, Operator, pOperation);
Reducer<E> accumulator = new Reducer<E>();
return accumulator.reduce(queue);
}
View
64 src/rae/RAEFeatureExtractor.java
@@ -9,19 +9,27 @@
import parallel.*;
public class RAEFeatureExtractor {
- int HiddenSize;
+ int HiddenSize, CatSize, DictionaryLength;
+ double AlphaCat, Beta;
FineTunableTheta Theta;
RAEPropagation Propagator;
DoubleMatrix features;
Lock lock;
+ DifferentiableMatrixFunction f;
public RAEFeatureExtractor(int HiddenSize, FineTunableTheta Theta, double AlphaCat, double Beta,
int CatSize, int DictionaryLength, DifferentiableMatrixFunction f)
{
this.HiddenSize = HiddenSize;
this.Theta = Theta;
- Propagator = new RAEPropagation(AlphaCat, Beta, HiddenSize, CatSize, DictionaryLength, f);
+ this.AlphaCat = AlphaCat;
+ this.Beta = Beta;
+ this.HiddenSize = HiddenSize;
+ this.CatSize = CatSize;
+ this.DictionaryLength = DictionaryLength;
+ this.f = f;
lock = new ReentrantLock();
+ Propagator = new RAEPropagation(AlphaCat, Beta, HiddenSize, CatSize, DictionaryLength, f);
}
public List<LabeledDatum<Double,Integer>>
@@ -49,25 +57,31 @@ public void perform(int index, LabeledRAETree tree) {
public DoubleMatrix extractFeatures(List<LabeledDatum<Integer,Integer>> Data)
{
int numExamples = Data.size();
- features = DoubleMatrix.zeros(2*HiddenSize,numExamples);
-
- Parallel.For(Data, new Parallel.Operation<LabeledDatum<Integer,Integer>>(){
- @Override
- public void perform(int index, LabeledDatum<Integer, Integer> data) {
- double[] feature = extractFeatures(data);
- lock.lock();
+ features = DoubleMatrix.zeros(2*HiddenSize,numExamples);
+ ThreadPool.map (Data, Propagator,
+ new ThreadPool.Operation<RAEPropagation, LabeledDatum<Integer,Integer>>() {
+ public void perform(RAEPropagation locPropagator, int index,
+ LabeledDatum<Integer,Integer> data)
{
- features.putColumn(index, new DoubleMatrix(feature));
+ double[] feature = extractFeatures(locPropagator, data);
+ lock.lock();
+ {
+ features.putColumn(index, new DoubleMatrix(feature));
+ }
+ lock.unlock();
}
- lock.unlock();
- }
- });
+ });
return features;
}
public double[] extractFeatures (LabeledDatum<Integer,Integer> Data)
{
- return getRAETree (Data).getFeaturesVector();
+ return getRAETree (Propagator, Data).getFeaturesVector();
+ }
+
+ public double[] extractFeatures (RAEPropagation Propagator, LabeledDatum<Integer,Integer> Data)
+ {
+ return getRAETree (Propagator, Data).getFeaturesVector();
}
public List<LabeledRAETree> getRAETrees(List<LabeledDatum<Integer,Integer>> Data)
@@ -75,21 +89,23 @@ public void perform(int index, LabeledDatum<Integer, Integer> data) {
int numExamples = Data.size();
final LabeledRAETree[] ExtractedTrees = new LabeledRAETree[numExamples];
- Parallel.For(Data, new Parallel.Operation<LabeledDatum<Integer,Integer>>(){
- @Override
- public void perform(int index, LabeledDatum<Integer, Integer> data) {
- LabeledRAETree tree = getRAETree(data);
- lock.lock();
+ ThreadPool.map (Data, Propagator,
+ new ThreadPool.Operation<RAEPropagation, LabeledDatum<Integer,Integer>>() {
+ public void perform(RAEPropagation locPropagator, int index,
+ LabeledDatum<Integer,Integer> data)
{
- ExtractedTrees[index] = tree;
+ LabeledRAETree tree = getRAETree(locPropagator, data);
+ lock.lock();
+ {
+ ExtractedTrees[index] = tree;
+ }
+ lock.unlock();
}
- lock.unlock();
- }
- });
+ });
return Arrays.asList(ExtractedTrees);
}
- public LabeledRAETree getRAETree(LabeledDatum<Integer,Integer> data)
+ public LabeledRAETree getRAETree(RAEPropagation Propagator, LabeledDatum<Integer,Integer> data)
{
int SentenceLength = data.getFeatures().size();
Please sign in to comment.
Something went wrong with that request. Please try again.