Skip to content

Commit

Permalink
First commit of RawThresholdClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
fabuzaid21 committed Aug 6, 2017
1 parent 6407090 commit bcd203c
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

import edu.stanford.futuredata.macrobase.analysis.classify.Classifier;
import edu.stanford.futuredata.macrobase.analysis.classify.PercentileClassifier;
import edu.stanford.futuredata.macrobase.analysis.classify.RawThresholdClassifier;
import edu.stanford.futuredata.macrobase.analysis.summary.APrioriSummarizer;
import edu.stanford.futuredata.macrobase.analysis.summary.BatchSummarizer;
import edu.stanford.futuredata.macrobase.analysis.summary.Explanation;
import edu.stanford.futuredata.macrobase.datamodel.DataFrame;
import edu.stanford.futuredata.macrobase.datamodel.Schema;
import edu.stanford.futuredata.macrobase.ingest.CSVDataFrameLoader;
import edu.stanford.futuredata.macrobase.util.MacrobaseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* Simplest default pipeline: load, classify, and then explain
Expand All @@ -21,44 +23,54 @@
public class BasicBatchPipeline implements Pipeline {
Logger log = LoggerFactory.getLogger(Pipeline.class);

private String inputURI = null;
// All classifier-specific fields need to be retrieved from ``conf''
private final PipelineConfig conf;

private String classifierType = "percentile";
private String metric = null;
private double cutoff = 1.0;
private boolean pctileHigh = true;
private boolean pctileLow = true;
// PipelineConfig params applicable to all classifiers
private final String inputURI;
private final String classifierType;
private final String metric;

// PipelineConfig params applicable to all summarizers
private String summarizerType = "apriori";
private List<String> attributes = null;
private double minSupport = 0.01;
private double minRiskRatio = 5.0;


public BasicBatchPipeline (PipelineConfig conf) {
this.conf = conf;
// these fields must be defined explicitly in the conf.yaml file
inputURI = conf.get("inputURI");

classifierType = conf.get("classifier");
metric = conf.get("metric");
cutoff = conf.get("cutoff");
pctileHigh = conf.get("includeHi");
pctileLow = conf.get("includeLo");

summarizerType = conf.get("summarizer");
attributes = conf.get("attributes");
minRiskRatio = conf.get("minRiskRatio");
minSupport = conf.get("minSupport");

}

public Classifier getClassifier() throws MacrobaseException {
switch (classifierType.toLowerCase()) {
case "percentile": {
PercentileClassifier classifier = new PercentileClassifier(metric);
classifier.setPercentile(cutoff);
classifier.setIncludeHigh(pctileHigh);
classifier.setIncludeLow(pctileLow);
return classifier;
// default values for PercentileClassifier:
// {cuttoff: 1.0, includeHi: true, includeLo: true}
final double cutoff = conf.get("cutoff", 1.0);
final boolean pctileHigh = conf.get("includeHi", true);
final boolean pctileLow = conf.get("includeLo", true);

return new PercentileClassifier(metric)
.setPercentile(cutoff)
.setIncludeHigh(pctileHigh)
.setIncludeLow(pctileLow);
}
case "raw_threshold": {
// default values for RawThresholdClassifier
// {predicate: "==", value: 1.0}
final String predicateStr = conf.get("predicate", "==");
final double metricValue = conf.get("value", 1.0);
return new RawThresholdClassifier(metric, predicateStr, metricValue);
}
default : {
throw new MacrobaseException("Bad Classifier Type");
Expand Down Expand Up @@ -109,8 +121,6 @@ public Explanation results() throws Exception {
summarizer.process(df);
elapsed = System.currentTimeMillis() - startTime;
log.info("Summarization time: {}", elapsed);
Explanation output = summarizer.getResults();

return output;
return summarizer.getResults();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,13 @@ public void testDemoQuery() throws Exception {
assertEquals(3, e.getNumInliers());
}

@Test
public void testRawThresholdClassifier() throws Exception {
PipelineConfig conf = PipelineConfig.fromYamlFile(
"src/test/resources/tiny_raw_threshold_conf.yaml"
);
BasicBatchPipeline p = new BasicBatchPipeline(conf);
Explanation e = p.results();
assertEquals(2, e.getNumInliers());
}
}
15 changes: 15 additions & 0 deletions core/src/test/resources/tiny_raw_threshold_conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pipeline: "BasicBatchPipeline"

inputURI: "csv://src/test/resources/tiny.csv"

classifier: "raw_threshold"
metric: "usage"
predicate: "=="
value: 2.0

summarizer: "apriori"
attributes:
- "location"
- "version"
minRiskRatio: 10.0
minSupport: 0.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package edu.stanford.futuredata.macrobase.analysis.classify;

import edu.stanford.futuredata.macrobase.datamodel.DataFrame;
import edu.stanford.futuredata.macrobase.util.MacrobaseException;

import java.util.function.DoublePredicate;

/**
* RawThresholdClassifier classifies an outlier based on a predicate(e.g., equality, less than, greater than)
* and a hard-coded sentinel value. Unlike {@link PercentileClassifier}, outlier values are not determined based on a
* proportion of the values in the metric column. Instead, the outlier values are defined explicitly by the user in the
* conf.yaml file; for example:
* <code>
* classifier: "raw_threshold"
* metric: "usage"
* predicate: "=="
* value: 1.0
* </code>
* This would instantiate a RawThresholdClassifier that classifies every value in the "usage" column equal to 1.0
* as an outlier. Currently, we support six different predicates: "==", "!=", "<", ">", "<=", and ">=".
*/
public class RawThresholdClassifier extends Classifier {

private final DoublePredicate predicate;
private DataFrame output;

/**
* @param metricName Column on which to classifier outliers
* @param predicateStr Predicate used for classification: "==", "!=", "<", ">", "<=", or ">="
* @param sentinel Sentinel value used when evaluating the predicate to determine outlier
* @throws MacrobaseException
*/
public RawThresholdClassifier(final String metricName, final String predicateStr, final double sentinel)
throws MacrobaseException {
super(metricName);
this.predicate = getPredicate(predicateStr, sentinel);
}

/**
* @return Lambda function corresponding to the ``predicateStr''. The Lambda function takes in a single
* argument, which will correspond to the value in the metric column. (A closure is created around the ``sentinel''
* parameter.)
* @throws MacrobaseException
*/
private DoublePredicate getPredicate(final String predicateStr, final double sentinel) throws MacrobaseException {
switch (predicateStr) {
case "==":
return (double x) -> x == sentinel;
case "!=":
return (double x) -> x != sentinel;
case "<":
return (double x) -> x < sentinel;
case ">":
return (double x) -> x > sentinel;
case "<=":
return (double x) -> x <= sentinel;
case ">=":
return (double x) -> x >= sentinel;
default:
throw new MacrobaseException("RawThresholdClassifier: Predicate string " + predicateStr +
" not suppported.");
}
}

/**
* Scan through the metric column, and evaluate the predicate on every value in the column. The ``input'' DataFrame
* remains unmodified; a copy is created and all modifications are made on the copy.
* @throws Exception
*/
@Override
public void process(DataFrame input) throws Exception {
double[] metrics = input.getDoubleColumnByName(columnName);
int len = metrics.length;
output = input.copy();
double[] resultColumn = new double[len];
for (int i = 0; i < len; i++) {
final double curVal = metrics[i];
if (predicate.test(curVal)) {
resultColumn[i] = 1.0;
}
}
output.addDoubleColumn(outputColumnName, resultColumn);
}

@Override
public DataFrame getResults() {
return output;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package edu.stanford.futuredata.macrobase.analysis.classify;

import edu.stanford.futuredata.macrobase.datamodel.DataFrame;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertTrue;

/**
* Created by fabuzaid21 on 8/4/17.
*/
public class RawThresholdClassifierTest {

private static final int NUM_ROWS = 1000;
private static final int METRIC_CARDINALITY = 4;

private DataFrame df;

@Rule
public final ExpectedException exception = ExpectedException.none();

@Before
public void setUp() {
df = new DataFrame();
final int num_rows_per_value = NUM_ROWS / METRIC_CARDINALITY;
double[] vals = new double[NUM_ROWS];
double metricVal = 0.0;
int j = 0;
for (int i = 0; i < METRIC_CARDINALITY; ++i) {
for (; j < num_rows_per_value * (i + 1); ++j) {
vals[j] = metricVal;
}
metricVal += 1.0;
}
df.addDoubleColumn("val", vals);
}

@Test
public void testEquals() throws Exception {
assertEquals(NUM_ROWS, df.getNumRows());
RawThresholdClassifier pc = new RawThresholdClassifier("val", "==", 0.0);
pc.process(df);
DataFrame output = pc.getResults();
assertEquals(df.getNumRows(), output.getNumRows());
assertEquals(1, df.getSchema().getNumColumns());
assertEquals(2, output.getSchema().getNumColumns());

DataFrame outliers = output.filter(
pc.getOutputColumnName(), (double d) -> d != 0.0
);
int numOutliers = outliers.getNumRows();
assertTrue(numOutliers == NUM_ROWS / METRIC_CARDINALITY);
double[] vals = outliers.getDoubleColumnByName("val");
for (double val : vals) {
assertTrue(val == 0.0);
}
}

@Test
public void testNotEquals() throws Exception {
assertEquals(NUM_ROWS, df.getNumRows());
RawThresholdClassifier pc = new RawThresholdClassifier("val", "!=", 0.0);
pc.process(df);
DataFrame output = pc.getResults();
assertEquals(df.getNumRows(), output.getNumRows());
assertEquals(1, df.getSchema().getNumColumns());
assertEquals(2, output.getSchema().getNumColumns());

DataFrame outliers = output.filter(
pc.getOutputColumnName(), (double d) -> d != 0.0
);
int numOutliers = outliers.getNumRows();
assertTrue(numOutliers == (METRIC_CARDINALITY - 1) * NUM_ROWS / METRIC_CARDINALITY);
double[] vals = outliers.getDoubleColumnByName("val");
for (double val : vals) {
assertTrue(val != 0.0);
}
}

@Test
public void testLessThan() throws Exception {
assertEquals(NUM_ROWS, df.getNumRows());
RawThresholdClassifier pc = new RawThresholdClassifier("val", "<", 3.0);
pc.process(df);
DataFrame output = pc.getResults();
assertEquals(df.getNumRows(), output.getNumRows());
assertEquals(1, df.getSchema().getNumColumns());
assertEquals(2, output.getSchema().getNumColumns());

DataFrame outliers = output.filter(
pc.getOutputColumnName(), (double d) -> d != 0.0
);
int numOutliers = outliers.getNumRows();
assertTrue(numOutliers == (METRIC_CARDINALITY - 1) * NUM_ROWS / METRIC_CARDINALITY);
double[] vals = outliers.getDoubleColumnByName("val");
for (double val : vals) {
assertTrue(val < 3.0);
}
}

@Test
public void testGreaterThan() throws Exception {
assertEquals(NUM_ROWS, df.getNumRows());
RawThresholdClassifier pc = new RawThresholdClassifier("val", ">", 1.0);
pc.process(df);
DataFrame output = pc.getResults();
assertEquals(df.getNumRows(), output.getNumRows());
assertEquals(1, df.getSchema().getNumColumns());
assertEquals(2, output.getSchema().getNumColumns());

DataFrame outliers = output.filter(
pc.getOutputColumnName(), (double d) -> d != 0.0
);
int numOutliers = outliers.getNumRows();
assertTrue(numOutliers == NUM_ROWS / 2);
double[] vals = outliers.getDoubleColumnByName("val");
for (double val : vals) {
assertTrue(val > 1.0);
}
}


@Test
public void testLessThanOrEqual() throws Exception {
assertEquals(NUM_ROWS, df.getNumRows());
RawThresholdClassifier pc = new RawThresholdClassifier("val", "<=", 1.0);
pc.process(df);
DataFrame output = pc.getResults();
assertEquals(df.getNumRows(), output.getNumRows());
assertEquals(1, df.getSchema().getNumColumns());
assertEquals(2, output.getSchema().getNumColumns());

DataFrame outliers = output.filter(
pc.getOutputColumnName(), (double d) -> d != 0.0
);
int numOutliers = outliers.getNumRows();
assertTrue(numOutliers == NUM_ROWS / 2);
double[] vals = outliers.getDoubleColumnByName("val");
for (double val : vals) {
assertTrue(val <= 1.0);
}
}


@Test
public void testGreaterThanOrEqual() throws Exception {
assertEquals(NUM_ROWS, df.getNumRows());
RawThresholdClassifier pc = new RawThresholdClassifier("val", ">=", 3.0);
pc.process(df);
DataFrame output = pc.getResults();
assertEquals(df.getNumRows(), output.getNumRows());
assertEquals(1, df.getSchema().getNumColumns());
assertEquals(2, output.getSchema().getNumColumns());

DataFrame outliers = output.filter(
pc.getOutputColumnName(), (double d) -> d != 0.0
);
int numOutliers = outliers.getNumRows();
assertTrue(numOutliers == NUM_ROWS / METRIC_CARDINALITY);
double[] vals = outliers.getDoubleColumnByName("val");
for (double val : vals) {
assertTrue(val >= 3.0);
}
}
}

0 comments on commit bcd203c

Please sign in to comment.