Skip to content

Commit

Permalink
Use proper test ranges in feature converter test (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
tovbinm committed Oct 1, 2018
1 parent dc4f527 commit 449b0bc
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,23 +153,23 @@ case object FeatureTypeSparkConverter {
case null => FeatureTypeDefaults.Date.value
case v: Int => Some(v.toLong)
case v: Long => Some(v)
case _ => throw new IllegalArgumentException(s"Date type mapping is not defined for ${value.getClass}")
case v => throw new IllegalArgumentException(s"Date type mapping is not defined for ${v.getClass}")
}

// Numerals
// Numerics
case wt if wt <:< weakTypeOf[t.RealNN] => (value: Any) =>
value match {
case null => None
case v: Float => Some(v.toDouble)
case v: Double => Some(v)
case _ => throw new IllegalArgumentException(s"RealNN type mapping is not defined for ${value.getClass}")
case v => throw new IllegalArgumentException(s"RealNN type mapping is not defined for ${v.getClass}")
}
case wt if wt <:< weakTypeOf[t.Real] => (value: Any) =>
value match {
case null => FeatureTypeDefaults.Real.value
case v: Float => Some(v.toDouble)
case v: Double => Some(v)
case _ => throw new IllegalArgumentException(s"Real type mapping is not defined for ${value.getClass}")
case v => throw new IllegalArgumentException(s"Real type mapping is not defined for ${v.getClass}")
}
case wt if wt <:< weakTypeOf[t.Integral] => (value: Any) =>
value match {
Expand All @@ -178,7 +178,7 @@ case object FeatureTypeSparkConverter {
case v: Short => Some(v.toLong)
case v: Int => Some(v.toLong)
case v: Long => Some(v)
case _ => throw new IllegalArgumentException(s"Integral type mapping is not defined for ${value.getClass}")
case v => throw new IllegalArgumentException(s"Integral type mapping is not defined for ${v.getClass}")
}
case wt if wt <:< weakTypeOf[t.Binary] => (value: Any) =>
if (value == null) FeatureTypeDefaults.Binary.value else Some(value.asInstanceOf[Boolean])
Expand All @@ -190,8 +190,7 @@ case object FeatureTypeSparkConverter {

// Sets
case wt if wt <:< weakTypeOf[t.MultiPickList] => (value: Any) =>
if (value == null) FeatureTypeDefaults.MultiPickList.value
else value.asInstanceOf[MWrappedArray[String]].toSet
if (value == null) FeatureTypeDefaults.MultiPickList.value else value.asInstanceOf[MWrappedArray[String]].toSet

// Everything else
case _ => identity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,70 +30,84 @@

package com.salesforce.op.features

import com.salesforce.op.features.types.FeatureType
import com.salesforce.op.test.TestSparkContext
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.{Assertion, FlatSpec}
import org.scalatest.junit.JUnitRunner

import scala.reflect.runtime.universe._

@RunWith(classOf[JUnitRunner])
class FeatureSparkTypeTest extends FlatSpec with TestSparkContext {
val sparkTypeToTypeTagMappings = Seq(
(DoubleType, weakTypeTag[types.RealNN]), (FloatType, weakTypeTag[types.RealNN]),
(LongType, weakTypeTag[types.Integral]), (IntegerType, weakTypeTag[types.Integral]),
(ShortType, weakTypeTag[types.Integral]), (ByteType, weakTypeTag[types.Integral]),
(DateType, weakTypeTag[types.Date]), (TimestampType, weakTypeTag[types.DateTime]),
(StringType, weakTypeTag[types.Text]), (BooleanType, weakTypeTag[types.Binary]),
(VectorType, weakTypeTag[types.OPVector])
val primitiveTypes = Seq(
(DoubleType, weakTypeTag[types.Real], DoubleType),
(FloatType, weakTypeTag[types.Real], DoubleType),
(LongType, weakTypeTag[types.Integral], LongType),
(IntegerType, weakTypeTag[types.Integral], LongType),
(ShortType, weakTypeTag[types.Integral], LongType),
(ByteType, weakTypeTag[types.Integral], LongType),
(DateType, weakTypeTag[types.Date], LongType),
(TimestampType, weakTypeTag[types.DateTime], LongType),
(StringType, weakTypeTag[types.Text], StringType),
(BooleanType, weakTypeTag[types.Binary], BooleanType),
(VectorType, weakTypeTag[types.OPVector], VectorType)
)

val sparkCollectionTypeToTypeTagMappings = Seq(
(ArrayType(LongType, containsNull = true), weakTypeTag[types.DateList]),
(ArrayType(DoubleType, containsNull = true), weakTypeTag[types.Geolocation]),
(MapType(StringType, StringType, valueContainsNull = true), weakTypeTag[types.TextMap]),
(MapType(StringType, DoubleType, valueContainsNull = true), weakTypeTag[types.RealMap]),
(MapType(StringType, LongType, valueContainsNull = true), weakTypeTag[types.IntegralMap]),
(MapType(StringType, BooleanType, valueContainsNull = true), weakTypeTag[types.BinaryMap]),
(MapType(StringType, ArrayType(StringType, containsNull = true), valueContainsNull = true),
weakTypeTag[types.MultiPickListMap]),
(MapType(StringType, ArrayType(DoubleType, containsNull = true), valueContainsNull = true),
weakTypeTag[types.GeolocationMap])
val nonNullable = Seq(
(DoubleType, weakTypeTag[types.RealNN], DoubleType),
(FloatType, weakTypeTag[types.RealNN], DoubleType)
)

val sparkNonNullableTypeToTypeTagMappings = Seq(
(DoubleType, weakTypeTag[types.Real]), (FloatType, weakTypeTag[types.Real])
private def mapType(v: DataType) = MapType(StringType, v, valueContainsNull = true)
private def arrType(v: DataType) = ArrayType(v, containsNull = true)

val collectionTypes = Seq(
(arrType(LongType), weakTypeTag[types.DateList], arrType(LongType)),
(arrType(DoubleType), weakTypeTag[types.Geolocation], arrType(DoubleType)),
(arrType(StringType), weakTypeTag[types.TextList], arrType(StringType)),
(mapType(StringType), weakTypeTag[types.TextMap], mapType(StringType)),
(mapType(DoubleType), weakTypeTag[types.RealMap], mapType(DoubleType)),
(mapType(LongType), weakTypeTag[types.IntegralMap], mapType(LongType)),
(mapType(BooleanType), weakTypeTag[types.BinaryMap], mapType(BooleanType)),
(mapType(arrType(StringType)), weakTypeTag[types.MultiPickListMap], mapType(arrType(StringType))),
(mapType(arrType(DoubleType)), weakTypeTag[types.GeolocationMap], mapType(arrType(DoubleType)))
)

Spec(FeatureSparkTypes.getClass) should "assign appropriate feature type tags for valid types" in {
sparkTypeToTypeTagMappings.foreach(mapping =>
FeatureSparkTypes.featureTypeTagOf(mapping._1, isNullable = false) shouldBe mapping._2
)
Spec(FeatureSparkTypes.getClass) should "assign appropriate feature type tags for valid types and versa" in {
primitiveTypes.map(scala.Function.tupled(assertTypes()))
}

it should "assign appropriate feature type tags for valid non-nullable types" in {
sparkNonNullableTypeToTypeTagMappings.foreach(mapping =>
FeatureSparkTypes.featureTypeTagOf(mapping._1, isNullable = true) shouldBe mapping._2
)
it should "assign appropriate feature type tags for valid non-nullable types and versa" in {
nonNullable.map(scala.Function.tupled(assertTypes(isNullable = false)))
}

it should "assign appropriate feature type tags for collection types" in {
sparkCollectionTypeToTypeTagMappings.foreach(mapping =>
FeatureSparkTypes.featureTypeTagOf(mapping._1, isNullable = true) shouldBe mapping._2
)
it should "assign appropriate feature type tags for collection types and versa" in {
collectionTypes.map(scala.Function.tupled(assertTypes()))
}

it should "throw error for unsupported types" in {
it should "error for unsupported types" in {
val error = intercept[IllegalArgumentException](FeatureSparkTypes.featureTypeTagOf(BinaryType, isNullable = false))
error.getMessage shouldBe "Spark BinaryType is currently not supported."
}

it should "throw error for unknown types" in {
it should "error for unknown types" in {
val unknownType = NullType
val error = intercept[IllegalArgumentException](FeatureSparkTypes.featureTypeTagOf(unknownType, isNullable = false))
error.getMessage shouldBe s"No feature type tag mapping for Spark type $unknownType"
}

def assertTypes(
isNullable: Boolean = true
)(
sparkType: DataType,
featureType: WeakTypeTag[_ <: FeatureType],
expectedSparkType: DataType
): Assertion = {
FeatureSparkTypes.featureTypeTagOf(sparkType, isNullable) shouldBe featureType
FeatureSparkTypes.sparkTypeOf(featureType) shouldBe expectedSparkType
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalacheck.Gen
import org.scalacheck.Arbitrary._
import org.scalatest.PropSpec
import org.scalatest.junit.JUnitRunner
import org.scalatest.prop.{PropertyChecks, TableFor1}

import scala.collection.mutable.{WrappedArray => MWrappedArray}
import scala.concurrent.duration._


Expand All @@ -50,13 +53,38 @@ class FeatureTypeSparkConverterTest
val featureTypeNames: TableFor1[String] = Table("ftnames",
FeatureTypeSparkConverter.featureTypeSparkConverters.keys.toSeq: _*
)
val bogusNames = Gen.alphaNumStr
val strings = Gen.alphaNumStr

val naturalNumbers = Table("NaturalNumbers", Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue)
val naturalNumbers = Gen.oneOf(
arbitrary[Long], arbitrary[Int], arbitrary[Short], arbitrary[Byte]
).map(_.asInstanceOf[Number])

val realNumbers = Table("NaturalNumbers", Float.MaxValue, Double.MaxValue)
val realNumbers = Gen.oneOf(arbitrary[Float], arbitrary[Double]).map(_.asInstanceOf[Number])

val dateTimeValues = Table("DateTimeVals", 300, 300L)
val dateValues = Gen.oneOf(Gen.posNum[Int], Gen.posNum[Long]).map(_.asInstanceOf[Number])
val dateTimeValues = Gen.posNum[Long]

val booleans = Table("booleans", true, false)

val featureTypeValues = Table("ft",
Real.empty -> null,
Text("abc") -> "abc",
Real(0.1) -> 0.1,
Integral(123) -> 123,
Array(1.0, 2.0).toOPVector -> Vectors.dense(Array(1.0, 2.0)),
Vectors.sparse(2, Array(0), Array(3.0)).toOPVector -> Vectors.sparse(2, Array(0), Array(3.0)),
Seq(1L, 2L, 3L).toDateList -> Array(1L, 2L, 3L),
Set("a", "b").toMultiPickList -> Array("a", "b")
)

val featureTypeMapValues = Table("ftm",
TextMap.empty -> null,
Map("1" -> 1.0, "2" -> 2.0).toRealMap -> Map("1" -> 1.0, "2" -> 2.0),
Map("1" -> 3L, "2" -> 4L).toIntegralMap -> Map("1" -> 3L, "2" -> 4L),
Map("1" -> "one", "2" -> "two").toTextMap -> Map("1" -> "one", "2" -> "two"),
Map("1" -> Set("a", "b")).toMultiPickListMap -> Map("1" -> MWrappedArray.make(Array("a", "b"))),
Map("1" -> Seq(1.0, 5.0, 6.0)).toGeolocationMap -> Map("1" -> MWrappedArray.make(Array(1.0, 5.0, 6.0)))
)

property("is a feature type converter") {
forAll(featureTypeConverters) { ft => ft shouldBe a[FeatureTypeSparkConverter[_]] }
Expand All @@ -72,7 +100,7 @@ class FeatureTypeSparkConverterTest
}
}
property("error on making a converter on no existent feature type name") {
forAll(bogusNames) { bogusName =>
forAll(strings) { bogusName =>
intercept[IllegalArgumentException](
FeatureTypeSparkConverter.fromFeatureTypeName(bogusName)
).getMessage shouldBe s"Unknown feature type '$bogusName'"
Expand Down Expand Up @@ -100,100 +128,80 @@ class FeatureTypeSparkConverterTest
}
)
}

property("converts Natural Number of Byte/Short/Int/Long ranges to Integral valued feature type") {
forAll(naturalNumbers)(nn =>
FeatureTypeSparkConverter[Integral]().fromSpark(nn) shouldBe a[Integral]
)
}
property("converts Natural Number of Byte/Short/Int/Long ranges to Long range Integral feature") {
forAll(naturalNumbers)(nn =>
FeatureTypeSparkConverter[Integral]().fromSpark(nn).value.get shouldBe a[java.lang.Long]
)
property("converts natural number of Byte/Short/Int/Long ranges to Integral feature type") {
forAll(naturalNumbers) { nn =>
FeatureTypeSparkConverter[Integral]().fromSpark(nn) shouldBe nn.longValue().toIntegral
FeatureTypeSparkConverter.toSpark(nn.longValue().toIntegral) shouldEqual nn
}
}
property("raises error for bad Natural Number") {
property("raises error on invalid natural numbers") {
forAll(realNumbers)(nn =>
intercept[IllegalArgumentException](FeatureTypeSparkConverter[Integral]().fromSpark(nn)).getMessage
shouldBe s"Integral type mapping is not defined for class java.lang.${nn.getClass.toString.capitalize}"
)
}

property("converts Real Numbers in float/double ranges to Real valued feature type") {
forAll(realNumbers)(rn =>
FeatureTypeSparkConverter[Real]().fromSpark(rn) shouldBe a[Real]
intercept[IllegalArgumentException](FeatureTypeSparkConverter[Integral]().fromSpark(nn))
.getMessage startsWith "Integral type mapping is not defined"
)
}
property("converts Real Numbers in float/double ranges to Double range Real feature") {
forAll(realNumbers)(rn =>
FeatureTypeSparkConverter[Real]().fromSpark(rn).value.get shouldBe a[java.lang.Double]
)
property("converts real numbers in Float/Double ranges to Real feature type") {
forAll(realNumbers) { rn =>
FeatureTypeSparkConverter[Real]().fromSpark(rn) shouldBe rn.doubleValue().toReal
FeatureTypeSparkConverter.toSpark(rn.doubleValue().toReal) shouldEqual rn
}
}
property("raises error for bad Real Number") {
forAll(naturalNumbers)(rn =>
property("raises error on invalid real numbers") {
forAll(naturalNumbers) { rn =>
intercept[IllegalArgumentException](FeatureTypeSparkConverter[Real]().fromSpark(rn))
.getMessage shouldBe s"Real type mapping is not defined for class java.lang.${rn.getClass.toString.capitalize}"
)
}

property("converts Real Numbers in float/double ranges to RealNN valued feature type") {
forAll(realNumbers)(rn =>
FeatureTypeSparkConverter[RealNN]().fromSpark(rn) shouldBe a[RealNN]
)
.getMessage startsWith "Real type mapping is not defined"
intercept[IllegalArgumentException](FeatureTypeSparkConverter[RealNN]().fromSpark(rn))
.getMessage startsWith "RealNN type mapping is not defined"
}
}
property("converts Real Numbers in float/double ranges Double range RealNN feature") {
forAll(realNumbers)(rn =>
FeatureTypeSparkConverter[RealNN]().fromSpark(rn).value.get shouldBe a[java.lang.Double]
)
property("convert real numbers in Float/Double ranges to RealNN feature type") {
forAll(realNumbers) { rn =>
FeatureTypeSparkConverter[RealNN]().fromSpark(rn) shouldBe rn.doubleValue().toRealNN
FeatureTypeSparkConverter.toSpark(rn.doubleValue().toRealNN) shouldEqual rn
}
}
property("raises error for empty RealNN Number") {
forAll(naturalNumbers)(rn =>
intercept[NonNullableEmptyException](FeatureTypeSparkConverter[RealNN]().fromSpark(null))
.getMessage shouldBe "RealNN cannot be empty"
)
property("error for an empty RealNN value") {
intercept[NonNullableEmptyException](FeatureTypeSparkConverter[RealNN]().fromSpark(null))
.getMessage shouldBe "RealNN cannot be empty"
}

property("converts date denoted using int/long ranges to date feature types") {
forAll(dateTimeValues)(dt =>
FeatureTypeSparkConverter[Date]().fromSpark(dt) shouldBe a[Date]
)
}
property("converts date denoted using int/long ranges to Long range date feature") {
forAll(dateTimeValues)(dt =>
FeatureTypeSparkConverter[Date]().fromSpark(dt).value.get shouldBe a[java.lang.Long]
)
property("convert date denoted using Int/Long ranges to Date feature type") {
forAll(dateValues) { dt =>
FeatureTypeSparkConverter[Date]().fromSpark(dt) shouldBe dt.longValue().toDate
FeatureTypeSparkConverter.toSpark(dt.longValue().toDate) shouldEqual dt
}
}
property("raises error for bad date values") {
property("error on invalid date values") {
forAll(realNumbers)(rn =>
intercept[IllegalArgumentException](FeatureTypeSparkConverter[Date]().fromSpark(rn))
.getMessage shouldBe s"Date type mapping is not defined for class java.lang.${rn.getClass.toString.capitalize}"
intercept[IllegalArgumentException](FeatureTypeSparkConverter[Date]().fromSpark(rn))
.getMessage startsWith "Date type mapping is not defined"
)
}

property("converts timestamp denoted using long range to datetime feature type") {
forAll(dateTimeValues)(dt =>
FeatureTypeSparkConverter[DateTime]().fromSpark(dt) shouldBe a[DateTime]
)
property("convert timestamp denoted using Long range to Datetime feature type") {
forAll(dateTimeValues) { dt =>
FeatureTypeSparkConverter[DateTime]().fromSpark(dt) shouldBe dt.toDateTime
FeatureTypeSparkConverter.toSpark(dt.toDateTime) shouldEqual dt
}
}
property("converts timestamp denoted using long range to date super feature type") {
forAll(dateTimeValues)(dt =>
FeatureTypeSparkConverter[DateTime]().fromSpark(dt) shouldBe a[Date]
)
property("convert string to text feature type") {
forAll(strings) { s =>
FeatureTypeSparkConverter[Text]().fromSpark(s) shouldBe s.toText
FeatureTypeSparkConverter.toSpark(s.toText) shouldEqual s
}
}
property("converts timestamp denoted using long ranges to long range datetime feature") {
forAll(dateTimeValues)(dt =>
FeatureTypeSparkConverter[DateTime]().fromSpark(dt).value.get shouldBe a[java.lang.Long]
)
property("convert boolean to Binary feature type") {
forAll(booleans) { b =>
FeatureTypeSparkConverter[Binary]().fromSpark(b) shouldBe b.toBinary
FeatureTypeSparkConverter.toSpark(b.toBinary) shouldEqual b
}
}

property("converts string to text feature type") {
val text = FeatureTypeSparkConverter[Text]().fromSpark("Simple")
text shouldBe a[Text]
text.value.get shouldBe a[String]
property("convert feature type values to spark values") {
forAll(featureTypeValues) { case (featureValue, sparkValue) =>
FeatureTypeSparkConverter.toSpark(featureValue) shouldEqual sparkValue
}
}

property("converts a Boolean to Binary feature type") {
val bool = FeatureTypeSparkConverter[Binary]().fromSpark(false)
bool shouldBe a[Binary]
bool.value.get shouldBe a[java.lang.Boolean]
property("convert feature type map values to spark values") {
forAll(featureTypeMapValues) { case (featureValue, sparkValue) =>
FeatureTypeSparkConverter.toSpark(featureValue) shouldEqual sparkValue
}
}
}

0 comments on commit 449b0bc

Please sign in to comment.