diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala new file mode 100644 index 0000000000000..4c25852862df3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import com.google.common.io.ByteStreams + +import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.PlatformDependent + + +class UnsafeRowSerializer(numFields: Int) extends Serializer { + override def newInstance(): SerializerInstance = new UnsafeRowSerializerInstance(numFields) + override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true +} + +private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { + + private[this] val EOF: Int = -1 + + override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { + private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) + private[this] val dOut: DataOutputStream = new DataOutputStream(out) + + override def writeValue[T: ClassTag](value: T): SerializationStream = { + val row = value.asInstanceOf[UnsafeRow] + assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool") + dOut.writeInt(row.getSizeInBytes) + println("Size in bytes is " + row.getSizeInBytes) + var dataRemaining: Int = row.getSizeInBytes + val baseObject = row.getBaseObject + var rowReadPosition: Long = row.getBaseOffset + while (dataRemaining > 0) { + val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining) + PlatformDependent.copyMemory( + baseObject, + rowReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer) + out.write(writeBuffer, 0, toTransfer) + rowReadPosition += toTransfer + dataRemaining -= toTransfer + } + this + } + override def writeKey[T: ClassTag](key: T): SerializationStream = { + assert(key.isInstanceOf[Int]) + this + } + override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = + throw new UnsupportedOperationException + override def writeObject[T: ClassTag](t: T): SerializationStream = + throw new UnsupportedOperationException + override def flush(): Unit = dOut.flush() + override def close(): Unit = { + writeBuffer = null + dOut.writeInt(EOF) + dOut.close() + } + } + + override def deserializeStream(in: InputStream): DeserializationStream = { + new DeserializationStream { + private[this] val dIn: DataInputStream = new DataInputStream(in) + private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) + private[this] var row: UnsafeRow = new UnsafeRow() + private[this] var rowTuple: (Int, UnsafeRow) = (0, row) + + override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { + new Iterator[(Int, UnsafeRow)] { + private[this] var rowSize: Int = dIn.readInt() + + override def hasNext: Boolean = rowSize != EOF + + override def next(): (Int, UnsafeRow) = { + if (rowBuffer.length < rowSize) { + rowBuffer = new Array[Byte](rowSize) + } + ByteStreams.readFully(in, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + rowSize = dIn.readInt() // read the next row's size + if (rowSize == EOF) { // We are returning the last row in this stream + val _rowTuple = rowTuple + // Null these out so that the byte array can be garbage collected once the entire + // iterator has been consumed + row = null + rowBuffer = null + rowTuple = null + _rowTuple + } else { + rowTuple + } + } + } + } + override def asIterator: Iterator[Any] = throw new UnsupportedOperationException + override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException + override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException + override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException + override def close(): Unit = dIn.close() + } + } + + override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala new file mode 100644 index 0000000000000..bd788ec8c14b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter} +import org.apache.spark.sql.catalyst.util.ObjectPool +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent + +class UnsafeRowSerializerSuite extends SparkFunSuite { + + private def toUnsafeRow( + row: Row, + schema: Array[DataType], + objPool: ObjectPool = null): UnsafeRow = { + val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val rowConverter = new UnsafeRowConverter(schema) + val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow) + val byteArray = new Array[Byte](rowSizeInBytes) + rowConverter.writeRow( + internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes, objPool) + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo( + byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes, objPool) + unsafeRow + } + + test("toUnsafeRow() test helper method") { + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + assert(row.getString(0) === unsafeRow.get(0).toString) + assert(row.getInt(1) === unsafeRow.getInt(1)) + } + + test("basic row serialization") { + val rows = Seq(Row("Hello", 1), Row("World", 2)) + val unsafeRows = rows.map(row => toUnsafeRow(row, Array(StringType, IntegerType))) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val baos = new ByteArrayOutputStream() + val serializerStream = serializer.serializeStream(baos) + for (unsafeRow <- unsafeRows) { + serializerStream.writeKey(0) + serializerStream.writeValue(unsafeRow) + } + serializerStream.close() + val deserializerIter = serializer.deserializeStream( + new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + for (expectedRow <- unsafeRows) { + val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 + assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) + assert(expectedRow.getString(0) === actualRow.getString(0)) + assert(expectedRow.getInt(1) === actualRow.getInt(1)) + } + assert(!deserializerIter.hasNext) + } +}