Skip to content

Commit

Permalink
Restructure collection encoders
Browse files Browse the repository at this point in the history
Resolves #193.
  • Loading branch information
Itamar Ravid committed Oct 14, 2017
1 parent 29b15e2 commit dff1dbc
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 170 deletions.
275 changes: 110 additions & 165 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,116 @@ object TypedEncoder {
)
}

implicit def arrayEncoder[T: ClassTag](
implicit
T: TypedEncoder[T]
): TypedEncoder[Array[T]] = new TypedEncoder[Array[T]] {
def nullable: Boolean = false

def jvmRepr: DataType = T.jvmRepr match {
case ByteType => BinaryType
case _ => FramelessInternals.objectTypeFor[Array[T]]
}

def catalystRepr: DataType = T.jvmRepr match {
case ByteType => BinaryType
case _ => ArrayType(T.catalystRepr, T.nullable)
}

def toCatalyst(path: Expression): Expression =
T.jvmRepr match {
case IntegerType | LongType | DoubleType | FloatType | ShortType | BooleanType =>
StaticInvoke(classOf[UnsafeArrayData], catalystRepr, "fromPrimitiveArray", path :: Nil)

case ByteType => path

case otherwise => MapObjects(T.toCatalyst, path, T.jvmRepr, T.nullable)
}

def fromCatalyst(path: Expression): Expression =
T.jvmRepr match {
case IntegerType => Invoke(path, "toIntArray", jvmRepr)
case LongType => Invoke(path, "toLongArray", jvmRepr)
case DoubleType => Invoke(path, "toDoubleArray", jvmRepr)
case FloatType => Invoke(path, "toFloatArray", jvmRepr)
case ShortType => Invoke(path, "toShortArray", jvmRepr)
case BooleanType => Invoke(path, "toBooleanArray", jvmRepr)

case ByteType => path

case otherwise =>
Invoke(MapObjects(T.fromCatalyst, path, T.catalystRepr, T.nullable), "array", jvmRepr)
}
}

implicit def collectionEncoder[C[X] <: Seq[X], T](implicit T: TypedEncoder[T], CT: ClassTag[C[T]]): TypedEncoder[C[T]] =
new TypedEncoder[C[T]] {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[C[T]](CT)

def catalystRepr: DataType = ArrayType(T.catalystRepr, T.nullable)

def toCatalyst(path: Expression): Expression =
if (ScalaReflection.isNativeType(T.jvmRepr))
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
else MapObjects(T.toCatalyst, path, T.jvmRepr, T.nullable)

def fromCatalyst(path: Expression): Expression =
MapObjects(T.fromCatalyst, path, T.catalystRepr, T.nullable, Some(CT.runtimeClass))
}

implicit def mapEncoder[A: NotCatalystNullable, B](
implicit
encodeA: TypedEncoder[A],
encodeB: TypedEncoder[B]
): TypedEncoder[Map[A, B]] = new TypedEncoder[Map[A, B]] {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[Map[A, B]]

def catalystRepr: DataType = MapType(encodeA.catalystRepr, encodeB.catalystRepr, encodeB.nullable)

def fromCatalyst(path: Expression): Expression = {
val keyArrayType = ArrayType(encodeA.catalystRepr, containsNull = false)
val keyData = Invoke(
MapObjects(
encodeA.fromCatalyst,
Invoke(path, "keyArray", keyArrayType),
encodeA.catalystRepr
),
"array",
FramelessInternals.objectTypeFor[Array[Any]]
)

val valueArrayType = ArrayType(encodeB.catalystRepr, encodeB.nullable)
val valueData = Invoke(
MapObjects(
encodeB.fromCatalyst,
Invoke(path, "valueArray", valueArrayType),
encodeB.catalystRepr
),
"array",
FramelessInternals.objectTypeFor[Array[Any]]
)

StaticInvoke(
ArrayBasedMapData.getClass,
jvmRepr,
"toScalaMap",
keyData :: valueData :: Nil)
}

def toCatalyst(path: Expression): Expression = ExternalMapToCatalyst(
path,
encodeA.jvmRepr,
encodeA.toCatalyst,
encodeB.jvmRepr,
encodeB.toCatalyst,
encodeB.nullable)

}

