Skip to content

Commit

Permalink
[SPARK-11011][SQL] Narrow type of UDT serialization
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Narrow down the parameter type of `UserDefinedType#serialize()`. Currently, the parameter type is `Any`, however it would logically make more sense to narrow it down to the type of the actual user defined type.

## How was this patch tested?

Existing tests were successfully run on local machine.

Author: Jakob Odersky <jakob@odersky.com>

Closes apache#11379 from jodersky/SPARK-11011-udt-types.
  • Loading branch information
jodersky authored and roygao94 committed Mar 22, 2016
1 parent 12c8023 commit 55fc9cc
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
))
}

override def serialize(obj: Any): InternalRow = {
override def serialize(obj: Matrix): InternalRow = {
val row = new GenericMutableRow(7)
obj match {
case sm: SparseMatrix =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class VectorUDT extends UserDefinedType[Vector] {
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
}

override def serialize(obj: Any): InternalRow = {
override def serialize(obj: Vector): InternalRow = {
obj match {
case SparseVector(size, indices, values) =>
val row = new GenericMutableRow(4)
Expand Down
2 changes: 2 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$")
) ++ Seq(
//SPARK-11011 UserDefinedType serialization should be strongly typed
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
// SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,16 @@ object CatalystTypeConverters {
override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType)
}

private case class UDTConverter(
udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
private case class UDTConverter[A >: Null](
udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] {
// toCatalyst (it calls toCatalystImpl) will do null check.
override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue)

override def toScala(catalystValue: Any): Any = {
override def toScala(catalystValue: Any): A = {
if (catalystValue == null) null else udt.deserialize(catalystValue)
}

override def toScalaImpl(row: InternalRow, column: Int): Any =
override def toScalaImpl(row: InternalRow, column: Int): A =
toScala(row.get(column, udt.sqlType))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.annotation.DeveloperApi
* The conversion via `deserialize` occurs when reading from a `DataFrame`.
*/
@DeveloperApi
abstract class UserDefinedType[UserType] extends DataType with Serializable {
abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable {

/** Underlying storage type for this UDT */
def sqlType: DataType
Expand All @@ -50,11 +50,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {

/**
* Convert the user type to a SQL datum
*
* TODO: Can we make this take obj: UserType? The issue is in
* CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
*/
def serialize(obj: Any): Any
def serialize(obj: UserType): Any

/** Convert a SQL datum to the user type */
def deserialize(datum: Any): UserType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {

override def sqlType: DataType = IntegerType

override def serialize(obj: Any): Int = {
obj match {
case groupableData: GroupableData => groupableData.data
}
}
override def serialize(groupableData: GroupableData): Int = groupableData.data

override def deserialize(datum: Any): GroupableData = {
datum match {
Expand All @@ -60,13 +56,10 @@ private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {

override def sqlType: DataType = MapType(IntegerType, IntegerType)

override def serialize(obj: Any): MapData = {
obj match {
case groupableData: UngroupableData =>
val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
new ArrayBasedMapData(keyArray, valueArray)
}
override def serialize(ungroupableData: UngroupableData): MapData = {
val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq)
val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq)
new ArrayBasedMapData(keyArray, valueArray)
}

override def deserialize(datum: Any): UngroupableData = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,11 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

override def serialize(obj: Any): GenericArrayData = {
obj match {
case p: ExamplePoint =>
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}
override def serialize(p: ExamplePoint): GenericArrayData = {
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}

override def deserialize(datum: Any): ExamplePoint = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,11 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

override def serialize(obj: Any): GenericArrayData = {
obj match {
case p: ExamplePoint =>
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}
override def serialize(p: ExamplePoint): GenericArrayData = {
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}

override def deserialize(datum: Any): ExamplePoint = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {

override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)

override def serialize(obj: Any): ArrayData = {
obj match {
case features: MyDenseVector =>
new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
}
override def serialize(features: MyDenseVector): ArrayData = {
new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
}

override def deserialize(datum: Any): MyDenseVector = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,14 +590,11 @@ object TestingUDT {
.add("b", LongType, nullable = false)
.add("c", DoubleType, nullable = false)

override def serialize(obj: Any): Any = {
override def serialize(n: NestedStruct): Any = {
val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType))
obj match {
case n: NestedStruct =>
row.setInt(0, n.a)
row.setLong(1, n.b)
row.setDouble(2, n.c)
}
row.setInt(0, n.a)
row.setLong(1, n.b)
row.setDouble(2, n.c)
}

override def userClass: Class[NestedStruct] = classOf[NestedStruct]
Expand Down

0 comments on commit 55fc9cc

Please sign in to comment.