Skip to content

Commit

Permalink
Merge pull request #195 from iravid/issue-193
Browse files Browse the repository at this point in the history
Restructure collection encoders to resolve encoding issues in the REPL
  • Loading branch information
imarios committed Oct 15, 2017
2 parents f81be1d + b29d884 commit 17078b6
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 178 deletions.
284 changes: 119 additions & 165 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,125 @@ object TypedEncoder {
)
}

implicit def arrayEncoder[T: ClassTag](
implicit
encodeT: TypedEncoder[T]
): TypedEncoder[Array[T]] = new TypedEncoder[Array[T]] {
def nullable: Boolean = false

def jvmRepr: DataType = encodeT.jvmRepr match {
case ByteType => BinaryType
case _ => FramelessInternals.objectTypeFor[Array[T]]
}

def catalystRepr: DataType = encodeT.jvmRepr match {
case ByteType => BinaryType
case _ => ArrayType(encodeT.catalystRepr, encodeT.nullable)
}

def toCatalyst(path: Expression): Expression =
encodeT.jvmRepr match {
case IntegerType | LongType | DoubleType | FloatType | ShortType | BooleanType =>
StaticInvoke(classOf[UnsafeArrayData], catalystRepr, "fromPrimitiveArray", path :: Nil)

case ByteType => path

case otherwise => MapObjects(encodeT.toCatalyst, path, encodeT.jvmRepr, encodeT.nullable)
}

def fromCatalyst(path: Expression): Expression =
encodeT.jvmRepr match {
case IntegerType => Invoke(path, "toIntArray", jvmRepr)
case LongType => Invoke(path, "toLongArray", jvmRepr)
case DoubleType => Invoke(path, "toDoubleArray", jvmRepr)
case FloatType => Invoke(path, "toFloatArray", jvmRepr)
case ShortType => Invoke(path, "toShortArray", jvmRepr)
case BooleanType => Invoke(path, "toBooleanArray", jvmRepr)

case ByteType => path

case otherwise =>
Invoke(MapObjects(encodeT.fromCatalyst, path, encodeT.catalystRepr, encodeT.nullable), "array", jvmRepr)
}
}

implicit def collectionEncoder[C[X] <: Seq[X], T](
implicit
encodeT: TypedEncoder[T],
CT: ClassTag[C[T]]): TypedEncoder[C[T]] =
new TypedEncoder[C[T]] {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[C[T]](CT)

def catalystRepr: DataType = ArrayType(encodeT.catalystRepr, encodeT.nullable)

def toCatalyst(path: Expression): Expression =
if (ScalaReflection.isNativeType(encodeT.jvmRepr))
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
else MapObjects(encodeT.toCatalyst, path, encodeT.jvmRepr, encodeT.nullable)

def fromCatalyst(path: Expression): Expression =
MapObjects(
encodeT.fromCatalyst,
path,
encodeT.catalystRepr,
encodeT.nullable,
Some(CT.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly
)
}

implicit def mapEncoder[A: NotCatalystNullable, B](
implicit
encodeA: TypedEncoder[A],
encodeB: TypedEncoder[B]
): TypedEncoder[Map[A, B]] = new TypedEncoder[Map[A, B]] {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[Map[A, B]]

def catalystRepr: DataType = MapType(encodeA.catalystRepr, encodeB.catalystRepr, encodeB.nullable)

def fromCatalyst(path: Expression): Expression = {
val keyArrayType = ArrayType(encodeA.catalystRepr, containsNull = false)
val keyData = Invoke(
MapObjects(
encodeA.fromCatalyst,
Invoke(path, "keyArray", keyArrayType),
encodeA.catalystRepr
),
"array",
FramelessInternals.objectTypeFor[Array[Any]]
)

val valueArrayType = ArrayType(encodeB.catalystRepr, encodeB.nullable)
val valueData = Invoke(
MapObjects(
encodeB.fromCatalyst,
Invoke(path, "valueArray", valueArrayType),
encodeB.catalystRepr
),
"array",
FramelessInternals.objectTypeFor[Array[Any]]
)

StaticInvoke(
ArrayBasedMapData.getClass,
jvmRepr,
"toScalaMap",
keyData :: valueData :: Nil)
}

def toCatalyst(path: Expression): Expression = ExternalMapToCatalyst(
path,
encodeA.jvmRepr,
encodeA.toCatalyst,
encodeB.jvmRepr,
encodeB.toCatalyst,
encodeB.nullable)

}

implicit def optionEncoder[A](
implicit
underlying: TypedEncoder[A]
Expand Down Expand Up @@ -252,171 +371,6 @@ object TypedEncoder {
WrapOption(underlying.fromCatalyst(path), underlying.jvmRepr)
}

abstract class CollectionEncoder[F[_], A](implicit
underlying: TypedEncoder[A],
classTag: ClassTag[F[A]]
) extends TypedEncoder[F[A]] {
protected def arrayData(path: Expression): Expression = Option(underlying.jvmRepr)
.filter(ScalaReflection.isNativeType)
.filter(_ == underlying.catalystRepr)
.collect {
case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]]
case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]]
case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]]
case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]]
case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]]
case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]]
case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]]
}.map {
case (method, typ) => Invoke(path, method, typ)
}.getOrElse {
Invoke(
MapObjects(
underlying.fromCatalyst,
path,
underlying.catalystRepr
),
"array",
ScalaReflection.dataTypeFor[Array[AnyRef]]
)
}
}