implicit def optionEncoder[A](
implicit
underlying: TypedEncoder[A]
Expand Down Expand Up @@ -252,171 +362,6 @@ object TypedEncoder {
WrapOption(underlying.fromCatalyst(path), underlying.jvmRepr)
}

abstract class CollectionEncoder[F[_], A](implicit
underlying: TypedEncoder[A],
classTag: ClassTag[F[A]]
) extends TypedEncoder[F[A]] {
protected def arrayData(path: Expression): Expression = Option(underlying.jvmRepr)
.filter(ScalaReflection.isNativeType)
.filter(_ == underlying.catalystRepr)
.collect {
case BooleanType => "toBooleanArray" -> ScalaReflection.dataTypeFor[Array[Boolean]]
case ByteType => "toByteArray" -> ScalaReflection.dataTypeFor[Array[Byte]]
case ShortType => "toShortArray" -> ScalaReflection.dataTypeFor[Array[Short]]
case IntegerType => "toIntArray" -> ScalaReflection.dataTypeFor[Array[Int]]
case LongType => "toLongArray" -> ScalaReflection.dataTypeFor[Array[Long]]
case FloatType => "toFloatArray" -> ScalaReflection.dataTypeFor[Array[Float]]
case DoubleType => "toDoubleArray" -> ScalaReflection.dataTypeFor[Array[Double]]
}.map {
case (method, typ) => Invoke(path, method, typ)
}.getOrElse {
Invoke(
MapObjects(
underlying.fromCatalyst,
path,
underlying.catalystRepr
),
"array",
ScalaReflection.dataTypeFor[Array[AnyRef]]
)
}
}

implicit def vectorEncoder[A](
implicit
underlying: TypedEncoder[A]
): TypedEncoder[Vector[A]] = new CollectionEncoder[Vector, A]() {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[Vector[A]](classTag)

def catalystRepr: DataType = DataTypes.createArrayType(underlying.catalystRepr)

def fromCatalyst(path: Expression): Expression = {
StaticInvoke(
TypedEncoderUtils.getClass,
jvmRepr,
"mkVector",
arrayData(path) :: Nil
)
}

def toCatalyst(path: Expression): Expression = {
// if source `path` is already native for Spark, no need to `map`
if (ScalaReflection.isNativeType(underlying.jvmRepr)) {
NewInstance(
classOf[GenericArrayData],
path :: Nil,
dataType = ArrayType(underlying.catalystRepr, underlying.nullable)
)
} else {
MapObjects(underlying.toCatalyst, path, underlying.jvmRepr)
}
}
}

implicit def listEncoder[A](
implicit
underlying: TypedEncoder[A]
): TypedEncoder[List[A]] = new CollectionEncoder[List, A]() {
def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[List[A]](classTag)

def catalystRepr: DataType = DataTypes.createArrayType(underlying.catalystRepr)

def fromCatalyst(path: Expression): Expression = {
StaticInvoke(
TypedEncoderUtils.getClass,
jvmRepr,
"mkList",
arrayData(path) :: Nil
)
}

def toCatalyst(path: Expression): Expression = {
// if source `path` is already native for Spark, no need to `map`
if (ScalaReflection.isNativeType(underlying.jvmRepr)) {
NewInstance(
classOf[GenericArrayData],
path :: Nil,
dataType = ArrayType(underlying.catalystRepr, underlying.nullable)
)
} else {
MapObjects(underlying.toCatalyst, path, underlying.jvmRepr)
}
}
}

implicit def arrayEncoder[A](
implicit
underlying: TypedEncoder[A]
): TypedEncoder[Array[A]] = {
import underlying.classTag

new CollectionEncoder[Array, A]() {
def nullable: Boolean = false
def jvmRepr: DataType = FramelessInternals.objectTypeFor[Array[A]](classTag)
def catalystRepr: DataType = DataTypes.createArrayType(underlying.catalystRepr)

def fromCatalyst(path: Expression): Expression = arrayData(path)

def toCatalyst(path: Expression): Expression = {
// if source `path` is already native for Spark, no need to `map`
if (ScalaReflection.isNativeType(underlying.jvmRepr)) {
NewInstance(
classOf[GenericArrayData],
path :: Nil,
dataType = ArrayType(underlying.catalystRepr, underlying.nullable)
)
} else {
MapObjects(underlying.toCatalyst, path, underlying.jvmRepr)
}
}
}
}

implicit def mapEncoder[A: NotCatalystNullable, B](
implicit
encodeA: TypedEncoder[A],
encodeB: TypedEncoder[B]
): TypedEncoder[Map[A, B]] = new TypedEncoder[Map[A, B]] {
def nullable: Boolean = false
def jvmRepr: DataType = FramelessInternals.objectTypeFor[Map[A, B]]
def catalystRepr: DataType = MapType(encodeA.catalystRepr, encodeB.catalystRepr, encodeB.nullable)

private def wrap(arrayData: Expression) = {
StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
FramelessInternals.objectTypeFor[Seq[_]],
"make",
arrayData :: Nil)
}

def fromCatalyst(path: Expression): Expression = {
val keyArrayType = ArrayType(encodeA.catalystRepr, containsNull = false)
val keyData = wrap(arrayEncoder[A].fromCatalyst(Invoke(path, "keyArray", keyArrayType)))

val valueArrayType = ArrayType(encodeB.catalystRepr, encodeB.nullable)
val valueData = wrap(arrayEncoder[B].fromCatalyst(Invoke(path, "valueArray", valueArrayType)))

StaticInvoke(
ArrayBasedMapData.getClass,
jvmRepr,
"toScalaMap",
keyData :: valueData :: Nil)
}

def toCatalyst(path: Expression): Expression = ExternalMapToCatalyst(
path,
encodeA.jvmRepr,
encodeA.toCatalyst,
encodeB.jvmRepr,
encodeB.toCatalyst,
encodeB.nullable)

}

