Skip to content

Commit

Permalink
Add UnsafeRowSerializer
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 17, 2015
1 parent 0f2ac86 commit cbea80b
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream}
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import com.google.common.io.ByteStreams

import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.PlatformDependent


class UnsafeRowSerializer(numFields: Int) extends Serializer {
override def newInstance(): SerializerInstance = new UnsafeRowSerializerInstance(numFields)
override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true
}

private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {

private[this] val EOF: Int = -1

override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
private[this] val dOut: DataOutputStream = new DataOutputStream(out)

override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow]
assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool")
dOut.writeInt(row.getSizeInBytes)
println("Size in bytes is " + row.getSizeInBytes)
var dataRemaining: Int = row.getSizeInBytes
val baseObject = row.getBaseObject
var rowReadPosition: Long = row.getBaseOffset
while (dataRemaining > 0) {
val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining)
PlatformDependent.copyMemory(
baseObject,
rowReadPosition,
writeBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
toTransfer)
out.write(writeBuffer, 0, toTransfer)
rowReadPosition += toTransfer
dataRemaining -= toTransfer
}
this
}
override def writeKey[T: ClassTag](key: T): SerializationStream = {
assert(key.isInstanceOf[Int])
this
}
override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream =
throw new UnsupportedOperationException
override def writeObject[T: ClassTag](t: T): SerializationStream =
throw new UnsupportedOperationException
override def flush(): Unit = dOut.flush()
override def close(): Unit = {
writeBuffer = null
dOut.writeInt(EOF)
dOut.close()
}
}

override def deserializeStream(in: InputStream): DeserializationStream = {
new DeserializationStream {
private[this] val dIn: DataInputStream = new DataInputStream(in)
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow()
private[this] var rowTuple: (Int, UnsafeRow) = (0, row)

override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
new Iterator[(Int, UnsafeRow)] {
private[this] var rowSize: Int = dIn.readInt()

override def hasNext: Boolean = rowSize != EOF

override def next(): (Int, UnsafeRow) = {
if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize)
}
ByteStreams.readFully(in, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null)
rowSize = dIn.readInt() // read the next row's size
if (rowSize == EOF) { // We are returning the last row in this stream
val _rowTuple = rowTuple
// Null these out so that the byte array can be garbage collected once the entire
// iterator has been consumed
row = null
rowBuffer = null
rowTuple = null
_rowTuple
} else {
rowTuple
}
}
}
}
override def asIterator: Iterator[Any] = throw new UnsupportedOperationException
override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException
override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException
override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException
override def close(): Unit = dIn.close()
}
}

override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
throw new UnsupportedOperationException
override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
throw new UnsupportedOperationException
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter}
import org.apache.spark.sql.catalyst.util.ObjectPool
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent

class UnsafeRowSerializerSuite extends SparkFunSuite {

private def toUnsafeRow(
row: Row,
schema: Array[DataType],
objPool: ObjectPool = null): UnsafeRow = {
val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]
val rowConverter = new UnsafeRowConverter(schema)
val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow)
val byteArray = new Array[Byte](rowSizeInBytes)
rowConverter.writeRow(
internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes, objPool)
val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(
byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes, objPool)
unsafeRow
}

test("toUnsafeRow() test helper method") {
val row = Row("Hello", 123)
val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))
assert(row.getString(0) === unsafeRow.get(0).toString)
assert(row.getInt(1) === unsafeRow.getInt(1))
}

test("basic row serialization") {
val rows = Seq(Row("Hello", 1), Row("World", 2))
val unsafeRows = rows.map(row => toUnsafeRow(row, Array(StringType, IntegerType)))
val serializer = new UnsafeRowSerializer(numFields = 2).newInstance()
val baos = new ByteArrayOutputStream()
val serializerStream = serializer.serializeStream(baos)
for (unsafeRow <- unsafeRows) {
serializerStream.writeKey(0)
serializerStream.writeValue(unsafeRow)
}
serializerStream.close()
val deserializerIter = serializer.deserializeStream(
new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator
for (expectedRow <- unsafeRows) {
val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2
assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes)
assert(expectedRow.getString(0) === actualRow.getString(0))
assert(expectedRow.getInt(1) === actualRow.getInt(1))
}
assert(!deserializerIter.hasNext)
}
}

0 comments on commit cbea80b

Please sign in to comment.