Skip to content

Commit

Permalink
Make decision tree numeric bucketizer tests less flaky (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jauntbox committed Feb 12, 2019
1 parent e28122c commit a0af563
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ class DecisionTreeNumericBucketizerTest extends OpEstimatorSpec[OPVector,
).map(_.toOPVector)

trait NormalData {
val numericData: Seq[Real] = RandomReal.normal[Real]().withProbabilityOfEmpty(0.2).limit(1000)
val labelData: Seq[RealNN] = RandomBinary(probabilityOfSuccess = 0.4).limit(1000).map(_.toDouble.toRealNN(0.0))
val total = 1000
val numericData: Seq[Real] = RandomReal.normal[Real]().withProbabilityOfEmpty(0.2).limit(total)
val labelData: Seq[RealNN] = RandomBinary(probabilityOfSuccess = 0.4).limit(total).map(_.toDouble.toRealNN(0.0))
val (ds, numeric, label) = TestFeatureBuilder[Real, RealNN](numericData zip labelData)
val expectedSplits = Array.empty[Double]
lazy val modelLocation = tempDir + "/dt-buck-test-model-" + org.joda.time.DateTime.now().getMillis
Expand All @@ -83,10 +84,14 @@ class DecisionTreeNumericBucketizerTest extends OpEstimatorSpec[OPVector,
val expectedSplits = Array.empty[Double]
}

// Generate uniformly spaced data so that the splits found by the decision tree will be deterministic. We still
// won't get splits exactly at the midpoints between data points (eg. 14.95, 35.95, 90.95) due to the way Spark
// calculates splits by binning. The default bins are 32, which limits the resolution of the splits.
trait UniformData {
val total = 1000
val (min, max) = (0.0, 100.0)
val currencyData: Seq[Currency] =
RandomReal.uniform[Currency](minValue = min, maxValue = max).withProbabilityOfEmpty(0.1).limit(1000)
val currencyData: Seq[Currency] = (0 until total).map(x => (x * max/total).toCurrency)

val labelData = currencyData.map(c => {
c.value.map {
case v if v < 15 => 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ class DecisionTreeNumericMapBucketizerTest extends OpEstimatorSpec[OPVector,
val (min, max) = (0.0, 100.0)
val currencies: RandomReal[Currency] =
RandomReal.uniform[Currency](minValue = min, maxValue = max).withProbabilityOfEmpty(0.1)
val correlated = currencies.limit(total)
val correlated: Seq[Currency] = (0 until total).map(x => (x * max/total).toCurrency)

val labelData = correlated.map(c => {
c.value.map {
case v if v < 15 => 0.0
Expand Down

0 comments on commit a0af563

Please sign in to comment.