diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/GeolocationVectorizer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/GeolocationVectorizer.scala index 7192842b4b..489f010015 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/GeolocationVectorizer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/GeolocationVectorizer.scala @@ -139,7 +139,10 @@ final class GeolocationVectorizerModel private[op] def transformFn: Seq[Geolocation] => OPVector = row => { val replaced = if (!trackNulls) { - row.zip(fillValues).flatMap { case (r, m) => if (r.isEmpty) m else r.value } + row.zip(fillValues).flatMap { case (r, m) => + val meanToUse: Seq[Double] = if (m.isEmpty) RepresentationOfEmpty else m + if (r.isEmpty) meanToUse else r.value + } } else { row.zip(fillValues).flatMap { case (r, m) => diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/AttributeAsserts.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/AttributeAsserts.scala index 8e84d54a5a..0cb29643ab 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/AttributeAsserts.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/AttributeAsserts.scala @@ -37,19 +37,27 @@ import org.scalatest.{Assertion, Matchers} trait AttributeAsserts { self: Matchers => + /** * Assert if attributes are nominal or not * - * @param schema + * @param schema field schema with attributes attached * @param expectedNominal Expected array of booleans. True if the field is nominal, false if not. * @param output the output OPVector associated with the column */ final def assertNominal(schema: StructField, expectedNominal: Array[Boolean], output: Array[OPVector]): Assertion = { - val attributes = AttributeGroup.fromStructField(schema).attributes for { - x <- output + (x, i) <- output.zipWithIndex + _ = withClue(s"Output vector $i and expectedNominal arrays are not of the same length:") { + x.value.size shouldBe expectedNominal.length + } (value, nominal) <- x.value.toArray.zip(expectedNominal) - } if (nominal) value should (be (0.0) or be (1.0)) - attributes.map(_.map(_.isNominal).toSeq) shouldBe Some(expectedNominal.toSeq) + } if (nominal) value should (be(0.0) or be(1.0)) + + val attributes = AttributeGroup.fromStructField(schema).attributes + withClue("Field attributes were not set or not as expected:") { + attributes.map(_.map(_.isNominal).toSeq) shouldBe Some(expectedNominal.toSeq) + } } + } diff --git a/features/src/main/scala/com/salesforce/op/aggregators/Geolocation.scala b/features/src/main/scala/com/salesforce/op/aggregators/Geolocation.scala index 9b923b4c8d..83bf01f67b 100644 --- a/features/src/main/scala/com/salesforce/op/aggregators/Geolocation.scala +++ b/features/src/main/scala/com/salesforce/op/aggregators/Geolocation.scala @@ -78,9 +78,9 @@ case object GeolocationMidpoint trait GeolocationFunctions { - val Zero: Array[Double] = Array.fill[Double](4)(0.0) + val Zero: Array[Double] = new Array[Double](4) - def isNone(data: Array[Double]): Boolean = data(3) == 0 + def isNone(data: Array[Double]): Boolean = data(3) == 0.0 /** * Prepare method to be used in the MonoidAggregator for Geolocation objects @@ -93,7 +93,7 @@ trait GeolocationFunctions { if (input.isEmpty) Zero else { val g = input.toGeoPoint - val d = input.accuracy.rangeInUnits / 2 + val d = input.accuracy.rangeInUnits / 2.0 Array[Double]( g.x, g.y, g.z, 1.0, diff --git a/features/src/main/scala/com/salesforce/op/features/types/Geolocation.scala b/features/src/main/scala/com/salesforce/op/features/types/Geolocation.scala index ee171f11dd..225d3a77a2 100644 --- a/features/src/main/scala/com/salesforce/op/features/types/Geolocation.scala +++ b/features/src/main/scala/com/salesforce/op/features/types/Geolocation.scala @@ -97,8 +97,8 @@ object Geolocation { private[types] def geolocationData( lat: Double, lon: Double, - accuracy: GeolocationAccuracy): Seq[Double] = - geolocationData(lat, lon, accuracy.value) + accuracy: GeolocationAccuracy + ): Seq[Double] = geolocationData(lat, lon, accuracy.value) private[types] def geolocationData( lat: Double,