implicit def vectorEncoder[A](
implicit
underlying: TypedEncoder[A]
): TypedEncoder[Vector[A]] = new CollectionEncoder[Vector, A]() {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[Vector[A]](classTag)

def catalystRepr: DataType = DataTypes.createArrayType(underlying.catalystRepr)

def fromCatalyst(path: Expression): Expression = {
StaticInvoke(
TypedEncoderUtils.getClass,
jvmRepr,
"mkVector",
arrayData(path) :: Nil
)
}

def toCatalyst(path: Expression): Expression = {
// if source `path` is already native for Spark, no need to `map`
if (ScalaReflection.isNativeType(underlying.jvmRepr)) {
NewInstance(
classOf[GenericArrayData],
path :: Nil,
dataType = ArrayType(underlying.catalystRepr, underlying.nullable)
)
} else {
MapObjects(underlying.toCatalyst, path, underlying.jvmRepr)
}
}
}

implicit def listEncoder[A](
implicit
underlying: TypedEncoder[A]
): TypedEncoder[List[A]] = new CollectionEncoder[List, A]() {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[List[A]](classTag)

def catalystRepr: DataType = DataTypes.createArrayType(underlying.catalystRepr)

def fromCatalyst(path: Expression): Expression = {
StaticInvoke(
TypedEncoderUtils.getClass,
jvmRepr,
"mkList",
arrayData(path) :: Nil
)
}

def toCatalyst(path: Expression): Expression = {
// if source `path` is already native for Spark, no need to `map`
if (ScalaReflection.isNativeType(underlying.jvmRepr)) {
NewInstance(
classOf[GenericArrayData],
path :: Nil,
dataType = ArrayType(underlying.catalystRepr, underlying.nullable)
)
} else {
MapObjects(underlying.toCatalyst, path, underlying.jvmRepr)
}
}
}

implicit def arrayEncoder[A](
implicit
underlying: TypedEncoder[A]
): TypedEncoder[Array[A]] = {
import underlying.classTag

new CollectionEncoder[Array, A]() {
def nullable: Boolean = false
def jvmRepr: DataType = FramelessInternals.objectTypeFor[Array[A]](classTag)
def catalystRepr: DataType = DataTypes.createArrayType(underlying.catalystRepr)

def fromCatalyst(path: Expression): Expression = arrayData(path)

def toCatalyst(path: Expression): Expression = {
// if source `path` is already native for Spark, no need to `map`
if (ScalaReflection.isNativeType(underlying.jvmRepr)) {
NewInstance(
classOf[GenericArrayData],
path :: Nil,
dataType = ArrayType(underlying.catalystRepr, underlying.nullable)
)
} else {
MapObjects(underlying.toCatalyst, path, underlying.jvmRepr)
}
}
}
}

implicit def mapEncoder[A: NotCatalystNullable, B](
implicit
encodeA: TypedEncoder[A],
encodeB: TypedEncoder[B]
): TypedEncoder[Map[A, B]] = new TypedEncoder[Map[A, B]] {
def nullable: Boolean = false
def jvmRepr: DataType = FramelessInternals.objectTypeFor[Map[A, B]]
def catalystRepr: DataType = MapType(encodeA.catalystRepr, encodeB.catalystRepr, encodeB.nullable)

private def wrap(arrayData: Expression) = {
StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
FramelessInternals.objectTypeFor[Seq[_]],
"make",
arrayData :: Nil)
}

