Skip to content

Commit

Permalink
airframe-parquet: Add record writer (#1785)
Browse files Browse the repository at this point in the history
* Add record writer inteface
* Support various inputs
* Add exception handling
* Add usage note
  • Loading branch information
xerial committed Aug 14, 2021
1 parent cdd3f95 commit 555960b
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@ import org.apache.parquet.io.api.{Binary, RecordConsumer}
import org.apache.parquet.schema.LogicalTypeAnnotation.stringType
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.{MessageType, Type}
import wvlet.airframe.codec.PrimitiveCodec.{BooleanCodec, DoubleCodec, FloatCodec, IntCodec, LongCodec, StringCodec}
import wvlet.airframe.codec.{MessageCodec, MessageCodecFactory}
import wvlet.airframe.msgpack.spi.MsgPack
import wvlet.airframe.codec.PrimitiveCodec.{
BooleanCodec,
DoubleCodec,
FloatCodec,
IntCodec,
LongCodec,
StringCodec,
ValueCodec
}
import wvlet.airframe.codec.{JSONCodec, MessageCodec, MessageCodecException, MessageCodecFactory}
import wvlet.airframe.json.JSONParseException
import wvlet.airframe.msgpack.spi.Value.{ArrayValue, BinaryValue, MapValue, StringValue}
import wvlet.airframe.msgpack.spi.{MessagePack, MsgPack, Value}
import wvlet.airframe.parquet.AirframeParquetWriter.ParquetCodec
import wvlet.airframe.surface.{Parameter, Surface}
import wvlet.airframe.surface.{CName, Parameter, Surface}
import wvlet.log.LogSupport

import scala.reflect.runtime.{universe => ru}
Expand All @@ -56,6 +66,23 @@ object AirframeParquetWriter {
}
}

class RecordWriterBuilder(schema: MessageType, file: OutputFile)
extends ParquetWriter.Builder[Any, RecordWriterBuilder](file: OutputFile) {
override def self(): RecordWriterBuilder = this
override def getWriteSupport(conf: Configuration): WriteSupport[Any] = {
new AirframeParquetRecordWriterSupport(schema)
}
}

def recordWriterBuilder(path: String, schema: MessageType, conf: Configuration): RecordWriterBuilder = {
val fsPath = new Path(path)
val file = HadoopOutputFile.fromPath(fsPath, conf)
val b = new RecordWriterBuilder(schema, file).withConf(conf)
// Use snappy by default
b.withCompressionCodec(CompressionCodecName.SNAPPY)
.withWriteMode(ParquetFileWriter.Mode.OVERWRITE)
}

