Skip to content

Commit

Permalink
Support avro complex types (#70)
Browse files Browse the repository at this point in the history
Master issue: #39 

This PR contains SerDe for the following types of Spark to and from Pulsar Avro typed message :
 - Map
 - Union
 - Array
 - BigDecimal

After this PR, the leftover types in Spark are:
- CalendarInterval
- UserDefinedType
- ObjectType
- HiveStringType
  • Loading branch information
yjshen authored and sijie committed Jul 1, 2019
1 parent c26b612 commit 112425b
Show file tree
Hide file tree
Showing 5 changed files with 507 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,37 @@
*/
package org.apache.spark.sql.pulsar

import java.math.BigDecimal
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.Timestamp
import java.util.Date

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.pulsar.shade.org.apache.avro.{SchemaBuilder, Schema}
import org.apache.pulsar.shade.org.apache.avro.Schema.Type._
import org.apache.pulsar.shade.org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.pulsar.shade.org.apache.avro.Conversions.DecimalConversion
import org.apache.pulsar.shade.org.apache.avro.LogicalTypes
import org.apache.pulsar.shade.org.apache.avro.generic.{GenericData, GenericFixed, GenericRecord}
import org.apache.pulsar.shade.org.apache.avro.util.Utf8

import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema
import org.apache.avro.Schema.Type._
import org.apache.avro.generic.GenericFixed
import org.apache.avro.util.Utf8
import org.apache.pulsar.client.api.Message
import org.apache.pulsar.client.api.schema.{GenericRecord => PGenericRecord}
import org.apache.pulsar.client.impl.schema.generic.GenericAvroRecord
import org.apache.pulsar.common.schema.{SchemaInfo, SchemaType}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
import org.apache.spark.unsafe.types.UTF8String

import scala.collection.mutable.ArrayBuffer

class PulsarDeserializer(schemaInfo: SchemaInfo) {
private lazy val decimalConversions = new DecimalConversion()

val rootDataType: DataType = SchemaUtils.si2SqlType(schemaInfo)
val rootDataType: DataType = SchemaUtils.si2SqlType(schemaInfo).dataType

import SchemaUtils._

Expand All @@ -52,12 +56,12 @@ class PulsarDeserializer(schemaInfo: SchemaInfo) {
val st = rootDataType.asInstanceOf[StructType]
val resultRow = new SpecificInternalRow(st.map(_.dataType) ++ metaDataFields.map(_.dataType))
val fieldUpdater = new RowUpdater(resultRow)
val avroSchema = new org.apache.avro.Schema.Parser().parse(
val avroSchema = new Schema.Parser().parse(
new String(schemaInfo.getSchema, StandardCharsets.UTF_8))
val writer = getRecordWriter(avroSchema, st, Nil)
(msg: Message[_]) => {
val value = msg.getValue
writer(fieldUpdater, value.asInstanceOf[PGenericRecord])
writer(fieldUpdater, value.asInstanceOf[GenericAvroRecord].getAvroRecord)
writeMetadataFields(msg, resultRow)
resultRow
}
Expand Down Expand Up @@ -154,7 +158,7 @@ class PulsarDeserializer(schemaInfo: SchemaInfo) {
private def newWriter(
avroType: Schema,
catalystType: DataType,
path: List[String]): (RowUpdater, Int, Any) => Unit =
path: List[String]): (CatalystDataUpdater, Int, Any) => Unit =
(avroType.getType, catalystType) match {
case (NULL, NullType) => (updater, ordinal, _) =>
updater.setNullAt(ordinal)
Expand Down Expand Up @@ -213,34 +217,124 @@ class PulsarDeserializer(schemaInfo: SchemaInfo) {
}
updater.set(ordinal, bytes)

case (FIXED, d: DecimalType) =>
throw new NotImplementedError(s"$d not supported for now")
case (FIXED, d: DecimalType) => (updater, ordinal, value) =>
val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType,
LogicalTypes.decimal(d.precision, d.scale))
val decimal = createDecimal(bigDecimal, d.precision, d.scale)
updater.setDecimal(ordinal, decimal)

case (BYTES, d: DecimalType) =>
throw new NotImplementedError(s"$d not supported for now")
case (BYTES, d: DecimalType) => (updater, ordinal, value) =>
val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType,
LogicalTypes.decimal(d.precision, d.scale))
val decimal = createDecimal(bigDecimal, d.precision, d.scale)
updater.setDecimal(ordinal, decimal)

case (RECORD, st: StructType) =>
val writeRecord = getRecordWriter(avroType, st, path)
(updater, ordinal, value) =>
val row = new SpecificInternalRow(st)
writeRecord(new RowUpdater(row), value.asInstanceOf[PGenericRecord])
writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord])
updater.set(ordinal, row)

case (ARRAY, ArrayType(elementType, containsNull)) =>
throw new NotImplementedError("arrayType not supported for now")
val elementWriter = newWriter(avroType.getElementType, elementType, path)
(updater, ordinal, value) =>
val array = value.asInstanceOf[GenericData.Array[Any]]
val len = array.size()
val result = createArrayData(elementType, len)
val elementUpdater = new ArrayDataUpdater(result)

var i = 0
while (i < len) {
val element = array.get(i)
if (element == null) {
if (!containsNull) {
throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " +
"allowed to be null")
} else {
elementUpdater.setNullAt(i)
}
} else {
elementWriter(elementUpdater, i, element)
}
i += 1
}

