diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/PercentileCalibratorTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/PercentileCalibratorTest.scala index 48f3ad5746..cc2378f08a 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/PercentileCalibratorTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/PercentileCalibratorTest.scala @@ -34,7 +34,7 @@ import com.salesforce.op._ import com.salesforce.op.features.types._ import com.salesforce.op.features.Feature import com.salesforce.op.stages.base.unary.UnaryModel -import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} +import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder, TestSparkContext} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.RichMetadata._ import org.apache.spark.ml.{Estimator, Transformer} @@ -49,10 +49,25 @@ import org.scalatest.FlatSpec import scala.util.Random @RunWith(classOf[JUnitRunner]) -class PercentileCalibratorTest extends FlatSpec with TestSparkContext { +class PercentileCalibratorTest extends OpEstimatorSpec[RealNN, UnaryModel[RealNN, RealNN], PercentileCalibrator] { + import spark.implicits._ - Spec[PercentileCalibrator] should "return a minimum calibrated score of 0 and max of 99 when buckets is 100" in { + val testData = Seq(10, 100, 1000).map(_.toRealNN) + + val (inputData, testF) = TestFeatureBuilder(testData) + + /** + * Estimator instance to be tested + */ + override val estimator: PercentileCalibrator = new PercentileCalibrator().setInput(testF) + + /** + * Expected result of the transformer applied on the Input Dataset + */ + override val expectedResult: Seq[RealNN] = Seq(33.toRealNN, 66.toRealNN, 99.toRealNN) + + it should "return a minimum calibrated score of 0 and max of 99 when buckets is 100" in { val data = (0 until 1000).map(i => i.toLong.toIntegral -> Random.nextDouble.toRealNN) val (scoresDF, f1, f2): (DataFrame, Feature[Integral], Feature[RealNN]) = TestFeatureBuilder(data) val percentile = f2.toPercentile() @@ -61,8 +76,8 @@ class PercentileCalibratorTest extends FlatSpec with TestSparkContext { val scoresTransformed = model.asInstanceOf[Transformer].transform(scoresDF) percentile.name shouldBe percentile.originStage.getOutputFeatureName - scoresTransformed.select(min(percentile.name)).first.getDouble(0) should equal (0.0) - scoresTransformed.select(max(percentile.name)).first.getDouble(0) should equal (99.0) + scoresTransformed.select(min(percentile.name)).first.getDouble(0) should equal(0.0) + scoresTransformed.select(max(percentile.name)).first.getDouble(0) should equal(99.0) } it should "produce the calibration map metadata" in { @@ -89,7 +104,7 @@ class PercentileCalibratorTest extends FlatSpec with TestSparkContext { val model = percentile.originStage.asInstanceOf[Estimator[_]].fit(scoresDF) val scoresTransformed = model.asInstanceOf[Transformer].transform(scoresDF) - scoresTransformed.select(max(percentile.name)).first.getDouble(0) should equal (99.0) + scoresTransformed.select(max(percentile.name)).first.getDouble(0) should equal(99.0) } it should "return a maximum calibrated score of 99 when calibrating with less than 100" in { @@ -100,7 +115,7 @@ class PercentileCalibratorTest extends FlatSpec with TestSparkContext { val model = percentile.originStage.asInstanceOf[Estimator[_]].fit(scoresDF) val scoresTransformed = model.asInstanceOf[Transformer].transform(scoresDF) - scoresTransformed.select(max(percentile.name)).first.getDouble(0) should equal (99.0) + scoresTransformed.select(max(percentile.name)).first.getDouble(0) should equal(99.0) } it should "return all scores from 0 to 99 in increments of 1" in { @@ -117,7 +132,7 @@ class PercentileCalibratorTest extends FlatSpec with TestSparkContext { val checkSet = (0 to 99).map(_.toReal).toSet - scoreCounts.collect(percentile).toSet should equal (checkSet) + scoreCounts.collect(percentile).toSet should equal(checkSet) } it should "return a uniform distribution of scores" in { @@ -149,6 +164,7 @@ class PercentileCalibratorTest extends FlatSpec with TestSparkContext { val indicesByProb = scoresTransformed.orderBy(f2.name).collect(f1).deep val indicesByPerc = scoresTransformed.orderBy(percentile.name, f2.name).collect(f1).deep - indicesByProb should equal (indicesByPerc) + indicesByProb should equal(indicesByPerc) } + }