Skip to content

Commit

Permalink
Add evaluation functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
h-huss committed Jul 27, 2018
1 parent 89781b2 commit 4ac580e
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 3 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ Alternatively, only the change count can be predicted (entry point: [FileControl
curl -d '{"predictionHorizon": 256,"gitServer": {"url":"https://github.com/apache/commons-io"},"h2oUrl":"http://localhost:54321"}' -H "Content-Type: application/json" -X POST http://localhost:5432/files/predict
```

## Evaluation

To evaluate the prediction quality, run the main class / maven call above with the argument `--evaluate`. This will predict the future for a default project, and compare
it with the real, unseen future. This process does not start the REST api.

This evaluation can further be customized with the arguments `--repo=`, `--user=`, `--password=` and `--horizon=`

## Functionality

This tool consists of two major parts, the first one being the Desirability-Score for calculating the importance / desire for a fix of a specific issue.
Expand Down
62 changes: 62 additions & 0 deletions src/main/java/de/viadee/sonarIssueScoring/Evaluator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package de.viadee.sonarIssueScoring;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

import de.viadee.sonarIssueScoring.service.PredictionParams;
import de.viadee.sonarIssueScoring.service.desirability.ServerInfo;
import de.viadee.sonarIssueScoring.service.desirability.ServerInfo.Builder;
import de.viadee.sonarIssueScoring.service.prediction.EvaluationResult;
import de.viadee.sonarIssueScoring.service.prediction.PredictionService;

@Component
public class Evaluator implements ApplicationRunner {
private static final Logger log = LoggerFactory.getLogger(Evaluator.class);

private final PredictionService predictionService;

public Evaluator(PredictionService predictionService) {this.predictionService = predictionService;}

/**
* Evaluates the prediction quality versus the actual future on a given sample project.
*/
@Override
public void run(ApplicationArguments args) {
if (willRunEvaluation(args)) {
log.info("Starting evaluation. No web server is started."); //Web server is disabled in SonarIssueScoringApplication

Builder builder = ServerInfo.builder();

if (args.containsOption("repo"))
builder.url(args.getOptionValues("repo").get(0));
else {
log.info("No repository provided, using default");
builder.url("https://github.com/apache/commons-lang");
}

if (args.containsOption("user"))
builder.user(args.getOptionValues("user").get(0));

if (args.containsOption("password"))
builder.password(args.getOptionValues("password").get(0));

int horizon = 384;
if (args.containsOption("horizon"))
horizon = Integer.parseInt(args.getOptionValues("horizon").get(0));

ServerInfo server = builder.build();

log.info("Running evaluation for {} with horizon", predictionService); //Password is redacted automatically
EvaluationResult result = predictionService.evaluate(PredictionParams.of(server, horizon), "http://localhost:54321");
log.info("Evaluation result {}", result);
}
}

/** Static, because the context is not yet setup when this is called. */
public static boolean willRunEvaluation(ApplicationArguments args) {
return args.containsOption("evaluate");
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package de.viadee.sonarIssueScoring;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.DefaultApplicationArguments;
import org.springframework.boot.WebApplicationType;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.jackson.Jackson2ObjectMapperBuilderCustomizer;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

Expand All @@ -18,6 +20,11 @@ public Jackson2ObjectMapperBuilderCustomizer configureObjectMapper() {
}

public static void main(String[] args) {
SpringApplication.run(SonarIssueScoringApplication.class, args);
boolean runEvaluation = Evaluator.willRunEvaluation(new DefaultApplicationArguments(args));
new SpringApplicationBuilder().
main(SonarIssueScoringApplication.class).
sources(SonarIssueScoringApplication.class).
web(runEvaluation ? WebApplicationType.NONE : WebApplicationType.SERVLET).
run(args);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package de.viadee.sonarIssueScoring.service.prediction;

import org.immutables.value.Value.Immutable;

import com.google.common.collect.Table;

import de.viadee.sonarIssueScoring.misc.ImmutableStyle;

@Immutable
@ImmutableStyle
public interface BaseEvaluationResult {
public double rmse();

public double r2();

/** Rows: actual value > 80% percentile, Cols: predicted value > 80% percentile */
public Table<Boolean, Boolean, Integer> confusionMatrix();
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
package de.viadee.sonarIssueScoring.service.prediction;

import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

import org.springframework.stereotype.Component;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableTable;
import com.google.common.math.PairedStatsAccumulator;
import com.google.common.math.Quantiles;
import com.google.common.math.Quantiles.ScaleAndIndex;

import de.viadee.sonarIssueScoring.service.PredictionParams;
import de.viadee.sonarIssueScoring.service.prediction.load.SnapshotStrategy;
import de.viadee.sonarIssueScoring.service.prediction.load.RepositoryLoader;
import de.viadee.sonarIssueScoring.service.prediction.load.SnapshotStrategy;
import de.viadee.sonarIssueScoring.service.prediction.load.SplitRepository;
import de.viadee.sonarIssueScoring.service.prediction.train.Instance;
import de.viadee.sonarIssueScoring.service.prediction.train.MLInput;
Expand All @@ -33,5 +41,56 @@ public PredictionResult predict(PredictionParams params, String h2oServer) {

return mlService.predict(MLInput.of(instances, predictableInstances, h2oServer));
}

/** Extract data and build a model for the past, and compare it with the more recent, not learned past to gauge prediction quality */
public EvaluationResult evaluate(PredictionParams params, String h2oServer) {
SplitRepository data = repositoryLoader.loadSplitRepository(params, SnapshotStrategy.NO_OVERLAP_ON_MOST_RECENT);
//Use the most recent pastFuturePair as actual future, which has to be predicted. The SnapshotStrategy assures this future is not used as training data, even partially

Preconditions.checkState(data.trainingData().size() > 1, "Not enough historical data");

List<Instance> instances = instanceSource.extractInstances(data.trainingData().subList(1, data.trainingData().size())); //Training data, based on past
List<Instance> predictableInstances = instanceSource.extractInstances(data.trainingData().subList(0, 1));

PredictionResult result = mlService.predict(MLInput.of(instances, predictableInstances, h2oServer));

// Collect predicted vs actual future
List<ResultPair> pairs = predictableInstances.stream().map(
instance -> new ResultPair(result.results().get(instance.path()).predictedChangeCount(), instance.targetEditCountPercentile())).collect(
Collectors.toList());

return EvaluationResult.of(rmse(pairs), r2(pairs), confusionMatrix(pairs));
}

static ImmutableTable<Boolean, Boolean, Integer> confusionMatrix(Collection<ResultPair> pairs) {
//Identify commonly-edited files: all files edited more than the percentile below
ScaleAndIndex requirement = Quantiles.percentiles().index(80);
double thresholdActual = requirement.computeInPlace(pairs.stream().mapToDouble(p -> p.actual).toArray());
double thresholdPredicted = requirement.computeInPlace(requirement.computeInPlace(pairs.stream().mapToDouble(p -> p.predicted).toArray()));

return pairs.stream().collect(ImmutableTable.toImmutableTable(//
pair -> pair.actual >= thresholdActual, // Row = actual
pair -> pair.predicted >= thresholdPredicted, // Col == predicted
pair -> 1, (a, b) -> a + b));
}

static double rmse(Collection<ResultPair> pair) {
return Math.sqrt(pair.stream().mapToDouble(p -> Math.pow(p.actual - p.predicted, 2)).average().orElse(0));
}

static double r2(Collection<ResultPair> pair) {
PairedStatsAccumulator acc = new PairedStatsAccumulator();
pair.forEach(p -> acc.add(p.actual, p.predicted));
return Math.pow(acc.pearsonsCorrelationCoefficient(), 2);
}

static class ResultPair {
private final double predicted, actual;

ResultPair(double predicted, double actual) {
this.predicted = predicted;
this.actual = actual;
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package de.viadee.sonarIssueScoring.service.prediction;

import java.util.ArrayList;
import java.util.List;

import org.junit.Assert;
import org.junit.Test;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Table;

import de.viadee.sonarIssueScoring.service.prediction.PredictionService.ResultPair;

public class PredictionServiceTest {

@Test
public void confusionMatrix() {
List<ResultPair> values = new ArrayList<>(ImmutableList.of(//
new ResultPair(0, 1),//
new ResultPair(0, 1),//
new ResultPair(1, 0),//
new ResultPair(1, 0),//
new ResultPair(1, 0),//
new ResultPair(1, 1)));//

for (int i = 0; i < 8; i++)
values.add(new ResultPair(0, 0));

Table<Boolean, Boolean, Integer> matrix = PredictionService.confusionMatrix(values);
// actual, predicted
Assert.assertEquals(8, (int) matrix.get(false, false));
Assert.assertEquals(1, (int) matrix.get(true, true));
Assert.assertEquals(3, (int) matrix.get(false, true));
Assert.assertEquals(2, (int) matrix.get(true, false));
}

@Test
public void rmse() {
ImmutableList<ResultPair> values = ImmutableList.of(//
new ResultPair(90, 80),// 10 * 10 = 100
new ResultPair(50, 70),// 20 * 20 = 400
new ResultPair(50, 50));// 0 * 0 = 0

//sqrt(500/3)
Assert.assertEquals(12.909, PredictionService.rmse(values), 1.0e-3);
}

@Test
public void r2() {
ImmutableList<ResultPair> values = ImmutableList.of(//
new ResultPair(90, 80),//
new ResultPair(76, 70),//
new ResultPair(50, 0),//
new ResultPair(33, 30),//
new ResultPair(40, 40));

Assert.assertEquals(0.58920, PredictionService.r2(values), 5.0e-5);
}
}

0 comments on commit 4ac580e

Please sign in to comment.