Skip to content

Commit

Permalink
airframe-parquet: Support nested schema (#1917)
Browse files Browse the repository at this point in the history
* Extract Parquet type resolver
* Add primitive test
* Support reading repeated types
* Write nested objects
* Support reading nested group
  • Loading branch information
xerial committed Nov 16, 2021
1 parent ced33fe commit d2ab8fa
Show file tree
Hide file tree
Showing 13 changed files with 718 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@ import wvlet.log.LogSupport
import java.util
import scala.collection.generic.Growable
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.{universe => ru}

object AirframeParquetReader {

def builder[A: ru.TypeTag](path: String, conf: Configuration, plan: Option[ParquetQueryPlan] = None): Builder[A] = {
def builder[A](
surface: Surface,
path: String,
conf: Configuration,
plan: Option[ParquetQueryPlan] = None
): Builder[A] = {
val fsPath = new Path(path)
val file = HadoopInputFile.fromPath(fsPath, conf)
val builder = new Builder[A](Surface.of[A], file, plan)
val builder = new Builder[A](surface, file, plan)
builder.withConf(conf).asInstanceOf[Builder[A]]
}

Expand Down Expand Up @@ -92,85 +96,93 @@ class AirframeParquetRecordMaterializer[A](surface: Surface, projectedSchema: Me
}

object ParquetRecordConverter {

type Holder = Growable[(String, Any)]

private class IntConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class IntConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addInt(value: Int): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class LongConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class LongConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addLong(value: Long): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class BooleanConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class BooleanConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addBoolean(value: Boolean): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class StringConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class StringConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter with LogSupport {
override def addBinary(value: Binary): Unit = {
holder += fieldName -> value.toStringUsingUTF8
holder.add(fieldName, value.toStringUsingUTF8)
}
}
private class FloatConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class FloatConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addFloat(value: Float): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class DoubleConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class DoubleConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addDouble(value: Double): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class MsgPackConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class MsgPackConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addBinary(value: Binary): Unit = {
holder += fieldName -> ValueCodec.fromMsgPack(value.getBytes)
holder.add(fieldName, ValueCodec.fromMsgPack(value.getBytes))
}
}

}

class ParquetRecordConverter[A](surface: Surface, projectedSchema: MessageType) extends GroupConverter with LogSupport {
private val codec = MessageCodec.ofSurface(surface)
private val recordHolder = Map.newBuilder[String, Any]
private val codec = MessageCodec.ofSurface(surface)
private val recordBuilder = RecordBuilder.newBuilder

import ParquetRecordConverter._

private val converters: Seq[Converter] = projectedSchema.getFields.asScala.map { f =>
val cv: Converter = f match {
case p if p.isPrimitive =>
p.asPrimitiveType().getPrimitiveTypeName match {
case PrimitiveTypeName.INT32 => new IntConverter(f.getName, recordHolder)
case PrimitiveTypeName.INT64 => new LongConverter(f.getName, recordHolder)
case PrimitiveTypeName.BOOLEAN => new BooleanConverter(f.getName, recordHolder)
case PrimitiveTypeName.FLOAT => new FloatConverter(f.getName, recordHolder)
case PrimitiveTypeName.DOUBLE => new DoubleConverter(f.getName, recordHolder)
case PrimitiveTypeName.INT32 => new IntConverter(f.getName, recordBuilder)
case PrimitiveTypeName.INT64 => new LongConverter(f.getName, recordBuilder)
case PrimitiveTypeName.BOOLEAN => new BooleanConverter(f.getName, recordBuilder)
case PrimitiveTypeName.FLOAT => new FloatConverter(f.getName, recordBuilder)
case PrimitiveTypeName.DOUBLE => new DoubleConverter(f.getName, recordBuilder)
case PrimitiveTypeName.BINARY if p.getLogicalTypeAnnotation == stringType =>
new StringConverter(f.getName, recordHolder)
new StringConverter(f.getName, recordBuilder)
case PrimitiveTypeName.BINARY =>
new MsgPackConverter(f.getName, recordHolder)
new MsgPackConverter(f.getName, recordBuilder)
case _ => ???
}
case _ =>
// TODO Support nested types
???
// GroupConverter for nested objects

surface.params.find(_.name == f.getName) match {
case Some(param) =>
if (param.surface.isOption || param.surface.isSeq || param.surface.isArray) {
// For Option[X], Seq[X] types, extract X
val elementSurface = param.surface.typeArgs(0)
new ParquetRecordConverter(param.surface, ParquetSchema.toParquetSchema(elementSurface))
} else {
new ParquetRecordConverter(param.surface, ParquetSchema.toParquetSchema(param.surface))
}
case None =>
???
}
}
cv
}.toIndexedSeq

def currentRecord: A = {
val m = recordHolder.result()
val m = recordBuilder.toMap
trace(m)
codec.fromMap(m).asInstanceOf[A]
}

override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)

override def start(): Unit = {
recordHolder.clear()
recordBuilder.clear()
}

override def end(): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,25 @@ import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.parquet.hadoop.util.HadoopOutputFile
import org.apache.parquet.io.OutputFile
import org.apache.parquet.io.api.{Binary, RecordConsumer}
import org.apache.parquet.schema.LogicalTypeAnnotation.stringType
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.{MessageType, Type}
import wvlet.airframe.codec.PrimitiveCodec.{
BooleanCodec,
DoubleCodec,
FloatCodec,
IntCodec,
LongCodec,
StringCodec,
ValueCodec
}
import wvlet.airframe.codec.PrimitiveCodec.ValueCodec
import wvlet.airframe.codec.{JSONCodec, MessageCodec, MessageCodecException, MessageCodecFactory}
import wvlet.airframe.json.JSONParseException
import wvlet.airframe.msgpack.spi.Value.{ArrayValue, BinaryValue, MapValue, StringValue}
import wvlet.airframe.msgpack.spi.{MessagePack, MsgPack, Value}
import wvlet.airframe.parquet.AirframeParquetWriter.ParquetCodec
import wvlet.airframe.surface.{CName, Parameter, Surface}
import wvlet.log.LogSupport

import scala.reflect.runtime.{universe => ru}
import scala.jdk.CollectionConverters._

/**
*/

object AirframeParquetWriter {
def builder[A: ru.TypeTag](path: String, conf: Configuration): Builder[A] = {
val s = Surface.of[A]
object AirframeParquetWriter extends LogSupport {
def builder[A](surface: Surface, path: String, conf: Configuration): Builder[A] = {
val fsPath = new Path(path)
val file = HadoopOutputFile.fromPath(fsPath, conf)
val b = new Builder[A](s, file).withConf(conf)
val b = new Builder[A](surface, file).withConf(conf)
// Use snappy by default
b.withCompressionCodec(CompressionCodecName.SNAPPY)
.withWriteMode(ParquetFileWriter.Mode.OVERWRITE)
Expand Down Expand Up @@ -83,97 +70,18 @@ object AirframeParquetWriter {
.withWriteMode(ParquetFileWriter.Mode.OVERWRITE)
}

/**
* Convert object --[MessageCodec]--> msgpack --[MessageCodec]--> Parquet type --> RecordConsumer
* @param tpe
* @param index
* @param codec
*/
abstract class ParquetCodec(tpe: Type, index: Int, codec: MessageCodec[_]) {
protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit

def write(recordConsumer: RecordConsumer, v: Any): Unit = {
val msgpack = codec.asInstanceOf[MessageCodec[Any]].toMsgPack(v)
writeMsgpack(recordConsumer, msgpack)
}

def writeMsgpack(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.startField(tpe.getName, index)
writeValue(recordConsumer, msgpack)
recordConsumer.endField(tpe.getName, index)
}
}

private[parquet] def parquetCodecOf(tpe: Type, index: Int, codec: MessageCodec[_]): ParquetCodec = {
if (tpe.isPrimitive) {
tpe.asPrimitiveType().getPrimitiveTypeName match {
case PrimitiveTypeName.INT32 =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addInteger(IntCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.INT64 =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addLong(LongCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.BOOLEAN =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addBoolean(BooleanCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.FLOAT =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addFloat(FloatCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.DOUBLE =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addDouble(DoubleCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.BINARY if tpe.getLogicalTypeAnnotation == stringType =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addBinary(Binary.fromString(StringCodec.fromMsgPack(msgpack)))
}
}
case _ =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addBinary(Binary.fromConstantByteArray(msgpack))
}
}
}
} else {
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addBinary(Binary.fromConstantByteArray(msgpack))
}
}
}
}

}

class AirframeParquetWriteSupport[A](surface: Surface) extends WriteSupport[A] with LogSupport {
private lazy val schema: MessageType = Parquet.toParquetSchema(surface)
private val parquetCodec: Seq[(Parameter, ParquetCodec)] =
surface.params.zip(schema.getFields.asScala).map { case (param, tpe) =>
val codec = MessageCodec.ofSurface(param.surface)
(param, AirframeParquetWriter.parquetCodecOf(tpe, param.index, codec))
}
private lazy val schema = Parquet.toParquetSchema(surface)
private val objectCodec: ObjectParquetWriteCodec = {
ObjectParquetWriteCodec.buildFromSurface(surface, schema).asRoot
}

private var recordConsumer: RecordConsumer = null
import scala.jdk.CollectionConverters._

override def init(configuration: Configuration): WriteSupport.WriteContext = {
trace(s"schema: ${schema}")
val extraMetadata: Map[String, String] = Map.empty
new WriteContext(schema, extraMetadata.asJava)
}
Expand All @@ -184,27 +92,15 @@ class AirframeParquetWriteSupport[A](surface: Surface) extends WriteSupport[A] w

override def write(record: A): Unit = {
require(recordConsumer != null)
try {
recordConsumer.startMessage()
parquetCodec.foreach { case (param, pc) =>
val v = param.get(record)
v match {
case None if param.surface.isOption =>
// Skip writing Optional parameter
case _ =>
pc.write(recordConsumer, v)
}
}
} finally {
recordConsumer.endMessage()
}
objectCodec.write(recordConsumer, record)
}
}

class AirframeParquetRecordWriterSupport(schema: MessageType) extends WriteSupport[Any] with LogSupport {
private var recordConsumer: RecordConsumer = null

override def init(configuration: Configuration): WriteContext = {
trace(s"schema: ${schema}")
new WriteContext(schema, Map.empty[String, String].asJava)
}

Expand All @@ -223,24 +119,24 @@ class AirframeParquetRecordWriterSupport(schema: MessageType) extends WriteSuppo
} finally {
recordConsumer.endMessage()
}

}
}

/**
* Ajust any input objects into the shape of the Parquet schema
* Adjust any input objects into the shape of the Parquet schema
* @param schema
*/
class ParquetRecordCodec(schema: MessageType) extends LogSupport {

private val columnNames: IndexedSeq[String] =
schema.getFields.asScala.map(x => CName.toCanonicalName(x.getName)).toIndexedSeq
private val parquetCodecTable: Map[String, ParquetCodec] = {
private val parquetCodecTable: Map[String, ParameterCodec] = {
schema.getFields.asScala.zipWithIndex
.map { case (f, index) =>
val cKey = CName.toCanonicalName(f.getName)
cKey -> AirframeParquetWriter.parquetCodecOf(f, index, ValueCodec)
}.toMap[String, ParquetCodec]
val cKey = CName.toCanonicalName(f.getName)
val parquetCodec = ParquetWriteCodec.parquetCodecOf(f, Surface.of[Any], ValueCodec)
cKey -> ParameterCodec(index, f.getName, parquetCodec)
}.toMap[String, ParameterCodec]
}

private val anyCodec = MessageCodec.of[Any]
Expand All @@ -262,8 +158,8 @@ class ParquetRecordCodec(schema: MessageType) extends LogSupport {

def writeColumnValue(columnName: String, v: Value): Unit = {
parquetCodecTable.get(columnName) match {
case Some(parquetCodec) =>
parquetCodec.writeMsgpack(recordConsumer, v.toMsgpack)
case Some(parameterCodec) =>
parameterCodec.writeMsgPack(recordConsumer, v.toMsgpack)
case None =>
// No record. Skip the value
}
Expand Down

0 comments on commit d2ab8fa

Please sign in to comment.