diff --git a/core/src/main/scala/com/salesforce/op/ModelInsights.scala b/core/src/main/scala/com/salesforce/op/ModelInsights.scala index ff8322e43a..b9ef39b150 100644 --- a/core/src/main/scala/com/salesforce/op/ModelInsights.scala +++ b/core/src/main/scala/com/salesforce/op/ModelInsights.scala @@ -387,7 +387,8 @@ case object ModelInsights { val typeHints = FullTypeHints(List( classOf[Continuous], classOf[Discrete], classOf[DataBalancerSummary], classOf[DataCutterSummary], classOf[DataSplitterSummary], - classOf[SingleMetric], classOf[MultiMetrics], classOf[BinaryClassificationMetrics], classOf[ThresholdMetrics], + classOf[SingleMetric], classOf[MultiMetrics], classOf[BinaryClassificationMetrics], + classOf[BinaryClassificationBinMetrics], classOf[ThresholdMetrics], classOf[MultiClassificationMetrics], classOf[RegressionMetrics] )) val evalMetricsSerializer = new CustomSerializer[EvalMetric](_ => diff --git a/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala b/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala index cf5c7fb48c..a60e97b51d 100644 --- a/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala +++ b/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala @@ -52,6 +52,13 @@ object Evaluators { */ def apply(): OpBinaryClassificationEvaluator = auROC() + /* + * Brier Score for the prediction + */ + def brierScore(): OpBinScoreEvaluator = + new OpBinScoreEvaluator( + name = BinaryClassEvalMetrics.brierScore, isLargerBetter = true) + /** * Area under ROC */ diff --git a/core/src/main/scala/com/salesforce/op/evaluators/OpBinScoreEvaluator.scala b/core/src/main/scala/com/salesforce/op/evaluators/OpBinScoreEvaluator.scala new file mode 100644 index 0000000000..5be7948180 --- /dev/null +++ b/core/src/main/scala/com/salesforce/op/evaluators/OpBinScoreEvaluator.scala @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.salesforce.op.evaluators + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.salesforce.op.UID +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.{Dataset, Row} +import org.slf4j.LoggerFactory +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.DoubleType +import com.twitter.algebird.Operators._ +import com.twitter.algebird.Monoid._ +import org.apache.spark.rdd.RDD + +/** + * + * Evaluator for Binary Classification which provides statistics about the predicted scores. + * This evaluator creates the specified number of bins and computes the statistics for each bin + * and returns BinaryClassificationBinMetrics, which contains + * + * Total number of data points per bin + * Average Score per bin + * Average Conversion rate per bin + * Bin Centers for each bin + * BrierScore for the overall dataset is also computed, which is a default metric as well. + * + * @param name name of default metric + * @param isLargerBetter is metric better if larger + * @param uid uid for instance + */ +private[op] class OpBinScoreEvaluator +( + override val name: EvalMetric = OpEvaluatorNames.BinScore, + override val isLargerBetter: Boolean = true, + override val uid: String = UID[OpBinScoreEvaluator], + val numBins: Int = 100 +) extends OpBinaryClassificationEvaluatorBase[BinaryClassificationBinMetrics](uid = uid) { + + require(numBins > 0, "numBins must be positive") + @transient private lazy val log = LoggerFactory.getLogger(this.getClass) + + def getDefaultMetric: BinaryClassificationBinMetrics => Double = _.brierScore + + override def evaluateAll(data: Dataset[_]): BinaryClassificationBinMetrics = { + val labelColumnName = getLabelCol + val dataProcessed = makeDataToUse(data, labelColumnName) + + val rdd = dataProcessed.select(col(getProbabilityCol), col(labelColumnName).cast(DoubleType)).rdd + if (rdd.isEmpty()) { + log.error("The dataset is empty. Returning empty metrics") + BinaryClassificationBinMetrics(0.0, Seq(), Seq(), Seq(), Seq()) + } else { + val scoreAndLabels = rdd.map { + case Row(prob: Vector, label: Double) => (prob(1), label) + case Row(prob: Double, label: Double) => (prob, label) + } + + val (maxScore, minScore) = scoreAndLabels.map { + case (score , _) => (score, score) + }.fold(1.0, 0.0) { + case((maxVal, minVal), (scoreMax, scoreMin)) => { + (math.max(maxVal, scoreMax), math.min(minVal, scoreMin)) + } + } + + // Finding stats per bin -> avg score, avg conv rate, + // total num of data points and overall brier score. + val stats = scoreAndLabels.map { + case (score, label) => + (getBinIndex(score, minScore, maxScore), (score, label, 1L, math.pow((score - label), 2))) + }.reduceByKey(_ + _).map { + case (bin, (scoreSum, labelSum, count, squaredError)) => + (bin, scoreSum / count, labelSum / count, count, squaredError) + }.collect() + + val (averageScore, averageConversionRate, numberOfDataPoints, brierScoreSum, numberOfPoints) = + stats.foldLeft((new Array[Double](numBins), new Array[Double](numBins), new Array[Long](numBins), 0.0, 0L)) { + case ((score, convRate, dataPoints, brierScoreSum, totalPoints), + (binIndex, avgScore, avgConvRate, counts, squaredError)) => { + + score(binIndex) = avgScore + convRate(binIndex) = avgConvRate + dataPoints(binIndex) = counts + + (score, convRate, dataPoints, brierScoreSum + squaredError, totalPoints + counts) + } + } + + // binCenters is the center point in each bin. + // e.g., for bins [(0.0 - 0.5), (0.5 - 1.0)], bin centers are [0.25, 0.75]. + val diff = maxScore - minScore + val binCenters = (for {i <- 0 to numBins-1} yield (minScore + ((diff * i) / numBins) + (diff / (2 * numBins)))) + + val metrics = BinaryClassificationBinMetrics( + brierScore = brierScoreSum / numberOfPoints, + binCenters = binCenters, + numberOfDataPoints = numberOfDataPoints, + averageScore = averageScore, + averageConversionRate = averageConversionRate + ) + + log.info("Evaluated metrics: {}", metrics.toString) + metrics + } + } + + // getBinIndex finds which bin the score associates with. + private def getBinIndex(score: Double, minScore: Double, maxScore: Double): Int = { + val binIndex = (numBins * (score - minScore) / (maxScore - minScore)).toInt + math.min(numBins - 1, binIndex) + } +} + +/** + * Metrics of BinaryClassificationBinMetrics + * + * @param binCenters center of each bin + * @param numberOfDataPoints total number of data points in each bin + * @param averageScore average score in each bin + * @param averageConversionRate average conversion rate in each bin + * @param brierScore brier score for overall dataset + */ +case class BinaryClassificationBinMetrics +( + brierScore: Double, + @JsonDeserialize(contentAs = classOf[java.lang.Double]) + binCenters: Seq[Double], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + numberOfDataPoints: Seq[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Double]) + averageScore: Seq[Double], + @JsonDeserialize(contentAs = classOf[java.lang.Double]) + averageConversionRate: Seq[Double] +) extends EvaluationMetrics diff --git a/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala b/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala index 10cc86d9a0..bd68089643 100644 --- a/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala +++ b/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala @@ -319,6 +319,7 @@ object BinaryClassEvalMetrics extends Enum[ClassificationEvalMetric] { case object TN extends ClassificationEvalMetric("TN", "true negative") case object FP extends ClassificationEvalMetric("FP", "false positive") case object FN extends ClassificationEvalMetric("FN", "false negative") + case object brierScore extends ClassificationEvalMetric("brierscore", "brier score") } /** @@ -370,6 +371,7 @@ sealed abstract class OpEvaluatorNames object OpEvaluatorNames extends Enum[OpEvaluatorNames] { val values: Seq[OpEvaluatorNames] = findValues case object Binary extends OpEvaluatorNames("binEval", "binary evaluation metrics") + case object BinScore extends OpEvaluatorNames("binScoreEval", "bin score evaluation metrics") case object Multi extends OpEvaluatorNames("multiEval", "multiclass evaluation metrics") case object Regression extends OpEvaluatorNames("regEval", "regression evaluation metrics") case class Custom(name: String, humanName: String) extends OpEvaluatorNames(name, humanName) { diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/classification/BinaryClassificationModelSelector.scala b/core/src/main/scala/com/salesforce/op/stages/impl/classification/BinaryClassificationModelSelector.scala index 5ff7b96ee8..4986b3d5b8 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/classification/BinaryClassificationModelSelector.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/classification/BinaryClassificationModelSelector.scala @@ -152,7 +152,7 @@ case object BinaryClassificationModelSelector extends ModelSelectorFactory { numFolds = numFolds, seed = seed, validationMetric, stratify = stratify, parallelism = parallelism ) selector(cv, splitter = splitter, - trainTestEvaluators = Seq(new OpBinaryClassificationEvaluator) ++ trainTestEvaluators, + trainTestEvaluators = Seq(new OpBinaryClassificationEvaluator, new OpBinScoreEvaluator) ++ trainTestEvaluators, modelTypesToUse = modelTypesToUse, modelsAndParameters = modelsAndParameters) } diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/selector/ModelSelectorSummary.scala b/core/src/main/scala/com/salesforce/op/stages/impl/selector/ModelSelectorSummary.scala index ffe62d6406..43b4c470d7 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/selector/ModelSelectorSummary.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/selector/ModelSelectorSummary.scala @@ -241,6 +241,8 @@ case object ModelSelectorSummary { nm match { case OpEvaluatorNames.Binary.humanFriendlyName => nm -> JsonUtils.fromString[BinaryClassificationMetrics](valsJson).get + case OpEvaluatorNames.BinScore.humanFriendlyName => + nm -> JsonUtils.fromString[BinaryClassificationBinMetrics](valsJson).get case OpEvaluatorNames.Multi.humanFriendlyName => nm -> JsonUtils.fromString[MultiClassificationMetrics](valsJson).get case OpEvaluatorNames.Regression.humanFriendlyName => @@ -269,11 +271,13 @@ object ProblemType extends Enum[ProblemType] { def fromEvalMetrics(eval: EvaluationMetrics): ProblemType = { eval match { case _: BinaryClassificationMetrics => ProblemType.BinaryClassification + case _: BinaryClassificationBinMetrics => ProblemType.BinaryClassification case _: MultiClassificationMetrics => ProblemType.MultiClassification case _: RegressionMetrics => ProblemType.Regression case m: MultiMetrics => val keys = m.metrics.keySet if (keys.exists(_.contains(OpEvaluatorNames.Binary.humanFriendlyName))) ProblemType.BinaryClassification + else if (keys.exists(_.contains(OpEvaluatorNames.BinScore.humanFriendlyName))) ProblemType.BinaryClassification else if (keys.exists(_.contains(OpEvaluatorNames.Multi.humanFriendlyName))) ProblemType.MultiClassification else if (keys.exists(_.contains(OpEvaluatorNames.Regression.humanFriendlyName))) ProblemType.Regression else ProblemType.Unknown diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala index 3c29c26f00..b3471cc855 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala @@ -375,7 +375,7 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest { val prettySummary = fittedWorkflow.summaryPretty() log.info(prettySummary) prettySummary should include("Selected Model - OpLogisticRegression") - prettySummary should include("area under precision-recall | 1.0 | 0.0") + prettySummary should include("area under precision-recall | 1.0 | 0.0") prettySummary should include("Model Evaluation Metrics") prettySummary should include("Top Model Insights") prettySummary should include("Top Positive Correlations") diff --git a/core/src/test/scala/com/salesforce/op/evaluators/EvaluatorsTest.scala b/core/src/test/scala/com/salesforce/op/evaluators/EvaluatorsTest.scala index db341d60e3..0a53ebee1f 100644 --- a/core/src/test/scala/com/salesforce/op/evaluators/EvaluatorsTest.scala +++ b/core/src/test/scala/com/salesforce/op/evaluators/EvaluatorsTest.scala @@ -97,6 +97,9 @@ class EvaluatorsTest extends FlatSpec with TestSparkContext { val opBinaryMetrics = new OpBinaryClassificationEvaluator().setLabelCol(test_label) .setPredictionCol(pred).evaluateAll(transformedData) + val opBinScoreMetrics = new OpBinScoreEvaluator().setLabelCol(test_label) + .setPredictionCol(pred).evaluateAll(transformedData) + val sparkMultiEvaluator = new MulticlassClassificationEvaluator().setLabelCol(test_label.name) .setPredictionCol(predValue.name) @@ -115,6 +118,8 @@ class EvaluatorsTest extends FlatSpec with TestSparkContext { evaluateBinaryMetric(Evaluators.BinaryClassification.recall()) shouldBe opBinaryMetrics.Recall evaluateBinaryMetric(Evaluators.BinaryClassification.f1()) shouldBe opBinaryMetrics.F1 evaluateBinaryMetric(Evaluators.BinaryClassification.error()) shouldBe opBinaryMetrics.Error + + evaluateBinScoreMetric(Evaluators.BinaryClassification.brierScore()) shouldBe opBinScoreMetrics.brierScore } it should "have a multi classification factory" in { @@ -148,6 +153,9 @@ class EvaluatorsTest extends FlatSpec with TestSparkContext { def evaluateBinaryMetric(binEval: OpBinaryClassificationEvaluator): Double = binEval.setLabelCol(test_label) .setPredictionCol(pred).evaluate(transformedData3) + def evaluateBinScoreMetric(binEval: OpBinScoreEvaluator): Double = binEval.setLabelCol(test_label) + .setPredictionCol(pred).evaluate(transformedData3) + def evaluateSparkBinaryMetric(metricName: String): Double = sparkBinaryEvaluator.setMetricName(metricName) .evaluate(transformedData3) diff --git a/core/src/test/scala/com/salesforce/op/evaluators/OpBinScoreEvaluatorTest.scala b/core/src/test/scala/com/salesforce/op/evaluators/OpBinScoreEvaluatorTest.scala new file mode 100644 index 0000000000..acf8182bdd --- /dev/null +++ b/core/src/test/scala/com/salesforce/op/evaluators/OpBinScoreEvaluatorTest.scala @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.evaluators + +import com.salesforce.op.features.types.Prediction +import com.salesforce.op.features.types._ +import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} +import org.junit.runner.RunWith +import org.scalatest.FlatSpec +import org.scalatest.junit.JUnitRunner + +@RunWith(classOf[JUnitRunner]) +class OpBinScoreEvaluatorTest extends FlatSpec with TestSparkContext { + + val (dataset, prediction, label) = TestFeatureBuilder( + Seq ( + Prediction(1.0, Array(10.0, 10.0), Array(0.0001, 0.99999)) -> 1.0.toRealNN, + Prediction(1.0, Array(10.0, 10.0), Array(0.0001, 0.99999)) -> 1.0.toRealNN, + Prediction(1.0, Array(10.0, 10.0), Array(0.99560, 0.00541)) -> 0.0.toRealNN, + Prediction(1.0, Array(10.0, 10.0), Array(0.30, 0.70)) -> 0.0.toRealNN, + Prediction(0.0, Array(10.0, 10.0), Array(0.999, 0.001)) -> 0.0.toRealNN + ) + ) + + val (dataSkewed, predictionSkewedData, labelSkewedData) = TestFeatureBuilder( + Seq ( + Prediction(1.0, Array(10.0, 10.0), Array(0.0001, 0.99999)) -> 1.0.toRealNN, + Prediction(1.0, Array(10.0, 10.0), Array(0.0001, 0.99999)) -> 1.0.toRealNN, + Prediction(1.0, Array(10.0, 10.0), Array(0.001, 0.9987)) -> 1.0.toRealNN, + Prediction(1.0, Array(10.0, 10.0), Array(0.0541, 0.946)) -> 1.0.toRealNN + ) + ) + + val (emptyData, predictionEmptyData, labelEmptyData) = TestFeatureBuilder[Prediction, RealNN](Seq()) + + val (outOfBoundScoreDataset, outOfBoundScoreprediction, outOfBoundScorelabel) = TestFeatureBuilder( + Seq ( + Prediction(1.0, Array(0.0001, -0.99999), Array.emptyDoubleArray) -> 0.0.toRealNN, + Prediction(1.0, Array(0.0001, 1.99999), Array.emptyDoubleArray) -> 1.0.toRealNN, + Prediction(1.0, Array(0.0001, 12.0), Array.emptyDoubleArray) -> 1.0.toRealNN + ) + ) + + Spec[OpBinScoreEvaluator] should "return the bin metrics" in { + val metrics = new OpBinScoreEvaluator(numBins = 4) + .setLabelCol(label.name).setPredictionCol(prediction.name).evaluateAll(dataset) + + metrics shouldBe BinaryClassificationBinMetrics( + 0.09800605366, + Seq(0.125, 0.375, 0.625, 0.875), + Seq(2, 0, 1, 2), + Seq(0.003205, 0.0, 0.7, 0.99999), + Seq(0.0, 0.0, 0.0, 1.0)) + } + + it should "evaluate bin metrics for scores not between 0 and 1" in { + val metrics = new OpBinScoreEvaluator(numBins = 4) + .setLabelCol(outOfBoundScorelabel.name).setPredictionCol(outOfBoundScoreprediction.name) + .evaluateAll(outOfBoundScoreDataset) + + metrics shouldBe BinaryClassificationBinMetrics( + 40.999986666733335, + Seq(0.62500875, 3.87500625, 7.125003749999999, 10.37500125), + Seq(2, 0, 0, 1), + Seq(0.49999999999999994, 0.0, 0.0, 12.0), + Seq(0.5, 0.0, 0.0, 1.0)) + } + + it should "error on invalid number of bins" in { + assertThrows[IllegalArgumentException] { + new OpBinScoreEvaluator(numBins = 0) + .setLabelCol(label.name).setPredictionCol(prediction.name).evaluateAll(dataset) + } + } + + it should "evaluate the empty data" in { + val metrics = new OpBinScoreEvaluator(numBins = 10) + .setLabelCol(labelEmptyData.name).setPredictionCol(predictionEmptyData.name).evaluateAll(emptyData) + + metrics shouldBe BinaryClassificationBinMetrics(0.0, Seq(), Seq(), Seq(), Seq()) + } + + it should "evaluate bin metrics for skewed data" in { + val metrics = new OpBinScoreEvaluator(numBins = 5) + .setLabelCol(labelSkewedData.name).setPredictionCol(predictionSkewedData.name).evaluateAll(dataSkewed) + + metrics shouldBe BinaryClassificationBinMetrics( + 7.294225500000013E-4, + Seq(0.1, 0.30000000000000004, 0.5, 0.7, 0.9), + Seq(0, 0, 0, 0, 4), + Seq(0.0, 0.0, 0.0, 0.0, 0.98617), + Seq(0.0, 0.0, 0.0, 0.0, 1.0)) + } + + it should "evaluate the default metric as BrierScore" in { + val evaluator = new OpBinScoreEvaluator(numBins = 4) + .setLabelCol(label.name).setPredictionCol(prediction.name) + + evaluator.getDefaultMetric(evaluator.evaluateAll(dataset)) shouldBe 0.09800605366 + } +} diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/selector/ModelSelectorTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/selector/ModelSelectorTest.scala index 56c916ddcd..ad1307386c 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/selector/ModelSelectorTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/selector/ModelSelectorTest.scala @@ -120,7 +120,7 @@ class ModelSelectorTest extends OpEstimatorSpec[Prediction, SelectedModel, Model numFolds = 3, seed = seed, Evaluators.BinaryClassification.auPR(), stratify = false, parallelism = 1), splitter = Option(DataBalancer(sampleFraction = 0.5, seed = 11L)), models = Seq(lr -> Array.empty[ParamMap]), - evaluators = Seq(new OpBinaryClassificationEvaluator) + evaluators = Seq(new OpBinaryClassificationEvaluator, new OpBinScoreEvaluator) ).setInput(feature1, feature2) val expectedResult = Seq( @@ -144,7 +144,7 @@ class ModelSelectorTest extends OpEstimatorSpec[Prediction, SelectedModel, Model validator = validatorCV, splitter = Option(DataBalancer(sampleFraction = 0.5, seed = 11L)), models = Seq(lr -> lrParams, rf -> rfParams), - evaluators = Seq(new OpBinaryClassificationEvaluator) + evaluators = Seq(new OpBinaryClassificationEvaluator, new OpBinScoreEvaluator()) ).setInput(label, features) val model = testEstimator.fit(data) @@ -229,7 +229,7 @@ class ModelSelectorTest extends OpEstimatorSpec[Prediction, SelectedModel, Model validator = validatorCV, splitter = Option(DataBalancer(sampleFraction = 0.5, seed = 11L)), models = Seq(test -> testParams), - evaluators = Seq(new OpBinaryClassificationEvaluator) + evaluators = Seq(new OpBinaryClassificationEvaluator, new OpBinScoreEvaluator()) ).setInput(label, features) val model = testEstimator.fit(data)