/**
* Convert object --[MessageCodec]--> msgpack --[MessageCodec]--> Parquet type --> RecordConsumer
* @param tpe
Expand All @@ -67,6 +94,10 @@ object AirframeParquetWriter {

def write(recordConsumer: RecordConsumer, v: Any): Unit = {
val msgpack = codec.asInstanceOf[MessageCodec[Any]].toMsgPack(v)
writeMsgpack(recordConsumer, msgpack)
}

def writeMsgpack(recordConsumer: RecordConsumer, msgpack: MsgPack): Unit = {
recordConsumer.startField(tpe.getName, index)
writeValue(recordConsumer, msgpack)
recordConsumer.endField(tpe.getName, index)
Expand Down Expand Up @@ -169,3 +200,128 @@ class AirframeParquetWriteSupport[A](surface: Surface) extends WriteSupport[A] w
}
}
}

class AirframeParquetRecordWriterSupport(schema: MessageType) extends WriteSupport[Any] with LogSupport {
private var recordConsumer: RecordConsumer = null

override def init(configuration: Configuration): WriteContext = {
new WriteContext(schema, Map.empty[String, String].asJava)
}

override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
this.recordConsumer = recordConsumer
}

private val codec = new ParquetRecordCodec(schema)

override def write(record: Any): Unit = {
require(recordConsumer != null)

try {
recordConsumer.startMessage()
codec.pack(record, recordConsumer)
} finally {
recordConsumer.endMessage()
}

}
}

/**
* Ajust any input objects into the shape of the Parquet schema
* @param schema
*/
class ParquetRecordCodec(schema: MessageType) extends LogSupport {

private val columnNames: IndexedSeq[String] =
schema.getFields.asScala.map(x => CName.toCanonicalName(x.getName)).toIndexedSeq
private val parquetCodecTable: Map[String, ParquetCodec] = {
schema.getFields.asScala.zipWithIndex
.map { case (f, index) =>
val cKey = CName.toCanonicalName(f.getName)
cKey -> AirframeParquetWriter.parquetCodecOf(f, index, ValueCodec)
}.toMap[String, ParquetCodec]
}

private val anyCodec = MessageCodec.of[Any]

def pack(obj: Any, recordConsumer: RecordConsumer): Unit = {
val msgpack =
try {
anyCodec.toMsgPack(obj)
} catch {
case e: MessageCodecException =>
throw new IllegalArgumentException(s"Cannot convert the input into MsgPack: ${obj}", e)
}
val value = ValueCodec.fromMsgPack(msgpack)
packValue(value, recordConsumer)
}

def packValue(value: Value, recordConsumer: RecordConsumer): Unit = {
trace(s"packValue: ${value}")

def writeColumnValue(columnName: String, v: Value): Unit = {
parquetCodecTable.get(columnName) match {
case Some(parquetCodec) =>
parquetCodec.writeMsgpack(recordConsumer, v.toMsgpack)
case None =>
// No record. Skip the value
}
}

value match {
case arr: ArrayValue =>
// Array value
if (arr.size == schema.getFieldCount) {
for ((e, colIndex) <- arr.elems.zipWithIndex) {
val colName = columnNames(colIndex)
writeColumnValue(colName, e)
}
} else {
// Invalid shape
throw new IllegalArgumentException(s"${arr} size doesn't match with ${schema}")
}
case m: MapValue =>
for ((k, v) <- m.entries) {
val keyValue = k.toString
val cKey = CName.toCanonicalName(keyValue)
writeColumnValue(cKey, v)
}
case b: BinaryValue =>
// Assume it's a message pack value
try {
val v = ValueCodec.fromMsgPack(b.v)
packValue(v, recordConsumer)
} catch {
case e: MessageCodecException =>
invalidInput(b, e)
}
case s: StringValue =>
val str = s.toString
if (str.startsWith("{") || str.startsWith("[")) {
// Assume the input is a json object or an array
try {
val msgpack = JSONCodec.toMsgPack(str)
val value = ValueCodec.fromMsgPack(msgpack)
packValue(value, recordConsumer)
} catch {
case e: JSONParseException =>
// Not a json value.
invalidInput(s, e)
}
} else {
invalidInput(s, null)
}
case _ =>
invalidInput(value, null)
}
}

private def invalidInput(v: Value, cause: Throwable): Nothing = {
throw new IllegalArgumentException(
s"The input for ${schema} must be Map[String, Any], Array, MsgPack, or JSON strings: ${v}",
cause
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@ import scala.reflect.runtime.{universe => ru}

object Parquet extends LogSupport {

/**
* Create a Parquet writer that accepts records represented in Map, Array, JSON, MsgPack, etc.
* @param path
* @param schema
* @param hadoopConf
* @param config
* @return
*/
def newRecordWriter(
path: String,
schema: MessageType,
hadoopConf: Configuration = new Configuration(),
config: AirframeParquetWriter.RecordWriterBuilder => AirframeParquetWriter.RecordWriterBuilder =
identity[AirframeParquetWriter.RecordWriterBuilder](_)
): ParquetWriter[Any] = {
val b = AirframeParquetWriter.recordWriterBuilder(path, schema, hadoopConf)
val builder = config(b)
builder.build()
}

def newWriter[A: ru.TypeTag](
path: String,
// Hadoop filesystem specific configuration, e.g., fs.s3a.access.key
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.parquet

import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.{MessageType, PrimitiveType, Types}
import wvlet.airframe.codec.PrimitiveCodec.AnyCodec
import wvlet.airframe.control.Control.withResource
import wvlet.airframe.surface.Surface
import wvlet.airspec.AirSpec
import wvlet.log.io.IOUtil

/**
*/
object ParquetRecordWriterTest extends AirSpec {
case class MyRecord(id: Int, name: String)
private val schema = Parquet.toParquetSchema(Surface.of[MyRecord])

test("write generic records with a schema") {
IOUtil.withTempFile("target/tmp-record", ".parquet") { file =>
withResource(Parquet.newRecordWriter(file.getPath, schema)) { writer =>
writer.write(Map("id" -> 1, "name" -> "leo"))
writer.write(Array(2, "yui"))
writer.write("""{"id":3, "name":"aina"}""")
writer.write("""[4, "ruri"]""")
writer.write(AnyCodec.toMsgPack(Map("id" -> 5, "name" -> "xxx")))
}

withResource(Parquet.newReader[Map[String, Any]](file.getPath)) { reader =>
reader.read() shouldBe Map("id" -> 1, "name" -> "leo")
reader.read() shouldBe Map("id" -> 2, "name" -> "yui")
reader.read() shouldBe Map("id" -> 3, "name" -> "aina")
reader.read() shouldBe Map("id" -> 4, "name" -> "ruri")
reader.read() shouldBe Map("id" -> 5, "name" -> "xxx")
reader.read() shouldBe null
}
}
}

test("throw an exception for an invalid input") {
IOUtil.withTempFile("target/tmp-record-invalid", ".parquet") { file =>
withResource(Parquet.newRecordWriter(file.getPath, schema)) { writer =>
intercept[IllegalArgumentException] {
writer.write("{broken json data}")
}
intercept[IllegalArgumentException] {
writer.write("not a json data")
}
intercept[IllegalArgumentException] {
// Broken MessagePack data
writer.write(Array[Byte](0x1))
}
intercept[IllegalArgumentException] {
// Insufficient array size
writer.write(Array(1))
}
intercept[IllegalArgumentException] {
// Too large array size
writer.write(Array(1, 2, 3))
}
intercept[IllegalArgumentException] {
writer.write(null)
}
}
}
}

case class RecordOpt(id: Int, flag: Option[Int] = None)
private val schema2 = new MessageType(
"my record",
Types.required(PrimitiveTypeName.INT32).named("id"),
Types.optional(PrimitiveTypeName.INT32).named("flag")
)

test("write records with Option") {
IOUtil.withTempFile("target/tmp-record-opt", ".parquet") { file =>
withResource(Parquet.newRecordWriter(file.getPath, schema2)) { writer =>
writer.write(RecordOpt(1, Some(1)))
writer.write(RecordOpt(2, None))
writer.write("""{"id":"3"}""")
}

withResource(Parquet.newReader[Map[String, Any]](file.getPath)) { reader =>
reader.read() shouldBe Map("id" -> 1, "flag" -> 1)
reader.read() shouldBe Map("id" -> 2)
reader.read() shouldBe Map("id" -> 3)
reader.read() shouldBe null
}
}

}
}
23 changes: 23 additions & 0 deletions docs/airframe-parquet.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,29 @@ val j1 = jsonReader.read() // {"id":1,"name":"leo"}
val j2 = jsonReader.read() // {"id":2,"name":"yui"}
jsonReader.read() // null
jsonReader.close()

// Writing dynamically generated records
import org.apache.parquet.schema._
// Create a Parquet schema
val schema = new MessageType(
"MyEntry",
Types.required(PrimitiveTypeName.INT32).named("id"),
Types.optional(PrimitiveTypeName.BINARY).as(stringType).named("name")
)
// Create a record writer for the given schema
val recordWriter = Parquet.newRecordWriter(path = "record.parquet", schema)
// Write a record using Map (column name -> value)
recordWriter.write(Map("id" -> 1, "name" -> "leo"))
// Write a record using JSON object
recordWriter.write("""{"id":2, "name":"yui"}""")
// Write a record using Array
recordWriter.write(Seq(3, "aina"))
// Write a record using JSON array
recordWriter.write("""[4, "xxx"]""")
// You can use case classes as input as well
recordWriter.write(MyEntry(5, "yyy"))

recordWriter.close()
```

### Using with AWS S3
Expand Down

0 comments on commit 555960b

Please sign in to comment.