/** Encodes things using injection if there is one defined */
implicit def usingInjection[A: ClassTag, B]
(implicit inj: Injection[A, B], trb: TypedEncoder[B]): TypedEncoder[A] =
Expand Down
1 change: 1 addition & 0 deletions dataset/src/test/scala/frameless/CollectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import scala.reflect.ClassTag
class CollectTests extends TypedDatasetSuite {
test("collect()") {
check(forAll(propArray[String] _))
check(forAll(propArray[Byte] _))

check(forAll(prop[X2[Int, Int]] _))
check(forAll(prop[X2[String, String]] _))
Expand Down
4 changes: 3 additions & 1 deletion dataset/src/test/scala/frameless/forward/TakeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package frameless

import org.scalacheck.Prop
import org.scalacheck.Prop._
import scala.reflect.ClassTag

class TakeTests extends TypedDatasetSuite {
test("take") {
def prop[A: TypedEncoder](n: Int, data: Vector[A]): Prop =
(n >= 0) ==> (TypedDataset.create(data).take(n).run().toVector =? data.take(n))

def propArray[A: TypedEncoder](n: Int, data: Vector[X1[Array[A]]]): Prop =
def propArray[A: TypedEncoder: ClassTag](n: Int, data: Vector[X1[Array[A]]]): Prop =
(n >= 0) ==> {
Prop {
TypedDataset.create(data).take(n).run().toVector.zip(data.take(n)).forall {
Expand All @@ -21,5 +22,6 @@ class TakeTests extends TypedDatasetSuite {
check(forAll(prop[String] _))
check(forAll(propArray[Int] _))
check(forAll(propArray[String] _))
check(forAll(propArray[Byte] _))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.scalacheck.Prop._
import scala.collection.SeqLike

import scala.math.Ordering
import scala.reflect.ClassTag

class UnaryFunctionsTest extends TypedDatasetSuite {
test("size tests") {
Expand All @@ -27,7 +28,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
}

test("size on array test") {
def prop[A: TypedEncoder](xs: List[X1[Array[A]]]): Prop = {
def prop[A: TypedEncoder: ClassTag](xs: List[X1[Array[A]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.select(size(tds('a))).collect().run().toVector
Expand Down Expand Up @@ -82,7 +83,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
}

test("sort on array test: ascending order") {
def prop[A: TypedEncoder : Ordering](xs: List[X1[Array[A]]]): Prop = {
def prop[A: TypedEncoder : Ordering : ClassTag](xs: List[X1[Array[A]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.select(sortAscending(tds('a))).collect().run().toVector
Expand All @@ -103,7 +104,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
}

test("sort on array test: descending order") {
def prop[A: TypedEncoder : Ordering](xs: List[X1[Array[A]]]): Prop = {
def prop[A: TypedEncoder : Ordering : ClassTag](xs: List[X1[Array[A]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.select(sortDescending(tds('a))).collect().run().toVector
Expand Down Expand Up @@ -144,7 +145,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
}

test("explode on arrays") {
def prop[A: TypedEncoder](xs: List[X1[Array[A]]]): Prop = {
def prop[A: TypedEncoder: ClassTag](xs: List[X1[Array[A]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.select(explode(tds('a))).collect().run().toSet
Expand Down

0 comments on commit dff1dbc

Please sign in to comment.