Skip to content

Commit

Permalink
test for evaluator #3
Browse files Browse the repository at this point in the history
  • Loading branch information
philipp94831 committed Jun 8, 2016
1 parent b0dbb15 commit 6e63f3b
Show file tree
Hide file tree
Showing 12 changed files with 162 additions and 64 deletions.
27 changes: 14 additions & 13 deletions commons/pom.xml
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
<?xml version="1.0"?>
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>de.hpi.mmds.wiki</groupId>
<artifactId>wiki</artifactId>
<version>0.0.1-SNAPSHOT</version>
</parent>
<artifactId>commons</artifactId>
<name>mmds-wiki-commons</name>
<url>http://maven.apache.org</url>
<dependencies>
</dependencies>
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"
xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>de.hpi.mmds.wiki</groupId>
<artifactId>wiki</artifactId>
<version>0.0.1-SNAPSHOT</version>
</parent>
<artifactId>commons</artifactId>
<name>mmds-wiki-commons</name>
<url>http://maven.apache.org</url>
<dependencies>
</dependencies>
</project>
7 changes: 3 additions & 4 deletions commons/src/main/java/de/hpi/mmds/wiki/Edits.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package de.hpi.mmds.wiki;

import java.io.Serializable;

import de.hpi.mmds.wiki.spark.KeyFilter;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;

import scala.Tuple2;
import de.hpi.mmds.wiki.spark.KeyFilter;

import java.io.Serializable;

