Skip to content

Commit

Permalink
airframe-parquet: Support nested schema (#1917)
Browse files Browse the repository at this point in the history
* Extract Parquet type resolver
* Add primitive test
* Support reading repeated types
* Write nested objects
* Support reading nested group
  • Loading branch information
xerial committed Nov 16, 2021
1 parent ced33fe commit d2ab8fa
Show file tree
Hide file tree
Showing 13 changed files with 718 additions and 200 deletions.
Expand Up @@ -32,14 +32,18 @@ import wvlet.log.LogSupport
import java.util
import scala.collection.generic.Growable
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.{universe => ru}

object AirframeParquetReader {

def builder[A: ru.TypeTag](path: String, conf: Configuration, plan: Option[ParquetQueryPlan] = None): Builder[A] = {
def builder[A](
surface: Surface,
path: String,
conf: Configuration,
plan: Option[ParquetQueryPlan] = None
): Builder[A] = {
val fsPath = new Path(path)
val file = HadoopInputFile.fromPath(fsPath, conf)
val builder = new Builder[A](Surface.of[A], file, plan)
val builder = new Builder[A](surface, file, plan)
builder.withConf(conf).asInstanceOf[Builder[A]]
}

Expand Down Expand Up @@ -92,85 +96,93 @@ class AirframeParquetRecordMaterializer[A](surface: Surface, projectedSchema: Me
}

object ParquetRecordConverter {

type Holder = Growable[(String, Any)]

private class IntConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class IntConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addInt(value: Int): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class LongConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class LongConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addLong(value: Long): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class BooleanConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class BooleanConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addBoolean(value: Boolean): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class StringConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class StringConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter with LogSupport {
override def addBinary(value: Binary): Unit = {
holder += fieldName -> value.toStringUsingUTF8
holder.add(fieldName, value.toStringUsingUTF8)
}
}
private class FloatConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class FloatConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addFloat(value: Float): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class DoubleConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class DoubleConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addDouble(value: Double): Unit = {
holder += fieldName -> value
holder.add(fieldName, value)
}
}
private class MsgPackConverter(fieldName: String, holder: Holder) extends PrimitiveConverter {
private class MsgPackConverter(fieldName: String, holder: RecordBuilder) extends PrimitiveConverter {
override def addBinary(value: Binary): Unit = {
holder += fieldName -> ValueCodec.fromMsgPack(value.getBytes)
holder.add(fieldName, ValueCodec.fromMsgPack(value.getBytes))
}
}

}

class ParquetRecordConverter[A](surface: Surface, projectedSchema: MessageType) extends GroupConverter with LogSupport {
private val codec = MessageCodec.ofSurface(surface)
private val recordHolder = Map.newBuilder[String, Any]
private val codec = MessageCodec.ofSurface(surface)
private val recordBuilder = RecordBuilder.newBuilder

import ParquetRecordConverter._

private val converters: Seq[Converter] = projectedSchema.getFields.asScala.map { f =>
val cv: Converter = f match {
case p if p.isPrimitive =>
p.asPrimitiveType().getPrimitiveTypeName match {
case PrimitiveTypeName.INT32 => new IntConverter(f.getName, recordHolder)
case PrimitiveTypeName.INT64 => new LongConverter(f.getName, recordHolder)
case PrimitiveTypeName.BOOLEAN => new BooleanConverter(f.getName, recordHolder)
case PrimitiveTypeName.FLOAT => new FloatConverter(f.getName, recordHolder)
case PrimitiveTypeName.DOUBLE => new DoubleConverter(f.getName, recordHolder)
case PrimitiveTypeName.INT32 => new IntConverter(f.getName, recordBuilder)
case PrimitiveTypeName.INT64 => new LongConverter(f.getName, recordBuilder)
case PrimitiveTypeName.BOOLEAN => new BooleanConverter(f.getName, recordBuilder)
case PrimitiveTypeName.FLOAT => new FloatConverter(f.getName, recordBuilder)
case PrimitiveTypeName.DOUBLE => new DoubleConverter(f.getName, recordBuilder)
case PrimitiveTypeName.BINARY if p.getLogicalTypeAnnotation == stringType =>
new StringConverter(f.getName, recordHolder)
new StringConverter(f.getName, recordBuilder)
case PrimitiveTypeName.BINARY =>
new MsgPackConverter(f.getName, recordHolder)
new MsgPackConverter(f.getName, recordBuilder)
case _ => ???
}
case _ =>
// TODO Support nested types
???
// GroupConverter for nested objects

surface.params.find(_.name == f.getName) match {
case Some(param) =>
if (param.surface.isOption || param.surface.isSeq || param.surface.isArray) {
// For Option[X], Seq[X] types, extract X
val elementSurface = param.surface.typeArgs(0)
new ParquetRecordConverter(param.surface, ParquetSchema.toParquetSchema(elementSurface))
} else {
new ParquetRecordConverter(param.surface, ParquetSchema.toParquetSchema(param.surface))
}
case None =>
???
}
}
cv
}.toIndexedSeq

def currentRecord: A = {
val m = recordHolder.result()
val m = recordBuilder.toMap
trace(m)
codec.fromMap(m).asInstanceOf[A]
}

override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)

override def start(): Unit = {
recordHolder.clear()
recordBuilder.clear()
}

override def end(): Unit = {}
Expand Down
Expand Up @@ -22,38 +22,25 @@ import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.parquet.hadoop.util.HadoopOutputFile
import org.apache.parquet.io.OutputFile
import org.apache.parquet.io.api.{Binary, RecordConsumer}
import org.apache.parquet.schema.LogicalTypeAnnotation.stringType
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.{MessageType, Type}
import wvlet.airframe.codec.PrimitiveCodec.{
BooleanCodec,
DoubleCodec,
FloatCodec,
IntCodec,
LongCodec,
StringCodec,
ValueCodec
}
import wvlet.airframe.codec.PrimitiveCodec.ValueCodec
import wvlet.airframe.codec.{JSONCodec, MessageCodec, MessageCodecException, MessageCodecFactory}
import wvlet.airframe.json.JSONParseException
import wvlet.airframe.msgpack.spi.Value.{ArrayValue, BinaryValue, MapValue, StringValue}
import wvlet.airframe.msgpack.spi.{MessagePack, MsgPack, Value}
import wvlet.airframe.parquet.AirframeParquetWriter.ParquetCodec
import wvlet.airframe.surface.{CName, Parameter, Surface}
import wvlet.log.LogSupport

import scala.reflect.runtime.{universe => ru}
import scala.jdk.CollectionConverters._

/**
*/

object AirframeParquetWriter {
def builder[A: ru.TypeTag](path: String, conf: Configuration): Builder[A] = {
val s = Surface.of[A]
object AirframeParquetWriter extends LogSupport {
def builder[A](surface: Surface, path: String, conf: Configuration): Builder[A] = {
val fsPath = new Path(path)
val file = HadoopOutputFile.fromPath(fsPath, conf)
val b = new Builder[A](s, file).withConf(conf)
val b = new Builder[A](surface, file).withConf(conf)
// Use snappy by default
b.withCompressionCodec(CompressionCodecName.SNAPPY)
.withWriteMode(ParquetFileWriter.Mode.OVERWRITE)
Expand Down Expand Up @@ -83,97 +70,18 @@ object AirframeParquetWriter {
.withWriteMode(ParquetFileWriter.Mode.OVERWRITE)
}

/**
* Convert object --[MessageCodec]--> msgpack --[MessageCodec]--> Parquet type --> RecordConsumer
* @param tpe
* @param index
* @param codec
*/
abstract class ParquetCodec(tpe: Type, index: Int, codec: MessageCodec[_]) {
protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit

def write(recordConsumer: RecordConsumer, v: Any): Unit = {
val msgpack = codec.asInstanceOf[MessageCodec[Any]].toMsgPack(v)
writeMsgpack(recordConsumer, msgpack)
}

def writeMsgpack(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.startField(tpe.getName, index)
writeValue(recordConsumer, msgpack)
recordConsumer.endField(tpe.getName, index)
}
}

private[parquet] def parquetCodecOf(tpe: Type, index: Int, codec: MessageCodec[_]): ParquetCodec = {
if (tpe.isPrimitive) {
tpe.asPrimitiveType().getPrimitiveTypeName match {
case PrimitiveTypeName.INT32 =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addInteger(IntCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.INT64 =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addLong(LongCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.BOOLEAN =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addBoolean(BooleanCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.FLOAT =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addFloat(FloatCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.DOUBLE =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addDouble(DoubleCodec.fromMsgPack(msgpack))
}
}
case PrimitiveTypeName.BINARY if tpe.getLogicalTypeAnnotation == stringType =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addBinary(Binary.fromString(StringCodec.fromMsgPack(msgpack)))
}
}
case _ =>
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addBinary(Binary.fromConstantByteArray(msgpack))
}
}
}
} else {
new ParquetCodec(tpe, index, codec) {
override protected def writeValue(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.addBinary(Binary.fromConstantByteArray(msgpack))
}
}
}
}

}

class AirframeParquetWriteSupport[A](surface: Surface) extends WriteSupport[A] with LogSupport {
private lazy val schema: MessageType = Parquet.toParquetSchema(surface)
private val parquetCodec: Seq[(Parameter, ParquetCodec)] =
surface.params.zip(schema.getFields.asScala).map { case (param, tpe) =>
val codec = MessageCodec.ofSurface(param.surface)
(param, AirframeParquetWriter.parquetCodecOf(tpe, param.index, codec))
}
private lazy val schema = Parquet.toParquetSchema(surface)
private val objectCodec: ObjectParquetWriteCodec = {
ObjectParquetWriteCodec.buildFromSurface(surface, schema).asRoot
}

private var recordConsumer: RecordConsumer = null
import scala.jdk.CollectionConverters._

override def init(configuration: Configuration): WriteSupport.WriteContext = {
trace(s"schema: ${schema}")
val extraMetadata: Map[String, String] = Map.empty
new WriteContext(schema, extraMetadata.asJava)
}
Expand All @@ -184,27 +92,15 @@ class AirframeParquetWriteSupport[A](surface: Surface) extends WriteSupport[A] w

override def write(record: A): Unit = {
require(recordConsumer != null)
try {
recordConsumer.startMessage()
parquetCodec.foreach { case (param, pc) =>
val v = param.get(record)
v match {
case None if param.surface.isOption =>
// Skip writing Optional parameter
case _ =>
pc.write(recordConsumer, v)
}
}
} finally {
recordConsumer.endMessage()
}
objectCodec.write(recordConsumer, record)
}
}

class AirframeParquetRecordWriterSupport(schema: MessageType) extends WriteSupport[Any] with LogSupport {
private var recordConsumer: RecordConsumer = null

override def init(configuration: Configuration): WriteContext = {
trace(s"schema: ${schema}")
new WriteContext(schema, Map.empty[String, String].asJava)
}

Expand All @@ -223,24 +119,24 @@ class AirframeParquetRecordWriterSupport(schema: MessageType) extends WriteSuppo
} finally {
recordConsumer.endMessage()
}

}
}

/**
* Ajust any input objects into the shape of the Parquet schema
* Adjust any input objects into the shape of the Parquet schema
* @param schema
*/
class ParquetRecordCodec(schema: MessageType) extends LogSupport {

private val columnNames: IndexedSeq[String] =
schema.getFields.asScala.map(x => CName.toCanonicalName(x.getName)).toIndexedSeq
private val parquetCodecTable: Map[String, ParquetCodec] = {
private val parquetCodecTable: Map[String, ParameterCodec] = {
schema.getFields.asScala.zipWithIndex
.map { case (f, index) =>
val cKey = CName.toCanonicalName(f.getName)
cKey -> AirframeParquetWriter.parquetCodecOf(f, index, ValueCodec)
}.toMap[String, ParquetCodec]
val cKey = CName.toCanonicalName(f.getName)
val parquetCodec = ParquetWriteCodec.parquetCodecOf(f, Surface.of[Any], ValueCodec)
cKey -> ParameterCodec(index, f.getName, parquetCodec)
}.toMap[String, ParameterCodec]
}

private val anyCodec = MessageCodec.of[Any]
Expand All @@ -262,8 +158,8 @@ class ParquetRecordCodec(schema: MessageType) extends LogSupport {

def writeColumnValue(columnName: String, v: Value): Unit = {
parquetCodecTable.get(columnName) match {
case Some(parquetCodec) =>
parquetCodec.writeMsgpack(recordConsumer, v.toMsgpack)
case Some(parameterCodec) =>
parameterCodec.writeMsgPack(recordConsumer, v.toMsgpack)
case None =>
// No record. Skip the value
}
Expand Down

0 comments on commit d2ab8fa

Please sign in to comment.