def fromCatalyst(path: Expression): Expression = {
val keyArrayType = ArrayType(encodeA.catalystRepr, containsNull = false)
val keyData = wrap(arrayEncoder[A].fromCatalyst(Invoke(path, "keyArray", keyArrayType)))

val valueArrayType = ArrayType(encodeB.catalystRepr, encodeB.nullable)
val valueData = wrap(arrayEncoder[B].fromCatalyst(Invoke(path, "valueArray", valueArrayType)))

StaticInvoke(
ArrayBasedMapData.getClass,
jvmRepr,
"toScalaMap",
keyData :: valueData :: Nil)
}

def toCatalyst(path: Expression): Expression = ExternalMapToCatalyst(
path,
encodeA.jvmRepr,
encodeA.toCatalyst,
encodeB.jvmRepr,
encodeB.toCatalyst,
encodeB.nullable)

}

/** Encodes things using injection if there is one defined */
implicit def usingInjection[A: ClassTag, B]
(implicit inj: Injection[A, B], trb: TypedEncoder[B]): TypedEncoder[A] =
Expand Down
7 changes: 0 additions & 7 deletions dataset/src/main/scala/frameless/TypedEncoderUtils.scala

This file was deleted.

22 changes: 21 additions & 1 deletion dataset/src/test/scala/frameless/CollectTests.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
package frameless

import frameless.CollectTests.prop
import frameless.CollectTests.{ prop, propArray }
import org.apache.spark.sql.SQLContext
import org.scalacheck.Prop
import org.scalacheck.Prop._
import scala.reflect.ClassTag

class CollectTests extends TypedDatasetSuite {
test("collect()") {
check(forAll(propArray[Int] _))
check(forAll(propArray[Long] _))
check(forAll(propArray[Boolean] _))
check(forAll(propArray[Float] _))
check(forAll(propArray[String] _))
check(forAll(propArray[Byte] _))
check(forAll(propArray[Option[Int]] _))
check(forAll(propArray[Option[Long]] _))
check(forAll(propArray[Option[Double]] _))
check(forAll(propArray[Option[Float]] _))
check(forAll(propArray[Option[Short]] _))
check(forAll(propArray[Option[Byte]] _))
check(forAll(propArray[Option[Boolean]] _))
check(forAll(propArray[Option[String]] _))

check(forAll(prop[X2[Int, Int]] _))
check(forAll(prop[X2[String, String]] _))
check(forAll(prop[X2[String, Int]] _))
Expand Down Expand Up @@ -62,4 +77,9 @@ object CollectTests {

def prop[A: TypedEncoder : ClassTag](data: Vector[A])(implicit c: SQLContext): Prop =
TypedDataset.create(data).collect().run().toVector ?= data

def propArray[A: TypedEncoder : ClassTag](data: Vector[X1[Array[A]]])(implicit c: SQLContext): Prop =
Prop(TypedDataset.create(data).collect().run().toVector.zip(data).forall {
case (X1(l), X1(r)) => l.sameElements(r)
})
}
7 changes: 6 additions & 1 deletion dataset/src/test/scala/frameless/forward/FirstTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ package frameless

import org.scalacheck.Prop
import org.scalacheck.Prop._
import org.scalatest.Matchers

class FirstTests extends TypedDatasetSuite {
class FirstTests extends TypedDatasetSuite with Matchers {
test("first") {
def prop[A: TypedEncoder](data: Vector[A]): Prop =
TypedDataset.create(data).firstOption().run() =? data.headOption

check(forAll(prop[Int] _))
check(forAll(prop[String] _))
}

test("first on empty dataset should return None") {
TypedDataset.create(Vector[Int]()).firstOption().run() shouldBe None
}
}
13 changes: 13 additions & 0 deletions dataset/src/test/scala/frameless/forward/TakeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,26 @@ package frameless

import org.scalacheck.Prop
import org.scalacheck.Prop._
import scala.reflect.ClassTag

class TakeTests extends TypedDatasetSuite {
test("take") {
def prop[A: TypedEncoder](n: Int, data: Vector[A]): Prop =
(n >= 0) ==> (TypedDataset.create(data).take(n).run().toVector =? data.take(n))

def propArray[A: TypedEncoder: ClassTag](n: Int, data: Vector[X1[Array[A]]]): Prop =
(n >= 0) ==> {
Prop {
TypedDataset.create(data).take(n).run().toVector.zip(data.take(n)).forall {
case (X1(l), X1(r)) => l sameElements r
}
}
}

check(forAll(prop[Int] _))
check(forAll(prop[String] _))
check(forAll(propArray[Int] _))
check(forAll(propArray[String] _))
check(forAll(propArray[Byte] _))
}
}
Loading

0 comments on commit 17078b6

Please sign in to comment.