Skip to content

Commit

Permalink
save and load cf, closes #2, furthermore working on #3
Browse files Browse the repository at this point in the history
  • Loading branch information
philipp94831 committed Jun 12, 2016
1 parent cf76b74 commit 1a6cfe6
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 535 deletions.
27 changes: 13 additions & 14 deletions cf/src/main/java/de/hpi/mmds/wiki/cf/CollaborativeFiltering.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@
import java.util.NoSuchElementException;
import java.util.stream.Collectors;

@SuppressWarnings("unused")
public class CollaborativeFiltering implements Serializable, Recommender {

public static final String PRODUCT_PATH = "/product";
public static final String USER_PATH = "/user";
private static final int RANK = 35;
private static final double LOG2 = Math.log(2);
private static final boolean MANUAL_SAVE_LOAD = true;
private static final int NUM_ITERATIONS = 10;
Expand All @@ -46,15 +44,15 @@ public CollaborativeFiltering(JavaSparkContext jsc, String filterDir) {
model = loadModel(jsc, filterDir);
}

private MatrixFactorizationModel loadModel(JavaSparkContext jsc, String filterDir) {
private static MatrixFactorizationModel loadModel(JavaSparkContext jsc, String filterDir) {
final MatrixFactorizationModel model;
if (!MANUAL_SAVE_LOAD) {
model = MatrixFactorizationModel.load(jsc.sc(), filterDir);
} else {
final int rank;
try (BufferedReader in = new BufferedReader(new FileReader(new File(filterDir + "/meta")))) {
rank = Integer.parseInt(in.readLine());
} catch (IOException e) {
} catch (Exception e) {
throw new RuntimeException("Error reading metadata", e);
}
final JavaRDD<Tuple2<Object, double[]>> userFeatures = jsc.<Tuple2<Object, double[]>>objectFile(
Expand All @@ -63,11 +61,11 @@ private MatrixFactorizationModel loadModel(JavaSparkContext jsc, String filterDi
filterDir + PRODUCT_PATH).cache();
model = new MatrixFactorizationModel(rank, userFeatures.rdd(), productFeatures.rdd());
}
logger.info("Model loaded");
jsc.sc().log().info("Model loaded");
return model;
}

public CollaborativeFiltering(JavaSparkContext jsc, String filterDir, String path) {
public CollaborativeFiltering(JavaSparkContext jsc, String filterDir, String path, int rank, double lambda, double alpha) {
logger = jsc.sc().log();
JavaRDD<String> data = jsc.textFile(path);
JavaRDD<Rating> ratings = data.map(new Function<String, Rating>() {
Expand All @@ -78,28 +76,29 @@ public CollaborativeFiltering(JavaSparkContext jsc, String filterDir, String pat
public Rating call(String s) {
String[] sarray = s.split(",");
return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
Math.log(Double.parseDouble(sarray[2])) / LOG2 + 1);
Double.parseDouble(sarray[2]));
}
});
ratings.cache();
}).cache();
logger.info("Ratings imported");
model = ALS.trainImplicit(JavaRDD.toRDD(ratings), RANK, NUM_ITERATIONS);
model = ALS.trainImplicit(JavaRDD.toRDD(ratings), rank, NUM_ITERATIONS, lambda, alpha);
logger.info("Model trained");
try {
saveModel(jsc, filterDir);
saveModel(filterDir);
} catch (IOException e) {
throw new RuntimeException("Error saving model to disk", e);
}
ratings.unpersist();
}

private void saveModel(JavaSparkContext jsc, String filterDir) throws IOException {
public void saveModel(String filterDir) throws IOException {
FileUtils.deleteDirectory(new File(filterDir));
if (!MANUAL_SAVE_LOAD) {
model.save(jsc.sc(), filterDir);
model.save(model.productFeatures().sparkContext(), filterDir);
} else {
new File(filterDir).mkdirs();
File metadata = new File(filterDir + "/meta");
try (BufferedWriter out = new BufferedWriter(new FileWriter(metadata))) {
out.write(model.rank());
out.write(Integer.toString(model.rank()));
out.newLine();
}
model.userFeatures().saveAsObjectFile(filterDir + USER_PATH);
Expand Down
31 changes: 21 additions & 10 deletions commons/src/main/java/de/hpi/mmds/wiki/Edits.java
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
package de.hpi.mmds.wiki;

import de.hpi.mmds.wiki.spark.KeyFilter;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.storage.StorageLevel;

import scala.Tuple2;

import java.io.Serializable;

import static de.hpi.mmds.wiki.spark.SparkFunctions.identity;
import static de.hpi.mmds.wiki.spark.SparkFunctions.keyFilter;

public class Edits implements Serializable {

private static final long serialVersionUID = 1668840974181477332L;
private final JavaPairRDD<Integer, Integer> edits;
private final JavaPairRDD<Integer, Iterable<Integer>> edits;

public Edits(JavaSparkContext jsc, String dataDir) {
edits = parseEdits(jsc, dataDir);
edits.cache();
}

private static JavaPairRDD<Integer, Integer> parseEdits(JavaSparkContext jsc, String dataDir) {
private static JavaPairRDD<Integer, Iterable<Integer>> parseEdits(JavaSparkContext jsc, String dataDir) {
JavaRDD<String> data = jsc.textFile(dataDir);
JavaPairRDD<Integer, Integer> edits = data.mapToPair(new PairFunction<String, Integer, Integer>() {
JavaPairRDD<Integer, Iterable<Integer>> edits = data.mapToPair(new PairFunction<String, Integer, Integer>() {

private static final long serialVersionUID = -4781040078296911266L;

Expand All @@ -32,25 +33,35 @@ public Tuple2<Integer, Integer> call(String s) throws Exception {
String[] sarray = s.split(",");
return new Tuple2<>(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]));
}
});
}).groupByKey();
jsc.sc().log().info("Edit data loaded");
return edits;
}

public JavaPairRDD<Integer, Iterable<Integer>> getAggregatedEdits() {
return edits.groupByKey();
return edits;
}

public JavaPairRDD<Integer, Integer> getAllEdits() {
return edits;
return edits.flatMapValues(identity());
}

public JavaRDD<Integer> getEdits(int user) {
return edits.filter(new KeyFilter<>(user)).values();
return edits.filter(keyFilter(user)).flatMap(t -> t._2);
}

public JavaRDD<Integer> getUsers() {
return edits.keys();
}

public Edits cache() {
edits.cache();
return this;
}

public Edits persist(StorageLevel level) {
edits.persist(level);
return this;
}

}
63 changes: 31 additions & 32 deletions commons/src/main/java/de/hpi/mmds/wiki/Evaluator.java
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package de.hpi.mmds.wiki;

import com.google.common.collect.Sets;

import org.apache.commons.io.FileUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;

import scala.Tuple2;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
Expand All @@ -33,13 +35,18 @@ public Evaluator(Recommender recommender, Edits test, Edits training, File out)
}

public Map<Integer, Result> evaluate(int num) {
return evaluate(num, 1L);
return evaluate(num, new Random().nextLong());
}

public Map<Integer, Result> evaluate(int num, long seed) {
Map<Integer, Result> results = new HashMap<>();
List<Integer> uids = new ArrayList<>(test.getUsers().intersection(training.getUsers()).collect());
int i = 0;
JavaPairRDD<Integer, Set<Integer>> groundTruths = test.getAggregatedEdits().join(training.getAggregatedEdits())
.mapValues(t -> {
Set<Integer> gt = Sets.newHashSet(t._1);
gt.removeAll(Sets.newHashSet(t._2));
return gt;
}).filter(t -> !t._2.isEmpty());
double totalPrecision = 0.0;
double totalRecall = 0.0;
if (out.exists()) {
Expand All @@ -49,35 +56,27 @@ public Map<Integer, Result> evaluate(int num, long seed) {
throw new RuntimeException("Could not delete output file " + out.getPath(), e);
}
}
Collections.shuffle(uids, new Random(seed));
try (BufferedWriter writer = new BufferedWriter(new FileWriter(out))) {
for (int user : uids) {
if (i >= num) {
break;
}
for (Tuple2<Integer, Set<Integer>> t : groundTruths.takeSample(false, num, seed)) {
int user = t._1;
JavaRDD<Integer> articles = training.getEdits(user);
List<Integer> a = test.getEdits(user).collect();
List<Integer> p = articles.collect();
Set<Integer> groundTruth = new HashSet<>(a);
groundTruth.removeAll(p);
if (!groundTruth.isEmpty()) {
Set<Integer> recommendations = recommender.recommend(user, articles, NUM_RECOMMENDATIONS).stream()
.map(Recommendation::getArticle).collect(Collectors.toSet());
Result<Integer> result = new Result<>(recommendations, groundTruth);
results.put(user, result);
totalPrecision += result.precision();
totalRecall += result.recall();
i++;
writer.write("User: " + user + "\n");
writer.write(result.printResult());
writer.newLine();
writer.write("AVG Precision: " + totalPrecision / i + "\n");
writer.write("AVG Recall: " + totalRecall / i + "\n");
writer.write("Processed: " + i + "\n");
writer.write("---");
writer.newLine();
writer.flush();
}
Set<Integer> groundTruth = t._2;
Set<Integer> recommendations = recommender.recommend(user, articles, NUM_RECOMMENDATIONS).stream()
.map(Recommendation::getArticle).collect(Collectors.toSet());
Result<Integer> result = new Result<>(recommendations, groundTruth);
results.put(user, result);
totalPrecision += result.precision();
totalRecall += result.recall();
i++;
writer.write("User: " + user + "\n");
writer.write(result.printResult());
writer.newLine();
writer.write("AVG Precision: " + totalPrecision / i + "\n");
writer.write("AVG Recall: " + totalRecall / i + "\n");
writer.write("Processed: " + i + "\n");
writer.write("---");
writer.newLine();
writer.flush();
}
} catch (IOException e) {
throw new RuntimeException("Error writing to output file " + out.getPath(), e);
Expand Down
20 changes: 0 additions & 20 deletions commons/src/main/java/de/hpi/mmds/wiki/spark/KeyFilter.java

This file was deleted.

17 changes: 17 additions & 0 deletions commons/src/main/java/de/hpi/mmds/wiki/spark/SparkFunctions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package de.hpi.mmds.wiki.spark;

import org.apache.spark.api.java.function.Function;

import scala.Tuple2;

public class SparkFunctions {

public static <T, U> Function<Tuple2<T, U>, Boolean> keyFilter(T value) {
return t -> t._1.equals(value);
}

public static <T> Function<T, T> identity() {
return (v) -> v;
}

}
Loading

0 comments on commit 1a6cfe6

Please sign in to comment.