Skip to content

Commit

Permalink
Add a test case for the new TypedRow encoder
Browse files Browse the repository at this point in the history
implemented the proposal
  • Loading branch information
tribbloid committed Nov 26, 2023
1 parent f58ccd8 commit 5b65761
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 51 deletions.
119 changes: 89 additions & 30 deletions dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,19 @@ object DropUnitValues {
}
}

class RecordEncoder[F, G <: HList, H <: HList](
abstract class RecordEncoder[F, G <: HList, H <: HList](
implicit
i0: LabelledGeneric.Aux[F, G],
i1: DropUnitValues.Aux[G, H],
i2: IsHCons[H],
fields: Lazy[RecordEncoderFields[H]],
newInstanceExprs: Lazy[NewInstanceExprs[G]],
stage1: RecordEncoderStage1[G, H],
classTag: ClassTag[F])
extends TypedEncoder[F] {

import stage1._

def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
lazy val jvmRepr: DataType = FramelessInternals.objectTypeFor[F]

def catalystRepr: DataType = {
lazy val catalystRepr: DataType = {
val structFields = fields.value.value.map { field =>
StructField(
name = field.name,
Expand All @@ -169,39 +168,99 @@ class RecordEncoder[F, G <: HList, H <: HList](
StructType(structFields)
}

def toCatalyst(path: Expression): Expression = {
val nameExprs = fields.value.value.map { field => Literal(field.name) }
}

val valueExprs = fields.value.value.map { field =>
val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil)
field.encoder.toCatalyst(fieldPath)
}
object RecordEncoder {

case class ForGeneric[F, G <: HList, H <: HList](
)(implicit
stage1: RecordEncoderStage1[G, H],
classTag: ClassTag[F])
extends RecordEncoder[F, G, H] {

import stage1._

def toCatalyst(path: Expression): Expression = {

val valueExprs = fields.value.value.map { field =>
val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil)
field.encoder.toCatalyst(fieldPath)
}

val createExpr = stage1.cellsToCatalyst(valueExprs)

// the way exprs are encoded in CreateNamedStruct
val exprs = nameExprs.zip(valueExprs).flatMap {
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
val nullExpr = Literal.create(null, createExpr.dataType)

If(IsNull(path), nullExpr, createExpr)
}

val createExpr = CreateNamedStruct(exprs)
val nullExpr = Literal.create(null, createExpr.dataType)
def fromCatalyst(path: Expression): Expression = {

val newArgs = stage1.fromCatalystToCells(path)

val newExpr =
NewInstance(
classTag.runtimeClass,
newArgs,
jvmRepr,
propagateNull = true
)

val nullExpr = Literal.create(null, jvmRepr)

If(IsNull(path), nullExpr, createExpr)
If(IsNull(path), nullExpr, newExpr)
}
}

def fromCatalyst(path: Expression): Expression = {
val exprs = fields.value.value.map { field =>
field.encoder.fromCatalyst(
GetStructField(path, field.ordinal, Some(field.name))
)
case class ForTypedRow[G <: HList, H <: HList](
)(implicit
stage1: RecordEncoderStage1[G, H],
classTag: ClassTag[TypedRow[G]])
extends RecordEncoder[TypedRow[G], G, H] {

import stage1._

private final val _apply = "apply"
private final val _fromInternalRow = "fromInternalRow"

def toCatalyst(path: Expression): Expression = {

val valueExprs = fields.value.value.zipWithIndex.map {
case (field, i) =>
val fieldPath = Invoke(
path,
_apply,
field.encoder.jvmRepr,
Seq(Literal.create(i, IntegerType))
)
field.encoder.toCatalyst(fieldPath)
}

val createExpr = stage1.cellsToCatalyst(valueExprs)

val nullExpr = Literal.create(null, createExpr.dataType)

If(IsNull(path), nullExpr, createExpr)
}

val newArgs = newInstanceExprs.value.from(exprs)
val newExpr =
NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)
def fromCatalyst(path: Expression): Expression = {

val nullExpr = Literal.create(null, jvmRepr)
val newArgs = stage1.fromCatalystToCells(path)
val aggregated = CreateStruct(newArgs)

If(IsNull(path), nullExpr, newExpr)
val partial = TypedRow.WithCatalystTypes(newArgs.map(_.dataType))

val newExpr = Invoke(
Literal.fromObject(partial),
_fromInternalRow,
TypedRow.catalystType,
Seq(aggregated)
)

val nullExpr = Literal.create(null, jvmRepr)

If(IsNull(path), nullExpr, newExpr)
}
}
}

Expand Down
49 changes: 49 additions & 0 deletions dataset/src/main/scala/frameless/RecordEncoderStage1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package frameless

import org.apache.spark.sql.catalyst.expressions.{
CreateNamedStruct,
Expression,
GetStructField,
Literal
}
import shapeless.{ HList, Lazy }

case class RecordEncoderStage1[G <: HList, H <: HList](
)(implicit
// i1: DropUnitValues.Aux[G, H],
// i2: IsHCons[H],
val fields: Lazy[RecordEncoderFields[H]],
val newInstanceExprs: Lazy[NewInstanceExprs[G]]) {

def cellsToCatalyst(valueExprs: Seq[Expression]): Expression = {
val nameExprs = fields.value.value.map { field => Literal(field.name) }

// the way exprs are encoded in CreateNamedStruct
val exprs = nameExprs.zip(valueExprs).flatMap {
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
}

val createExpr = CreateNamedStruct(exprs)
createExpr
}

def fromCatalystToCells(path: Expression): Seq[Expression] = {
val exprs = fields.value.value.map { field =>
field.encoder.fromCatalyst(
GetStructField(path, field.ordinal, Some(field.name))
)
}

val newArgs = newInstanceExprs.value.from(exprs)
newArgs
}
}

object RecordEncoderStage1 {

implicit def usingDerivation[G <: HList, H <: HList](
implicit
i3: Lazy[RecordEncoderFields[H]],
i4: Lazy[NewInstanceExprs[G]]
): RecordEncoderStage1[G, H] = RecordEncoderStage1[G, H]()
}
12 changes: 10 additions & 2 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -727,15 +727,23 @@ object TypedEncoder {
}

/** Encodes things as records if there is no Injection defined */
implicit def usingDerivation[F, G <: HList, H <: HList](
implicit def deriveForGeneric[F, G <: HList, H <: HList](
implicit
i0: LabelledGeneric.Aux[F, G],
i1: DropUnitValues.Aux[G, H],
i2: IsHCons[H],
i3: Lazy[RecordEncoderFields[H]],
i4: Lazy[NewInstanceExprs[G]],
i5: ClassTag[F]
): TypedEncoder[F] = new RecordEncoder[F, G, H]
): TypedEncoder[F] = RecordEncoder.ForGeneric[F, G, H]()

implicit def deriveForTypedRow[G <: HList, H <: HList](
implicit
i1: DropUnitValues.Aux[G, H],
i2: IsHCons[H],
i3: Lazy[RecordEncoderFields[H]],
i4: Lazy[NewInstanceExprs[G]]
): TypedEncoder[TypedRow[G]] = RecordEncoder.ForTypedRow[G, H]()

/** Encodes things using a Spark SQL's User Defined Type (UDT) if there is one defined in implicit */
implicit def usingUserDefinedType[
Expand Down
45 changes: 45 additions & 0 deletions dataset/src/main/scala/frameless/TypedRow.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package frameless

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{ DataType, ObjectType }
import shapeless.HList

case class TypedRow[T <: HList](row: Row) {

def apply(i: Int): Any = row.apply(i)
}

object TypedRow {

def apply(values: Any*): TypedRow[HList] = {

val row = Row.fromSeq(values)
TypedRow(row)
}

case class WithCatalystTypes(schema: Seq[DataType]) {

def fromInternalRow(row: InternalRow): TypedRow[HList] = {
val data = row.toSeq(schema).toArray

apply(data: _*)
}

}

object WithCatalystTypes {}

def fromHList[T <: HList](
hlist: T
): TypedRow[T] = {

val cells = hlist.runtimeList

val row = Row.fromSeq(cells)
TypedRow(row)
}

lazy val catalystType: ObjectType = ObjectType(classOf[TypedRow[_]])

}
2 changes: 1 addition & 1 deletion dataset/src/test/scala/frameless/InjectionTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class InjectionTests extends TypedDatasetSuite {
}

test("Resolve ambiguity by importing usingDerivation") {
import TypedEncoder.usingDerivation
import TypedEncoder.deriveForGeneric
assert(
implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]]
)
Expand Down
38 changes: 22 additions & 16 deletions dataset/src/test/scala/frameless/RecordEncoderTests.scala
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
package frameless

import frameless.RecordEncoderTests.{ A, B, E }
import org.apache.spark.sql.types._
import org.apache.spark.sql.{ Row, functions => F }
import org.apache.spark.sql.types.{
ArrayType,
BinaryType,
DecimalType,
IntegerType,
LongType,
MapType,
ObjectType,
StringType,
StructField,
StructType
}

import shapeless.{ HList, LabelledGeneric }
import shapeless.test.illTyped

import org.scalatest.matchers.should.Matchers
import shapeless.record.Record
import shapeless.test.illTyped
import shapeless.{ HList, LabelledGeneric }

final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
test("Unable to encode products made from units only") {
Expand Down Expand Up @@ -101,6 +90,20 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
ds.collect.head shouldBe obj
}

test("shapeless Record") {

val r1: RecordEncoderTests.RR = Record(x = 1, y = "abc")
val r2: TypedRow[RecordEncoderTests.RR] = TypedRow.fromHList(r1)

val rdd = sc.parallelize(Seq(r2))
val ds =
session.createDataset(rdd)(
TypedExpressionEncoder[TypedRow[RecordEncoderTests.RR]]
)

ds.collect.head shouldBe r2
}

test("Scalar value class") {
import RecordEncoderTests._

Expand Down Expand Up @@ -632,6 +635,9 @@ object RecordEncoderTests {
case class D(m: Map[String, Int])
case class E(b: Set[B])

val RR = Record.`'x -> Int, 'y -> String`
type RR = RR.T

final class Subject(val name: String) extends AnyVal with Serializable

final class Grade(val value: BigDecimal) extends AnyVal with Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ object RefinedTypesTests {

import frameless.refined._ // implicit instances for refined

implicit val encoderA: TypedEncoder[A] = TypedEncoder.usingDerivation
implicit val encoderA: TypedEncoder[A] = TypedEncoder.deriveForGeneric

implicit val encoderB: TypedEncoder[B] = TypedEncoder.usingDerivation
implicit val encoderB: TypedEncoder[B] = TypedEncoder.deriveForGeneric
}

0 comments on commit 5b65761

Please sign in to comment.