diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index fdede2ad39109..157f2dbf5d7a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -177,7 +177,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { )) } - override def serialize(obj: Any): InternalRow = { + override def serialize(obj: Matrix): InternalRow = { val row = new GenericMutableRow(7) obj match { case sm: SparseMatrix => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index cecfd067bd874..0f0c3a2df556b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -203,7 +203,7 @@ class VectorUDT extends UserDefinedType[Vector] { StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) } - override def serialize(obj: Any): InternalRow = { + override def serialize(obj: Vector): InternalRow = { obj match { case SparseVector(size, indices, values) => val row = new GenericMutableRow(4) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 59c7e7db2e1e2..ffc6fa0599c16 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -292,6 +292,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$") ) ++ Seq( + //SPARK-11011 UserDefinedType serialization should be strongly typed + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 2ec0ff53c89c0..9bfc381639140 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -136,16 +136,16 @@ object CatalystTypeConverters { override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType) } - private case class UDTConverter( - udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + private case class UDTConverter[A >: Null]( + udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] { // toCatalyst (it calls toCatalystImpl) will do null check. - override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) + override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue) - override def toScala(catalystValue: Any): Any = { + override def toScala(catalystValue: Any): A = { if (catalystValue == null) null else udt.deserialize(catalystValue) } - override def toScalaImpl(row: InternalRow, column: Int): Any = + override def toScalaImpl(row: InternalRow, column: Int): A = toScala(row.get(column, udt.sqlType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 9d2449f3b729e..dabf9a2fc063a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -37,7 +37,7 @@ import org.apache.spark.annotation.DeveloperApi * The conversion via `deserialize` occurs when reading from a `DataFrame`. */ @DeveloperApi -abstract class UserDefinedType[UserType] extends DataType with Serializable { +abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable { /** Underlying storage type for this UDT */ def sqlType: DataType @@ -50,11 +50,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** * Convert the user type to a SQL datum - * - * TODO: Can we make this take obj: UserType? The issue is in - * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. */ - def serialize(obj: Any): Any + def serialize(obj: UserType): Any /** Convert a SQL datum to the user type */ def deserialize(datum: Any): UserType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 4e7bbc38d60ea..1b297525bdafb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -36,11 +36,7 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { override def sqlType: DataType = IntegerType - override def serialize(obj: Any): Int = { - obj match { - case groupableData: GroupableData => groupableData.data - } - } + override def serialize(groupableData: GroupableData): Int = groupableData.data override def deserialize(datum: Any): GroupableData = { datum match { @@ -60,13 +56,10 @@ private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { override def sqlType: DataType = MapType(IntegerType, IntegerType) - override def serialize(obj: Any): MapData = { - obj match { - case groupableData: UngroupableData => - val keyArray = new GenericArrayData(groupableData.data.keys.toSeq) - val valueArray = new GenericArrayData(groupableData.data.values.toSeq) - new ArrayBasedMapData(keyArray, valueArray) - } + override def serialize(ungroupableData: UngroupableData): MapData = { + val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq) + val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq) + new ArrayBasedMapData(keyArray, valueArray) } override def deserialize(datum: Any): UngroupableData = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index f119c6f4f73c4..bf0360c5e1967 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -47,14 +47,11 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" - override def serialize(obj: Any): GenericArrayData = { - obj match { - case p: ExamplePoint => - val output = new Array[Any](2) - output(0) = p.x - output(1) = p.y - new GenericArrayData(output) - } + override def serialize(p: ExamplePoint): GenericArrayData = { + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) } override def deserialize(datum: Any): ExamplePoint = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index e2c9fc421b83f..695a5ad78adc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -42,14 +42,11 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" - override def serialize(obj: Any): GenericArrayData = { - obj match { - case p: ExamplePoint => - val output = new Array[Any](2) - output(0) = p.x - output(1) = p.y - new GenericArrayData(output) - } + override def serialize(p: ExamplePoint): GenericArrayData = { + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) } override def deserialize(datum: Any): ExamplePoint = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 9081bc722ad04..8c4afb605b01f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -45,11 +45,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): ArrayData = { - obj match { - case features: MyDenseVector => - new GenericArrayData(features.data.map(_.asInstanceOf[Any])) - } + override def serialize(features: MyDenseVector): ArrayData = { + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) } override def deserialize(datum: Any): MyDenseVector = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index fb99b0c7e2acd..f8166c7ddc4da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -590,14 +590,11 @@ object TestingUDT { .add("b", LongType, nullable = false) .add("c", DoubleType, nullable = false) - override def serialize(obj: Any): Any = { + override def serialize(n: NestedStruct): Any = { val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) - obj match { - case n: NestedStruct => - row.setInt(0, n.a) - row.setLong(1, n.b) - row.setDouble(2, n.c) - } + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) } override def userClass: Class[NestedStruct] = classOf[NestedStruct]