public class Edits implements Serializable {

Expand Down
79 changes: 54 additions & 25 deletions commons/src/main/java/de/hpi/mmds/wiki/Evaluator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import org.apache.commons.io.FileUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

import java.io.BufferedWriter;
import java.io.File;
Expand All @@ -19,25 +18,60 @@ public class Evaluator {
private final File out;
private final Recommender recommender;

public static class Result<T> {

private final Set<T> recommendations;
private final Set<T> groundTruth;
private final Set<T> intersect;

public Result(Set<T> recommendations, Set<T> groundTruth) {
this.recommendations = recommendations;
this.groundTruth = groundTruth;
this.intersect = new HashSet<>(groundTruth);
this.intersect.retainAll(recommendations);
}

public double precision() {
return recommendations.isEmpty() ? 0 : (double) intersect.size() / recommendations.size();
}

public double recall() {
return (double) intersect.size() / groundTruth.size();
}

public String printResult() {
StringBuilder sb = new StringBuilder();
sb.append("Recommendations: " + recommendations + "\n");
sb.append("Gold standard: " + groundTruth + "\n");
sb.append("Matches: " + intersect + "\n");
sb.append("Precision: " + precision() + "\n");
sb.append("Recall: " + recall());
return sb.toString();
}
}

public Evaluator(Recommender recommender, Edits test, Edits training, File out) {
this.recommender = recommender;
this.training = training;
this.test = test;
this.out = out;
}

public void evaluate(int num, long seed) {
public Map<Integer, Result> evaluate(int num, long seed) {
Map<Integer, Result> results = new HashMap<>();
List<Integer> uids = new ArrayList<>(test.getUsers().intersection(training.getUsers()).collect());
int i = 0;
double totalPrecision = 0.0;
double totalRecall = 0.0;
try {
FileUtils.forceDelete(out);
} catch (IOException e) {
throw new RuntimeException("Could not delete output file " + out.getPath(), e);
if (out.exists()) {
try {
FileUtils.forceDelete(out);
} catch (IOException e) {
throw new RuntimeException("Could not delete output file " + out.getPath(), e);
}
}
Collections.shuffle(uids, new Random(seed));
try (BufferedWriter writer = new BufferedWriter(new FileWriter(this.out))) {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(out))) {
for (int user : uids) {
if (i >= num) {
break;
Expand All @@ -48,27 +82,21 @@ public void evaluate(int num, long seed) {
Set<Integer> groundTruth = new HashSet<>(a);
groundTruth.removeAll(p);
if (!groundTruth.isEmpty()) {
List<Integer> recommendations = recommender.recommend(user, articles, NUM_RECOMMENDATIONS).stream()
.map(Recommendation::getArticle).collect(Collectors.toList());
Set<Integer> intersect = new HashSet<>(groundTruth);
intersect.retainAll(recommendations);
double precision = recommendations.isEmpty() ?
0 :
(double) intersect.size() / recommendations.size();
double recall = (double) intersect.size() / groundTruth.size();
totalPrecision += precision;
totalRecall += recall;
Set<Integer> recommendations = recommender.recommend(user, articles, NUM_RECOMMENDATIONS).stream()
.map(Recommendation::getArticle).collect(Collectors.toSet());
Result<Integer> result = new Result<>(recommendations, groundTruth);
results.put(user, result);
totalPrecision += result.precision();
totalRecall += result.recall();
i++;
writer.write("User: " + user + "\n");
writer.write("Recommendations: " + recommendations + "\n");
writer.write("Gold standard: " + groundTruth + "\n");
writer.write("Matches: " + intersect + "\n");
writer.write("Precision: " + precision + "\n");
writer.write(result.printResult());
writer.newLine();
writer.write("AVG Precision: " + totalPrecision / i + "\n");
writer.write("Recall: " + recall + "\n");
writer.write("AVG Recall: " + totalRecall / i + "\n");
writer.write("Processed: " + i + "\n");
writer.write("---\n");
writer.write("---");
writer.newLine();
writer.flush();
}
}
Expand All @@ -79,9 +107,10 @@ public void evaluate(int num, long seed) {
System.out.println("AVG Precision: " + totalPrecision / i);
System.out.println("AVG Recall: " + totalRecall / i);
}
return results;
}

public void evaluate(int num) {
evaluate(num, 1L);
public Map<Integer, Result> evaluate(int num) {
return evaluate(num, 1L);
}
}
5 changes: 4 additions & 1 deletion commons/src/main/java/de/hpi/mmds/wiki/MultiRecommender.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import org.apache.spark.api.java.JavaRDD;
import scala.Tuple2;

import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class MultiRecommender implements Recommender {
Expand Down
4 changes: 2 additions & 2 deletions commons/src/main/java/de/hpi/mmds/wiki/Recommender.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package de.hpi.mmds.wiki;

import java.util.List;

import org.apache.spark.api.java.JavaRDD;

import java.util.List;

public interface Recommender {

default List<Recommendation> recommend(int userId, JavaRDD<Integer> articles) {
Expand Down
1 change: 0 additions & 1 deletion commons/src/main/java/de/hpi/mmds/wiki/SparkUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;

import scala.Tuple2;

public class SparkUtil {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package de.hpi.mmds.wiki.spark;

import org.apache.spark.api.java.function.Function;

import scala.Tuple2;

public final class KeyFilter<T> implements Function<Tuple2<T, T>, Boolean> {
Expand Down
16 changes: 8 additions & 8 deletions commons/src/test/java/de/hpi/mmds/wiki/EditsTest.java
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
package de.hpi.mmds.wiki;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.util.List;
import java.util.Map;

import org.apache.spark.api.java.JavaSparkContext;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import scala.Tuple2;

import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class EditsTest {

private static Edits edits;
Expand All @@ -21,7 +20,8 @@ public class EditsTest {
@BeforeClass
public static void setup() {
jsc = SparkUtil.getContext();
edits = new Edits(jsc, Thread.currentThread().getContextClassLoader().getResource("test_data.txt").getPath());
edits = new Edits(jsc,
Thread.currentThread().getContextClassLoader().getResource("training_data.txt").getPath());
}

@AfterClass
Expand Down
60 changes: 60 additions & 0 deletions commons/src/test/java/de/hpi/mmds/wiki/EvaluatorTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package de.hpi.mmds.wiki;

import de.hpi.mmds.wiki.Evaluator.Result;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.*;

import java.io.File;
import java.util.Arrays;
import java.util.Map;

import static org.junit.Assert.assertEquals;

public class EvaluatorTest {

public static final double DOUBLE_TOLERANCE = 0.001;
private static Edits test;
private static Edits training;
private static JavaSparkContext jsc;
private static Recommender recommender;
private File out;

@BeforeClass
public static void setupClass() {
jsc = SparkUtil.getContext();
test = new Edits(jsc, Thread.currentThread().getContextClassLoader().getResource("test_data.txt").getPath());
training = new Edits(jsc,
Thread.currentThread().getContextClassLoader().getResource("training_data.txt").getPath());
recommender = (userId, articles, howMany) -> Arrays
.asList(new Recommendation(1.0, 10), new Recommendation(1.0, 11));
}

@AfterClass
public static void tearDownClass() {
jsc.close();
}

@Before
public void setup() {
out = new File("out.txt");
}

@After
public void tearDown() {
out.delete();
}

@Test
public void test() {
Evaluator eval = new Evaluator(recommender, test, training, out);
Map<Integer, Result> results = eval.evaluate(3);
assertEquals(3, results.size());
assertEquals(1.0, results.get(1).precision(), DOUBLE_TOLERANCE);
assertEquals(1.0, results.get(1).recall(), DOUBLE_TOLERANCE);
assertEquals(0.5, results.get(2).precision(), DOUBLE_TOLERANCE);
assertEquals(1.0 / 3, results.get(2).recall(), DOUBLE_TOLERANCE);
assertEquals(0.0, results.get(3).precision(), DOUBLE_TOLERANCE);
assertEquals(0.0, results.get(3).recall(), DOUBLE_TOLERANCE);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.junit.Assert.assertEquals;

public class MultiRecommenderTest {

public static final double DOUBLE_TOLERANCE = 0.001;
Expand All @@ -20,7 +21,8 @@ public class MultiRecommenderTest {
@BeforeClass
public static void setup() {
jsc = SparkUtil.getContext();
edits = new Edits(jsc, Thread.currentThread().getContextClassLoader().getResource("test_data.txt").getPath());
edits = new Edits(jsc,
Thread.currentThread().getContextClassLoader().getResource("training_data.txt").getPath());
}

@AfterClass
Expand Down
13 changes: 6 additions & 7 deletions commons/src/test/resources/test_data.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
1,1
1,2
2,1
2,3
2,4
3,2
3,5
1,10
1,11
2,10
2,12
2,13
3,13
7 changes: 7 additions & 0 deletions commons/src/test/resources/training_data.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
1,1
1,2
2,1
2,3
2,4
3,2
3,5

0 comments on commit 6e63f3b

Please sign in to comment.