Skip to content

Commit

Permalink
Further changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ahoy196 committed Jun 27, 2010
1 parent f9288f8 commit 46243ea
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 148 deletions.
Expand Up @@ -66,8 +66,7 @@ public RBMDriver(DistributedRowMatrix inputUserMatrix, DistributedRowMatrix inpu

public static void runJob() throws IOException, InterruptedException, ClassNotFoundException {

Path stateIn = new Path(output, "state-0");
writeInitialState(stateIn, numTopics, numWords);
//Path stateIn = new Path(output, "state-0");
boolean converged = false;

DRand randn;
Expand Down Expand Up @@ -137,7 +136,7 @@ public static void runJob() throws IOException, InterruptedException, ClassNotFo


while(userVector.hasNext()) {
state.nrmse += this.runIteration(userVector.next());
state.nrmse += runIteration(userVector.next());
}

state.zero(state.CDpos, state.numItems, state.softmax, state.totalFeatures);
Expand All @@ -148,7 +147,7 @@ public static void runJob() throws IOException, InterruptedException, ClassNotFo
state.zero(state.negvisact, state.numItems, state.softmax);
state.zero(state.moviecount, state.numItems);

state.nrmse=Math.sqrt(state.nrmse/ntrain);
state.nrmse = Math.sqrt(state.nrmse/ntrain);
state.prmse = Math.sqrt(s/n);

if ( state.totalFeatures == 200 ) {
Expand Down Expand Up @@ -179,9 +178,12 @@ public static void runJob() throws IOException, InterruptedException, ClassNotFo
state.EpsilonVB *= 0.78;
state.EpsilonHB *= 0.78;
}

//recordErrors();
}

public static double runIteration() throws IOException, InterruptedException, ClassNotFoundException {

public static double runIteration(MatrixSlice userVector) throws IOException, InterruptedException, ClassNotFoundException {
Configuration conf = new Configuration();

Job job = new Job(conf);
Expand Down
Expand Up @@ -24,14 +24,15 @@
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.DistributedRowMatrix;

public class RBMInputDriver {

public static void runJob(Path input, Path output)
throws IOException, InterruptedException, ClassNotFoundException {
public static void runJob(Path input, Path output) throws IOException,
InterruptedException, ClassNotFoundException {
HadoopUtil.overwriteOutput(output);

Configuration conf = new Configuration();
Expand All @@ -42,7 +43,7 @@ public static void runJob(Path input, Path output)
job.setMapOutputValueClass(DistributedRowMatrix.MatrixEntryWritable.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(VectorWritable.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setOutputValueClass(SequenceFileOutputFormat.class);
job.setMapperClass(RBMInputMapper.class);
job.setReducerClass(RBMInputReducer.class);

Expand Down
Expand Up @@ -24,17 +24,18 @@
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.math.hadoop.DistributedRowMatrix;

public class RBMInputMapper extends Mapper<LongWritable, Text, IntWritable, DistributedRowMatrix.MatrixEntryWritable> {

public class RBMInputMapper
extends
Mapper<LongWritable,Text,IntWritable,DistributedRowMatrix.MatrixEntryWritable> {

@Override
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {

String[] entry = value.toString().split(",");

String [] entry = value.toString().split(",");

//User is the key for the Reducer
DistributedRowMatrix.MatrixEntryWritable record =
new DistributedRowMatrix.MatrixEntryWritable();
// User is the key for the Reducer
DistributedRowMatrix.MatrixEntryWritable record = new DistributedRowMatrix.MatrixEntryWritable();
IntWritable row = new IntWritable(Integer.valueOf(entry[0]));
record.setRow(-1);
record.setCol(Integer.valueOf(entry[1]));
Expand Down
Expand Up @@ -25,20 +25,22 @@
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.DistributedRowMatrix;

public class RBMInputReducer extends Reducer<IntWritable, DistributedRowMatrix.MatrixEntryWritable,
IntWritable, VectorWritable> {
public class RBMInputReducer
extends
Reducer<IntWritable,DistributedRowMatrix.MatrixEntryWritable,IntWritable,VectorWritable> {

@Override
protected void reduce(IntWritable record,
Iterable<DistributedRowMatrix.MatrixEntryWritable> recordEntries,
Context context)
throws IOException, InterruptedException {
RandomAccessSparseVector toWrite = new RandomAccessSparseVector(Integer.MAX_VALUE, 100); //100? or something else?

protected void reduce(IntWritable record,
Iterable<DistributedRowMatrix.MatrixEntryWritable> recordEntries,
Context context) throws IOException, InterruptedException {
RandomAccessSparseVector toWrite = new RandomAccessSparseVector(
Integer.MAX_VALUE, 100); // 100? or something else?
for (DistributedRowMatrix.MatrixEntryWritable entryItem : recordEntries) {
toWrite.setQuick(entryItem.getCol(), entryItem.getVal());
}
SequentialAccessSparseVector output = new SequentialAccessSparseVector(toWrite);
SequentialAccessSparseVector output = new SequentialAccessSparseVector(
toWrite);
context.write(record, new VectorWritable(output));
}
}

0 comments on commit 46243ea

Please sign in to comment.