updater.set(ordinal, result)

case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType =>
throw new NotImplementedError("mapType not supported for now")
val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path)
val valueWriter = newWriter(avroType.getValueType, valueType, path)
(updater, ordinal, value) =>
val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]]
val keyArray = createArrayData(keyType, map.size())
val keyUpdater = new ArrayDataUpdater(keyArray)
val valueArray = createArrayData(valueType, map.size())
val valueUpdater = new ArrayDataUpdater(valueArray)
val iter = map.entrySet().iterator()
var i = 0
while (iter.hasNext) {
val entry = iter.next()
assert(entry.getKey != null)
keyWriter(keyUpdater, i, entry.getKey)
if (entry.getValue == null) {
if (!valueContainsNull) {
throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " +
"allowed to be null")
} else {
valueUpdater.setNullAt(i)
}
} else {
valueWriter(valueUpdater, i, entry.getValue)
}
i += 1
}

updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))

// avro uses Union type to represent a nullable field
case (UNION, _) =>
val allTypes = avroType.getTypes.asScala
val nonNullTypes = allTypes.filter(_.getType != NULL)
if (nonNullTypes.nonEmpty) {
if (nonNullTypes.length == 1) {
newWriter(nonNullTypes.head, catalystType, path)
} else {
throw new NotImplementedError("UnionType not supported for now")
nonNullTypes.map(_.getType) match {
case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType =>
(updater, ordinal, value) => value match {
case null => updater.setNullAt(ordinal)
case l: java.lang.Long => updater.setLong(ordinal, l)
case i: java.lang.Integer => updater.setLong(ordinal, i.longValue())
}

case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType =>
(updater, ordinal, value) => value match {
case null => updater.setNullAt(ordinal)
case d: java.lang.Double => updater.setDouble(ordinal, d)
case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue())
}

case _ =>
catalystType match {
case st: StructType if st.length == nonNullTypes.size =>
val fieldWriters = nonNullTypes.zip(st.fields).map {
case (schema, field) => newWriter(schema, field.dataType, path :+ field.name)
}.toArray
(updater, ordinal, value) => {
val row = new SpecificInternalRow(st)
val fieldUpdater = new RowUpdater(row)
val i = GenericData.get().resolveUnion(avroType, value)
fieldWriters(i)(fieldUpdater, i, value)
updater.set(ordinal, row)
}

case _ =>
throw new IncompatibleSchemaException(
s"Cannot convert Avro to catalyst because schema at path " +
s"${path.mkString(".")} is not compatible " +
s"(avroType = $avroType, sqlType = $catalystType).\n")
}
}
}
} else {
(updater, ordinal, value) => updater.setNullAt(ordinal)
Expand All @@ -249,17 +343,14 @@ class PulsarDeserializer(schemaInfo: SchemaInfo) {
case _ =>
throw new IncompatibleSchemaException(
s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " +
s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" +
s"Source Avro schema: $avroType.\n" +
s"Target Catalyst type: $catalystType")
s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n")
}

private def getRecordWriter(
avroType: Schema,
sqlType: StructType,
path: List[String]): (RowUpdater, PGenericRecord) => Unit = {
path: List[String]): (RowUpdater, GenericRecord) => Unit = {
val validFieldIndexes = ArrayBuffer.empty[Int]
val validFieldNames = ArrayBuffer.empty[String]
val fieldWriters = ArrayBuffer.empty[(RowUpdater, Any) => Unit]

val length = sqlType.length
Expand All @@ -269,7 +360,6 @@ class PulsarDeserializer(schemaInfo: SchemaInfo) {
val avroField = avroType.getField(sqlField.name)
if (avroField != null) {
validFieldIndexes += avroField.pos()
validFieldNames += sqlField.name

val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name)
val ordinal = i
Expand All @@ -295,22 +385,77 @@ class PulsarDeserializer(schemaInfo: SchemaInfo) {
(fieldUpdater, record) => {
var i = 0
while (i < validFieldIndexes.length) {
fieldWriters(i)(fieldUpdater, record.getField(validFieldNames(i)))
fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i)))
i += 1
}
}
}
}

class RowUpdater(row: InternalRow) {
def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)

def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
if (precision <= Decimal.MAX_LONG_DIGITS) {
// Constructs a `Decimal` with an unscaled `Long` value if possible.
Decimal(decimal.unscaledValue().longValue(), precision, scale)
} else {
// Otherwise, resorts to an unscaled `BigInteger` instead.
Decimal(decimal, precision, scale)
}
}

private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length))
case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length))
case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length))
case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length))
case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length))
case _ => new GenericArrayData(new Array[Any](length))
}

/**
* A base interface for updating values inside catalyst data structure like `InternalRow` and
* `ArrayData`.
*/
sealed trait CatalystDataUpdater {
def set(ordinal: Int, value: Any): Unit

def setNullAt(ordinal: Int): Unit = set(ordinal, null)
def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
}

final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)

override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
override def setDecimal(ordinal: Int, value: Decimal): Unit =
row.setDecimal(ordinal, value, value.precision)
}

final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)

override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value)
}
}
Loading

0 comments on commit 112425b

Please sign in to comment.