-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First commit of RawThresholdClassifier
- Loading branch information
1 parent
6407090
commit bcd203c
Showing
5 changed files
with
312 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
pipeline: "BasicBatchPipeline" | ||
|
||
inputURI: "csv://src/test/resources/tiny.csv" | ||
|
||
classifier: "raw_threshold" | ||
metric: "usage" | ||
predicate: "==" | ||
value: 2.0 | ||
|
||
summarizer: "apriori" | ||
attributes: | ||
- "location" | ||
- "version" | ||
minRiskRatio: 10.0 | ||
minSupport: 0.2 |
89 changes: 89 additions & 0 deletions
89
...main/java/edu/stanford/futuredata/macrobase/analysis/classify/RawThresholdClassifier.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
package edu.stanford.futuredata.macrobase.analysis.classify; | ||
|
||
import edu.stanford.futuredata.macrobase.datamodel.DataFrame; | ||
import edu.stanford.futuredata.macrobase.util.MacrobaseException; | ||
|
||
import java.util.function.DoublePredicate; | ||
|
||
/** | ||
* RawThresholdClassifier classifies an outlier based on a predicate(e.g., equality, less than, greater than) | ||
* and a hard-coded sentinel value. Unlike {@link PercentileClassifier}, outlier values are not determined based on a | ||
* proportion of the values in the metric column. Instead, the outlier values are defined explicitly by the user in the | ||
* conf.yaml file; for example: | ||
* <code> | ||
* classifier: "raw_threshold" | ||
* metric: "usage" | ||
* predicate: "==" | ||
* value: 1.0 | ||
* </code> | ||
* This would instantiate a RawThresholdClassifier that classifies every value in the "usage" column equal to 1.0 | ||
* as an outlier. Currently, we support six different predicates: "==", "!=", "<", ">", "<=", and ">=". | ||
*/ | ||
public class RawThresholdClassifier extends Classifier { | ||
|
||
private final DoublePredicate predicate; | ||
private DataFrame output; | ||
|
||
/** | ||
* @param metricName Column on which to classifier outliers | ||
* @param predicateStr Predicate used for classification: "==", "!=", "<", ">", "<=", or ">=" | ||
* @param sentinel Sentinel value used when evaluating the predicate to determine outlier | ||
* @throws MacrobaseException | ||
*/ | ||
public RawThresholdClassifier(final String metricName, final String predicateStr, final double sentinel) | ||
throws MacrobaseException { | ||
super(metricName); | ||
this.predicate = getPredicate(predicateStr, sentinel); | ||
} | ||
|
||
/** | ||
* @return Lambda function corresponding to the ``predicateStr''. The Lambda function takes in a single | ||
* argument, which will correspond to the value in the metric column. (A closure is created around the ``sentinel'' | ||
* parameter.) | ||
* @throws MacrobaseException | ||
*/ | ||
private DoublePredicate getPredicate(final String predicateStr, final double sentinel) throws MacrobaseException { | ||
switch (predicateStr) { | ||
case "==": | ||
return (double x) -> x == sentinel; | ||
case "!=": | ||
return (double x) -> x != sentinel; | ||
case "<": | ||
return (double x) -> x < sentinel; | ||
case ">": | ||
return (double x) -> x > sentinel; | ||
case "<=": | ||
return (double x) -> x <= sentinel; | ||
case ">=": | ||
return (double x) -> x >= sentinel; | ||
default: | ||
throw new MacrobaseException("RawThresholdClassifier: Predicate string " + predicateStr + | ||
" not suppported."); | ||
} | ||
} | ||
|
||
/** | ||
* Scan through the metric column, and evaluate the predicate on every value in the column. The ``input'' DataFrame | ||
* remains unmodified; a copy is created and all modifications are made on the copy. | ||
* @throws Exception | ||
*/ | ||
@Override | ||
public void process(DataFrame input) throws Exception { | ||
double[] metrics = input.getDoubleColumnByName(columnName); | ||
int len = metrics.length; | ||
output = input.copy(); | ||
double[] resultColumn = new double[len]; | ||
for (int i = 0; i < len; i++) { | ||
final double curVal = metrics[i]; | ||
if (predicate.test(curVal)) { | ||
resultColumn[i] = 1.0; | ||
} | ||
} | ||
output.addDoubleColumn(outputColumnName, resultColumn); | ||
} | ||
|
||
@Override | ||
public DataFrame getResults() { | ||
return output; | ||
} | ||
} |
168 changes: 168 additions & 0 deletions
168
.../java/edu/stanford/futuredata/macrobase/analysis/classify/RawThresholdClassifierTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
package edu.stanford.futuredata.macrobase.analysis.classify; | ||
|
||
import edu.stanford.futuredata.macrobase.datamodel.DataFrame; | ||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
|
||
import static junit.framework.TestCase.assertEquals; | ||
import static junit.framework.TestCase.assertTrue; | ||
|
||
/** | ||
* Created by fabuzaid21 on 8/4/17. | ||
*/ | ||
public class RawThresholdClassifierTest { | ||
|
||
private static final int NUM_ROWS = 1000; | ||
private static final int METRIC_CARDINALITY = 4; | ||
|
||
private DataFrame df; | ||
|
||
@Rule | ||
public final ExpectedException exception = ExpectedException.none(); | ||
|
||
@Before | ||
public void setUp() { | ||
df = new DataFrame(); | ||
final int num_rows_per_value = NUM_ROWS / METRIC_CARDINALITY; | ||
double[] vals = new double[NUM_ROWS]; | ||
double metricVal = 0.0; | ||
int j = 0; | ||
for (int i = 0; i < METRIC_CARDINALITY; ++i) { | ||
for (; j < num_rows_per_value * (i + 1); ++j) { | ||
vals[j] = metricVal; | ||
} | ||
metricVal += 1.0; | ||
} | ||
df.addDoubleColumn("val", vals); | ||
} | ||
|
||
@Test | ||
public void testEquals() throws Exception { | ||
assertEquals(NUM_ROWS, df.getNumRows()); | ||
RawThresholdClassifier pc = new RawThresholdClassifier("val", "==", 0.0); | ||
pc.process(df); | ||
DataFrame output = pc.getResults(); | ||
assertEquals(df.getNumRows(), output.getNumRows()); | ||
assertEquals(1, df.getSchema().getNumColumns()); | ||
assertEquals(2, output.getSchema().getNumColumns()); | ||
|
||
DataFrame outliers = output.filter( | ||
pc.getOutputColumnName(), (double d) -> d != 0.0 | ||
); | ||
int numOutliers = outliers.getNumRows(); | ||
assertTrue(numOutliers == NUM_ROWS / METRIC_CARDINALITY); | ||
double[] vals = outliers.getDoubleColumnByName("val"); | ||
for (double val : vals) { | ||
assertTrue(val == 0.0); | ||
} | ||
} | ||
|
||
@Test | ||
public void testNotEquals() throws Exception { | ||
assertEquals(NUM_ROWS, df.getNumRows()); | ||
RawThresholdClassifier pc = new RawThresholdClassifier("val", "!=", 0.0); | ||
pc.process(df); | ||
DataFrame output = pc.getResults(); | ||
assertEquals(df.getNumRows(), output.getNumRows()); | ||
assertEquals(1, df.getSchema().getNumColumns()); | ||
assertEquals(2, output.getSchema().getNumColumns()); | ||
|
||
DataFrame outliers = output.filter( | ||
pc.getOutputColumnName(), (double d) -> d != 0.0 | ||
); | ||
int numOutliers = outliers.getNumRows(); | ||
assertTrue(numOutliers == (METRIC_CARDINALITY - 1) * NUM_ROWS / METRIC_CARDINALITY); | ||
double[] vals = outliers.getDoubleColumnByName("val"); | ||
for (double val : vals) { | ||
assertTrue(val != 0.0); | ||
} | ||
} | ||
|
||
@Test | ||
public void testLessThan() throws Exception { | ||
assertEquals(NUM_ROWS, df.getNumRows()); | ||
RawThresholdClassifier pc = new RawThresholdClassifier("val", "<", 3.0); | ||
pc.process(df); | ||
DataFrame output = pc.getResults(); | ||
assertEquals(df.getNumRows(), output.getNumRows()); | ||
assertEquals(1, df.getSchema().getNumColumns()); | ||
assertEquals(2, output.getSchema().getNumColumns()); | ||
|
||
DataFrame outliers = output.filter( | ||
pc.getOutputColumnName(), (double d) -> d != 0.0 | ||
); | ||
int numOutliers = outliers.getNumRows(); | ||
assertTrue(numOutliers == (METRIC_CARDINALITY - 1) * NUM_ROWS / METRIC_CARDINALITY); | ||
double[] vals = outliers.getDoubleColumnByName("val"); | ||
for (double val : vals) { | ||
assertTrue(val < 3.0); | ||
} | ||
} | ||
|
||
@Test | ||
public void testGreaterThan() throws Exception { | ||
assertEquals(NUM_ROWS, df.getNumRows()); | ||
RawThresholdClassifier pc = new RawThresholdClassifier("val", ">", 1.0); | ||
pc.process(df); | ||
DataFrame output = pc.getResults(); | ||
assertEquals(df.getNumRows(), output.getNumRows()); | ||
assertEquals(1, df.getSchema().getNumColumns()); | ||
assertEquals(2, output.getSchema().getNumColumns()); | ||
|
||
DataFrame outliers = output.filter( | ||
pc.getOutputColumnName(), (double d) -> d != 0.0 | ||
); | ||
int numOutliers = outliers.getNumRows(); | ||
assertTrue(numOutliers == NUM_ROWS / 2); | ||
double[] vals = outliers.getDoubleColumnByName("val"); | ||
for (double val : vals) { | ||
assertTrue(val > 1.0); | ||
} | ||
} | ||
|
||
|
||
@Test | ||
public void testLessThanOrEqual() throws Exception { | ||
assertEquals(NUM_ROWS, df.getNumRows()); | ||
RawThresholdClassifier pc = new RawThresholdClassifier("val", "<=", 1.0); | ||
pc.process(df); | ||
DataFrame output = pc.getResults(); | ||
assertEquals(df.getNumRows(), output.getNumRows()); | ||
assertEquals(1, df.getSchema().getNumColumns()); | ||
assertEquals(2, output.getSchema().getNumColumns()); | ||
|
||
DataFrame outliers = output.filter( | ||
pc.getOutputColumnName(), (double d) -> d != 0.0 | ||
); | ||
int numOutliers = outliers.getNumRows(); | ||
assertTrue(numOutliers == NUM_ROWS / 2); | ||
double[] vals = outliers.getDoubleColumnByName("val"); | ||
for (double val : vals) { | ||
assertTrue(val <= 1.0); | ||
} | ||
} | ||
|
||
|
||
@Test | ||
public void testGreaterThanOrEqual() throws Exception { | ||
assertEquals(NUM_ROWS, df.getNumRows()); | ||
RawThresholdClassifier pc = new RawThresholdClassifier("val", ">=", 3.0); | ||
pc.process(df); | ||
DataFrame output = pc.getResults(); | ||
assertEquals(df.getNumRows(), output.getNumRows()); | ||
assertEquals(1, df.getSchema().getNumColumns()); | ||
assertEquals(2, output.getSchema().getNumColumns()); | ||
|
||
DataFrame outliers = output.filter( | ||
pc.getOutputColumnName(), (double d) -> d != 0.0 | ||
); | ||
int numOutliers = outliers.getNumRows(); | ||
assertTrue(numOutliers == NUM_ROWS / METRIC_CARDINALITY); | ||
double[] vals = outliers.getDoubleColumnByName("val"); | ||
for (double val : vals) { | ||
assertTrue(val >= 3.0); | ||
} | ||
} | ||
} |