diff --git a/features/src/main/scala/com/salesforce/op/features/types/FeatureTypeSparkConverter.scala b/features/src/main/scala/com/salesforce/op/features/types/FeatureTypeSparkConverter.scala index 77d887d84d..bd138b1fd0 100644 --- a/features/src/main/scala/com/salesforce/op/features/types/FeatureTypeSparkConverter.scala +++ b/features/src/main/scala/com/salesforce/op/features/types/FeatureTypeSparkConverter.scala @@ -153,23 +153,23 @@ case object FeatureTypeSparkConverter { case null => FeatureTypeDefaults.Date.value case v: Int => Some(v.toLong) case v: Long => Some(v) - case _ => throw new IllegalArgumentException(s"Date type mapping is not defined for ${value.getClass}") + case v => throw new IllegalArgumentException(s"Date type mapping is not defined for ${v.getClass}") } - // Numerals + // Numerics case wt if wt <:< weakTypeOf[t.RealNN] => (value: Any) => value match { case null => None case v: Float => Some(v.toDouble) case v: Double => Some(v) - case _ => throw new IllegalArgumentException(s"RealNN type mapping is not defined for ${value.getClass}") + case v => throw new IllegalArgumentException(s"RealNN type mapping is not defined for ${v.getClass}") } case wt if wt <:< weakTypeOf[t.Real] => (value: Any) => value match { case null => FeatureTypeDefaults.Real.value case v: Float => Some(v.toDouble) case v: Double => Some(v) - case _ => throw new IllegalArgumentException(s"Real type mapping is not defined for ${value.getClass}") + case v => throw new IllegalArgumentException(s"Real type mapping is not defined for ${v.getClass}") } case wt if wt <:< weakTypeOf[t.Integral] => (value: Any) => value match { @@ -178,7 +178,7 @@ case object FeatureTypeSparkConverter { case v: Short => Some(v.toLong) case v: Int => Some(v.toLong) case v: Long => Some(v) - case _ => throw new IllegalArgumentException(s"Integral type mapping is not defined for ${value.getClass}") + case v => throw new IllegalArgumentException(s"Integral type mapping is not defined for ${v.getClass}") } case wt if wt <:< weakTypeOf[t.Binary] => (value: Any) => if (value == null) FeatureTypeDefaults.Binary.value else Some(value.asInstanceOf[Boolean]) @@ -190,8 +190,7 @@ case object FeatureTypeSparkConverter { // Sets case wt if wt <:< weakTypeOf[t.MultiPickList] => (value: Any) => - if (value == null) FeatureTypeDefaults.MultiPickList.value - else value.asInstanceOf[MWrappedArray[String]].toSet + if (value == null) FeatureTypeDefaults.MultiPickList.value else value.asInstanceOf[MWrappedArray[String]].toSet // Everything else case _ => identity diff --git a/features/src/test/scala/com/salesforce/op/features/FeatureSparkTypeTest.scala b/features/src/test/scala/com/salesforce/op/features/FeatureSparkTypeTest.scala index 7cbfd8f1d3..69af58d6ba 100644 --- a/features/src/test/scala/com/salesforce/op/features/FeatureSparkTypeTest.scala +++ b/features/src/test/scala/com/salesforce/op/features/FeatureSparkTypeTest.scala @@ -30,70 +30,84 @@ package com.salesforce.op.features +import com.salesforce.op.features.types.FeatureType import com.salesforce.op.test.TestSparkContext import org.apache.spark.ml.linalg.SQLDataTypes.VectorType import org.apache.spark.sql.types._ import org.junit.runner.RunWith -import org.scalatest.FlatSpec +import org.scalatest.{Assertion, FlatSpec} import org.scalatest.junit.JUnitRunner import scala.reflect.runtime.universe._ @RunWith(classOf[JUnitRunner]) class FeatureSparkTypeTest extends FlatSpec with TestSparkContext { - val sparkTypeToTypeTagMappings = Seq( - (DoubleType, weakTypeTag[types.RealNN]), (FloatType, weakTypeTag[types.RealNN]), - (LongType, weakTypeTag[types.Integral]), (IntegerType, weakTypeTag[types.Integral]), - (ShortType, weakTypeTag[types.Integral]), (ByteType, weakTypeTag[types.Integral]), - (DateType, weakTypeTag[types.Date]), (TimestampType, weakTypeTag[types.DateTime]), - (StringType, weakTypeTag[types.Text]), (BooleanType, weakTypeTag[types.Binary]), - (VectorType, weakTypeTag[types.OPVector]) + val primitiveTypes = Seq( + (DoubleType, weakTypeTag[types.Real], DoubleType), + (FloatType, weakTypeTag[types.Real], DoubleType), + (LongType, weakTypeTag[types.Integral], LongType), + (IntegerType, weakTypeTag[types.Integral], LongType), + (ShortType, weakTypeTag[types.Integral], LongType), + (ByteType, weakTypeTag[types.Integral], LongType), + (DateType, weakTypeTag[types.Date], LongType), + (TimestampType, weakTypeTag[types.DateTime], LongType), + (StringType, weakTypeTag[types.Text], StringType), + (BooleanType, weakTypeTag[types.Binary], BooleanType), + (VectorType, weakTypeTag[types.OPVector], VectorType) ) - val sparkCollectionTypeToTypeTagMappings = Seq( - (ArrayType(LongType, containsNull = true), weakTypeTag[types.DateList]), - (ArrayType(DoubleType, containsNull = true), weakTypeTag[types.Geolocation]), - (MapType(StringType, StringType, valueContainsNull = true), weakTypeTag[types.TextMap]), - (MapType(StringType, DoubleType, valueContainsNull = true), weakTypeTag[types.RealMap]), - (MapType(StringType, LongType, valueContainsNull = true), weakTypeTag[types.IntegralMap]), - (MapType(StringType, BooleanType, valueContainsNull = true), weakTypeTag[types.BinaryMap]), - (MapType(StringType, ArrayType(StringType, containsNull = true), valueContainsNull = true), - weakTypeTag[types.MultiPickListMap]), - (MapType(StringType, ArrayType(DoubleType, containsNull = true), valueContainsNull = true), - weakTypeTag[types.GeolocationMap]) + val nonNullable = Seq( + (DoubleType, weakTypeTag[types.RealNN], DoubleType), + (FloatType, weakTypeTag[types.RealNN], DoubleType) ) - val sparkNonNullableTypeToTypeTagMappings = Seq( - (DoubleType, weakTypeTag[types.Real]), (FloatType, weakTypeTag[types.Real]) + private def mapType(v: DataType) = MapType(StringType, v, valueContainsNull = true) + private def arrType(v: DataType) = ArrayType(v, containsNull = true) + + val collectionTypes = Seq( + (arrType(LongType), weakTypeTag[types.DateList], arrType(LongType)), + (arrType(DoubleType), weakTypeTag[types.Geolocation], arrType(DoubleType)), + (arrType(StringType), weakTypeTag[types.TextList], arrType(StringType)), + (mapType(StringType), weakTypeTag[types.TextMap], mapType(StringType)), + (mapType(DoubleType), weakTypeTag[types.RealMap], mapType(DoubleType)), + (mapType(LongType), weakTypeTag[types.IntegralMap], mapType(LongType)), + (mapType(BooleanType), weakTypeTag[types.BinaryMap], mapType(BooleanType)), + (mapType(arrType(StringType)), weakTypeTag[types.MultiPickListMap], mapType(arrType(StringType))), + (mapType(arrType(DoubleType)), weakTypeTag[types.GeolocationMap], mapType(arrType(DoubleType))) ) - Spec(FeatureSparkTypes.getClass) should "assign appropriate feature type tags for valid types" in { - sparkTypeToTypeTagMappings.foreach(mapping => - FeatureSparkTypes.featureTypeTagOf(mapping._1, isNullable = false) shouldBe mapping._2 - ) + Spec(FeatureSparkTypes.getClass) should "assign appropriate feature type tags for valid types and versa" in { + primitiveTypes.map(scala.Function.tupled(assertTypes())) } - it should "assign appropriate feature type tags for valid non-nullable types" in { - sparkNonNullableTypeToTypeTagMappings.foreach(mapping => - FeatureSparkTypes.featureTypeTagOf(mapping._1, isNullable = true) shouldBe mapping._2 - ) + it should "assign appropriate feature type tags for valid non-nullable types and versa" in { + nonNullable.map(scala.Function.tupled(assertTypes(isNullable = false))) } - it should "assign appropriate feature type tags for collection types" in { - sparkCollectionTypeToTypeTagMappings.foreach(mapping => - FeatureSparkTypes.featureTypeTagOf(mapping._1, isNullable = true) shouldBe mapping._2 - ) + it should "assign appropriate feature type tags for collection types and versa" in { + collectionTypes.map(scala.Function.tupled(assertTypes())) } - it should "throw error for unsupported types" in { + it should "error for unsupported types" in { val error = intercept[IllegalArgumentException](FeatureSparkTypes.featureTypeTagOf(BinaryType, isNullable = false)) error.getMessage shouldBe "Spark BinaryType is currently not supported." } - it should "throw error for unknown types" in { + it should "error for unknown types" in { val unknownType = NullType val error = intercept[IllegalArgumentException](FeatureSparkTypes.featureTypeTagOf(unknownType, isNullable = false)) error.getMessage shouldBe s"No feature type tag mapping for Spark type $unknownType" } + def assertTypes( + isNullable: Boolean = true + )( + sparkType: DataType, + featureType: WeakTypeTag[_ <: FeatureType], + expectedSparkType: DataType + ): Assertion = { + FeatureSparkTypes.featureTypeTagOf(sparkType, isNullable) shouldBe featureType + FeatureSparkTypes.sparkTypeOf(featureType) shouldBe expectedSparkType + } + } diff --git a/features/src/test/scala/com/salesforce/op/features/types/FeatureTypeSparkConverterTest.scala b/features/src/test/scala/com/salesforce/op/features/types/FeatureTypeSparkConverterTest.scala index 8e2230e301..100b11d36e 100644 --- a/features/src/test/scala/com/salesforce/op/features/types/FeatureTypeSparkConverterTest.scala +++ b/features/src/test/scala/com/salesforce/op/features/types/FeatureTypeSparkConverterTest.scala @@ -31,12 +31,15 @@ package com.salesforce.op.features.types import com.salesforce.op.test.TestCommon +import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalacheck.Gen +import org.scalacheck.Arbitrary._ import org.scalatest.PropSpec import org.scalatest.junit.JUnitRunner import org.scalatest.prop.{PropertyChecks, TableFor1} +import scala.collection.mutable.{WrappedArray => MWrappedArray} import scala.concurrent.duration._ @@ -50,13 +53,38 @@ class FeatureTypeSparkConverterTest val featureTypeNames: TableFor1[String] = Table("ftnames", FeatureTypeSparkConverter.featureTypeSparkConverters.keys.toSeq: _* ) - val bogusNames = Gen.alphaNumStr + val strings = Gen.alphaNumStr - val naturalNumbers = Table("NaturalNumbers", Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue) + val naturalNumbers = Gen.oneOf( + arbitrary[Long], arbitrary[Int], arbitrary[Short], arbitrary[Byte] + ).map(_.asInstanceOf[Number]) - val realNumbers = Table("NaturalNumbers", Float.MaxValue, Double.MaxValue) + val realNumbers = Gen.oneOf(arbitrary[Float], arbitrary[Double]).map(_.asInstanceOf[Number]) - val dateTimeValues = Table("DateTimeVals", 300, 300L) + val dateValues = Gen.oneOf(Gen.posNum[Int], Gen.posNum[Long]).map(_.asInstanceOf[Number]) + val dateTimeValues = Gen.posNum[Long] + + val booleans = Table("booleans", true, false) + + val featureTypeValues = Table("ft", + Real.empty -> null, + Text("abc") -> "abc", + Real(0.1) -> 0.1, + Integral(123) -> 123, + Array(1.0, 2.0).toOPVector -> Vectors.dense(Array(1.0, 2.0)), + Vectors.sparse(2, Array(0), Array(3.0)).toOPVector -> Vectors.sparse(2, Array(0), Array(3.0)), + Seq(1L, 2L, 3L).toDateList -> Array(1L, 2L, 3L), + Set("a", "b").toMultiPickList -> Array("a", "b") + ) + + val featureTypeMapValues = Table("ftm", + TextMap.empty -> null, + Map("1" -> 1.0, "2" -> 2.0).toRealMap -> Map("1" -> 1.0, "2" -> 2.0), + Map("1" -> 3L, "2" -> 4L).toIntegralMap -> Map("1" -> 3L, "2" -> 4L), + Map("1" -> "one", "2" -> "two").toTextMap -> Map("1" -> "one", "2" -> "two"), + Map("1" -> Set("a", "b")).toMultiPickListMap -> Map("1" -> MWrappedArray.make(Array("a", "b"))), + Map("1" -> Seq(1.0, 5.0, 6.0)).toGeolocationMap -> Map("1" -> MWrappedArray.make(Array(1.0, 5.0, 6.0))) + ) property("is a feature type converter") { forAll(featureTypeConverters) { ft => ft shouldBe a[FeatureTypeSparkConverter[_]] } @@ -72,7 +100,7 @@ class FeatureTypeSparkConverterTest } } property("error on making a converter on no existent feature type name") { - forAll(bogusNames) { bogusName => + forAll(strings) { bogusName => intercept[IllegalArgumentException]( FeatureTypeSparkConverter.fromFeatureTypeName(bogusName) ).getMessage shouldBe s"Unknown feature type '$bogusName'" @@ -100,100 +128,80 @@ class FeatureTypeSparkConverterTest } ) } - - property("converts Natural Number of Byte/Short/Int/Long ranges to Integral valued feature type") { - forAll(naturalNumbers)(nn => - FeatureTypeSparkConverter[Integral]().fromSpark(nn) shouldBe a[Integral] - ) - } - property("converts Natural Number of Byte/Short/Int/Long ranges to Long range Integral feature") { - forAll(naturalNumbers)(nn => - FeatureTypeSparkConverter[Integral]().fromSpark(nn).value.get shouldBe a[java.lang.Long] - ) + property("converts natural number of Byte/Short/Int/Long ranges to Integral feature type") { + forAll(naturalNumbers) { nn => + FeatureTypeSparkConverter[Integral]().fromSpark(nn) shouldBe nn.longValue().toIntegral + FeatureTypeSparkConverter.toSpark(nn.longValue().toIntegral) shouldEqual nn + } } - property("raises error for bad Natural Number") { + property("raises error on invalid natural numbers") { forAll(realNumbers)(nn => - intercept[IllegalArgumentException](FeatureTypeSparkConverter[Integral]().fromSpark(nn)).getMessage - shouldBe s"Integral type mapping is not defined for class java.lang.${nn.getClass.toString.capitalize}" - ) - } - - property("converts Real Numbers in float/double ranges to Real valued feature type") { - forAll(realNumbers)(rn => - FeatureTypeSparkConverter[Real]().fromSpark(rn) shouldBe a[Real] + intercept[IllegalArgumentException](FeatureTypeSparkConverter[Integral]().fromSpark(nn)) + .getMessage startsWith "Integral type mapping is not defined" ) } - property("converts Real Numbers in float/double ranges to Double range Real feature") { - forAll(realNumbers)(rn => - FeatureTypeSparkConverter[Real]().fromSpark(rn).value.get shouldBe a[java.lang.Double] - ) + property("converts real numbers in Float/Double ranges to Real feature type") { + forAll(realNumbers) { rn => + FeatureTypeSparkConverter[Real]().fromSpark(rn) shouldBe rn.doubleValue().toReal + FeatureTypeSparkConverter.toSpark(rn.doubleValue().toReal) shouldEqual rn + } } - property("raises error for bad Real Number") { - forAll(naturalNumbers)(rn => + property("raises error on invalid real numbers") { + forAll(naturalNumbers) { rn => intercept[IllegalArgumentException](FeatureTypeSparkConverter[Real]().fromSpark(rn)) - .getMessage shouldBe s"Real type mapping is not defined for class java.lang.${rn.getClass.toString.capitalize}" - ) - } - - property("converts Real Numbers in float/double ranges to RealNN valued feature type") { - forAll(realNumbers)(rn => - FeatureTypeSparkConverter[RealNN]().fromSpark(rn) shouldBe a[RealNN] - ) + .getMessage startsWith "Real type mapping is not defined" + intercept[IllegalArgumentException](FeatureTypeSparkConverter[RealNN]().fromSpark(rn)) + .getMessage startsWith "RealNN type mapping is not defined" + } } - property("converts Real Numbers in float/double ranges Double range RealNN feature") { - forAll(realNumbers)(rn => - FeatureTypeSparkConverter[RealNN]().fromSpark(rn).value.get shouldBe a[java.lang.Double] - ) + property("convert real numbers in Float/Double ranges to RealNN feature type") { + forAll(realNumbers) { rn => + FeatureTypeSparkConverter[RealNN]().fromSpark(rn) shouldBe rn.doubleValue().toRealNN + FeatureTypeSparkConverter.toSpark(rn.doubleValue().toRealNN) shouldEqual rn + } } - property("raises error for empty RealNN Number") { - forAll(naturalNumbers)(rn => - intercept[NonNullableEmptyException](FeatureTypeSparkConverter[RealNN]().fromSpark(null)) - .getMessage shouldBe "RealNN cannot be empty" - ) + property("error for an empty RealNN value") { + intercept[NonNullableEmptyException](FeatureTypeSparkConverter[RealNN]().fromSpark(null)) + .getMessage shouldBe "RealNN cannot be empty" } - - property("converts date denoted using int/long ranges to date feature types") { - forAll(dateTimeValues)(dt => - FeatureTypeSparkConverter[Date]().fromSpark(dt) shouldBe a[Date] - ) - } - property("converts date denoted using int/long ranges to Long range date feature") { - forAll(dateTimeValues)(dt => - FeatureTypeSparkConverter[Date]().fromSpark(dt).value.get shouldBe a[java.lang.Long] - ) + property("convert date denoted using Int/Long ranges to Date feature type") { + forAll(dateValues) { dt => + FeatureTypeSparkConverter[Date]().fromSpark(dt) shouldBe dt.longValue().toDate + FeatureTypeSparkConverter.toSpark(dt.longValue().toDate) shouldEqual dt + } } - property("raises error for bad date values") { + property("error on invalid date values") { forAll(realNumbers)(rn => - intercept[IllegalArgumentException](FeatureTypeSparkConverter[Date]().fromSpark(rn)) - .getMessage shouldBe s"Date type mapping is not defined for class java.lang.${rn.getClass.toString.capitalize}" + intercept[IllegalArgumentException](FeatureTypeSparkConverter[Date]().fromSpark(rn)) + .getMessage startsWith "Date type mapping is not defined" ) } - - property("converts timestamp denoted using long range to datetime feature type") { - forAll(dateTimeValues)(dt => - FeatureTypeSparkConverter[DateTime]().fromSpark(dt) shouldBe a[DateTime] - ) + property("convert timestamp denoted using Long range to Datetime feature type") { + forAll(dateTimeValues) { dt => + FeatureTypeSparkConverter[DateTime]().fromSpark(dt) shouldBe dt.toDateTime + FeatureTypeSparkConverter.toSpark(dt.toDateTime) shouldEqual dt + } } - property("converts timestamp denoted using long range to date super feature type") { - forAll(dateTimeValues)(dt => - FeatureTypeSparkConverter[DateTime]().fromSpark(dt) shouldBe a[Date] - ) + property("convert string to text feature type") { + forAll(strings) { s => + FeatureTypeSparkConverter[Text]().fromSpark(s) shouldBe s.toText + FeatureTypeSparkConverter.toSpark(s.toText) shouldEqual s + } } - property("converts timestamp denoted using long ranges to long range datetime feature") { - forAll(dateTimeValues)(dt => - FeatureTypeSparkConverter[DateTime]().fromSpark(dt).value.get shouldBe a[java.lang.Long] - ) + property("convert boolean to Binary feature type") { + forAll(booleans) { b => + FeatureTypeSparkConverter[Binary]().fromSpark(b) shouldBe b.toBinary + FeatureTypeSparkConverter.toSpark(b.toBinary) shouldEqual b + } } - - property("converts string to text feature type") { - val text = FeatureTypeSparkConverter[Text]().fromSpark("Simple") - text shouldBe a[Text] - text.value.get shouldBe a[String] + property("convert feature type values to spark values") { + forAll(featureTypeValues) { case (featureValue, sparkValue) => + FeatureTypeSparkConverter.toSpark(featureValue) shouldEqual sparkValue + } } - - property("converts a Boolean to Binary feature type") { - val bool = FeatureTypeSparkConverter[Binary]().fromSpark(false) - bool shouldBe a[Binary] - bool.value.get shouldBe a[java.lang.Boolean] + property("convert feature type map values to spark values") { + forAll(featureTypeMapValues) { case (featureValue, sparkValue) => + FeatureTypeSparkConverter.toSpark(featureValue) shouldEqual sparkValue + } } }