From 79ec07290d0b4d16f1643af83824d926304c8f46 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 19 Jul 2015 23:41:28 -0700 Subject: [PATCH] [SPARK-9023] [SQL] Efficiency improvements for UnsafeRows in Exchange This pull request aims to improve the performance of SQL's Exchange operator when shuffling UnsafeRows. It also makes several general efficiency improvements to Exchange. Key changes: - When performing hash partitioning, the old Exchange projected the partitioning columns into a new row then passed a `(partitioningColumRow: InternalRow, row: InternalRow)` pair into the shuffle. This is very inefficient because it ends up redundantly serializing the partitioning columns only to immediately discard them after the shuffle. After this patch's changes, Exchange now shuffles `(partitionId: Int, row: InternalRow)` pairs. This still isn't optimal, since we're still shuffling extra data that we don't need, but it's significantly more efficient than the old implementation; in the future, we may be able to further optimize this once we implement a new shuffle write interface that accepts non-key-value-pair inputs. - Exchange's `compute()` method has been significantly simplified; the new code has less duplication and thus is easier to understand. - When the Exchange's input operator produces UnsafeRows, Exchange will use a specialized `UnsafeRowSerializer` to serialize these rows. This serializer is significantly more efficient since it simply copies the UnsafeRow's underlying bytes. Note that this approach does not work for UnsafeRows that use the ObjectPool mechanism; I did not add support for this because we are planning to remove ObjectPool in the next few weeks. Author: Josh Rosen Closes #7456 from JoshRosen/unsafe-exchange and squashes the following commits: 7e75259 [Josh Rosen] Fix cast in SparkSqlSerializer2Suite 0082515 [Josh Rosen] Some additional comments + small cleanup to remove an unused parameter a27cfc1 [Josh Rosen] Add missing newline 741973c [Josh Rosen] Add simple test of UnsafeRow shuffling in Exchange. 359c6a4 [Josh Rosen] Remove println() and add comments 93904e7 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-exchange 8dd3ff2 [Josh Rosen] Exchange outputs UnsafeRows when its child outputs them dd9c66d [Josh Rosen] Fix for copying logic 035af21 [Josh Rosen] Add logic for choosing when to use UnsafeRowSerializer 7876f31 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-shuffle cbea80b [Josh Rosen] Add UnsafeRowSerializer 0f2ac86 [Josh Rosen] Import ordering 3ca8515 [Josh Rosen] Big code simplification in Exchange 3526868 [Josh Rosen] Iniitial cut at removing shuffle on KV pairs --- .../apache/spark/sql/execution/Exchange.scala | 132 ++++++---------- .../spark/sql/execution/ShuffledRowRDD.scala | 80 ++++++++++ .../sql/execution/SparkSqlSerializer2.scala | 43 ++---- .../sql/execution/UnsafeRowSerializer.scala | 142 ++++++++++++++++++ .../spark/sql/execution/basicOperators.scala | 5 +- .../spark/sql/execution/ExchangeSuite.scala | 32 ++++ .../execution/SparkSqlSerializer2Suite.scala | 4 +- .../execution/UnsafeRowSerializerSuite.scala | 76 ++++++++++ 8 files changed, 398 insertions(+), 116 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index feea4f239c04d..2750053594f99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -44,6 +43,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + + override def canProcessSafeRows: Boolean = true + + override def canProcessUnsafeRows: Boolean = true + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -112,109 +117,70 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf - private def getSerializer( - keySchema: Array[DataType], - valueSchema: Array[DataType], - numPartitions: Int): Serializer = { + private val serializer: Serializer = { + val rowDataTypes = child.output.map(_.dataType).toArray // It is true when there is no field that needs to be write out. // For now, we will not use SparkSqlSerializer2 when noField is true. - val noField = - (keySchema == null || keySchema.length == 0) && - (valueSchema == null || valueSchema.length == 0) + val noField = rowDataTypes == null || rowDataTypes.length == 0 val useSqlSerializer2 = child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. - SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. + SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported. !noField - val serializer = if (useSqlSerializer2) { + if (child.outputsUnsafeRows) { + logInfo("Using UnsafeRowSerializer.") + new UnsafeRowSerializer(child.output.size) + } else if (useSqlSerializer2) { logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(keySchema, valueSchema) + new SparkSqlSerializer2(rowDataTypes) } else { logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } - - serializer } protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { - newPartitioning match { - case HashPartitioning(expressions, numPartitions) => - val keySchema = expressions.map(_.dataType).toArray - val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, valueSchema, numPartitions) - val part = new HashPartitioner(numPartitions) - - val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - iter.map(r => (hashExpressions(r).copy(), r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[InternalRow, InternalRow]() - iter.map(r => mutablePair.update(hashExpressions(r), r)) - } - } - val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) - shuffled.setSerializer(serializer) - shuffled.map(_._2) - + val rdd = child.execute() + val part: Partitioner = newPartitioning match { + case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => - val keySchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, null, numPartitions) - - val childRdd = child.execute() - val part: Partitioner = { - // Internally, RangePartitioner runs a job on the RDD that samples keys to compute - // partition bounds. To get accurate samples, we need to copy the mutable keys. - val rddForSampling = childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[InternalRow, Null]() - iter.map(row => mutablePair.update(row.copy(), null)) - } - // TODO: RangePartitioner should take an Ordering. - implicit val ordering = new RowOrdering(sortingExpressions, child.output) - new RangePartitioner(numPartitions, rddForSampling, ascending = true) - } - - val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { - childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))} - } else { - childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[InternalRow, Null]() - iter.map(row => mutablePair.update(row, null)) - } + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + val rddForSampling = rdd.mapPartitions { iter => + val mutablePair = new MutablePair[InternalRow, Null]() + iter.map(row => mutablePair.update(row.copy(), null)) } - - val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part) - shuffled.setSerializer(serializer) - shuffled.map(_._1) - + implicit val ordering = new RowOrdering(sortingExpressions, child.output) + new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => - val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(null, valueSchema, numPartitions = 1) - val partitioner = new HashPartitioner(1) - - val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) { - child.execute().mapPartitions { - iter => iter.map(r => (null, r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Null, InternalRow]() - iter.map(r => mutablePair.update(null, r)) - } + new Partitioner { + override def numPartitions: Int = 1 + override def getPartition(key: Any): Int = 0 } - val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner) - shuffled.setSerializer(serializer) - shuffled.map(_._2) - case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } + def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { + case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case RangePartitioning(_, _) | SinglePartition => identity + case _ => sys.error(s"Exchange not implemented for $newPartitioning") + } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + rdd.mapPartitions { iter => + val getPartitionKey = getPartitionKeyExtractor() + iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } + } + } else { + rdd.mapPartitions { iter => + val getPartitionKey = getPartitionKeyExtractor() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + } + } + } + new ShuffledRowRDD(rddWithPartitionIds, serializer, part.numPartitions) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala new file mode 100644 index 0000000000000..88f5b13c8f248 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.DataType + +private class ShuffledRowRDDPartition(val idx: Int) extends Partition { + override val index: Int = idx + override def hashCode(): Int = idx +} + +/** + * A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for + * use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range). + */ +private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = key.asInstanceOf[Int] +} + +/** + * This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for + * shuffling rows instead of Java key-value pairs. Note that something like this should eventually + * be implemented in Spark core, but that is blocked by some more general refactorings to shuffle + * interfaces / internals. + * + * @param prev the RDD being shuffled. Elements of this RDD are (partitionId, Row) pairs. + * Partition ids should be in the range [0, numPartitions - 1]. + * @param serializer the serializer used during the shuffle. + * @param numPartitions the number of post-shuffle partitions. + */ +class ShuffledRowRDD( + @transient var prev: RDD[Product2[Int, InternalRow]], + serializer: Serializer, + numPartitions: Int) + extends RDD[InternalRow](prev.context, Nil) { + + private val part: Partitioner = new PartitionIdPassthrough(numPartitions) + + override def getDependencies: Seq[Dependency[_]] = { + List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer))) + } + + override val partitioner = Some(part) + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i)) + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[Product2[Int, InternalRow]]] + .map(_._2) + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 6ed822dc70d68..c87e2064a8f33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -45,14 +45,12 @@ import org.apache.spark.unsafe.types.UTF8String * the comment of the `serializer` method in [[Exchange]] for more information on it. */ private[sql] class Serializer2SerializationStream( - keySchema: Array[DataType], - valueSchema: Array[DataType], + rowSchema: Array[DataType], out: OutputStream) extends SerializationStream with Logging { private val rowOut = new DataOutputStream(new BufferedOutputStream(out)) - private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) - private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) override def writeObject[T: ClassTag](t: T): SerializationStream = { val kv = t.asInstanceOf[Product2[Row, Row]] @@ -63,12 +61,12 @@ private[sql] class Serializer2SerializationStream( } override def writeKey[T: ClassTag](t: T): SerializationStream = { - writeKeyFunc(t.asInstanceOf[Row]) + // No-op. this } override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeValueFunc(t.asInstanceOf[Row]) + writeRowFunc(t.asInstanceOf[Row]) this } @@ -85,8 +83,7 @@ private[sql] class Serializer2SerializationStream( * The corresponding deserialization stream for [[Serializer2SerializationStream]]. */ private[sql] class Serializer2DeserializationStream( - keySchema: Array[DataType], - valueSchema: Array[DataType], + rowSchema: Array[DataType], in: InputStream) extends DeserializationStream with Logging { @@ -103,22 +100,20 @@ private[sql] class Serializer2DeserializationStream( } // Functions used to return rows for key and value. - private val getKey = rowGenerator(keySchema) - private val getValue = rowGenerator(valueSchema) + private val getRow = rowGenerator(rowSchema) // Functions used to read a serialized row from the InputStream and deserialize it. - private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn) - private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn) + private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn) override def readObject[T: ClassTag](): T = { - (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T] + readValue() } override def readKey[T: ClassTag](): T = { - readKeyFunc(getKey()).asInstanceOf[T] + null.asInstanceOf[T] // intentionally left blank. } override def readValue[T: ClassTag](): T = { - readValueFunc(getValue()).asInstanceOf[T] + readRowFunc(getRow()).asInstanceOf[T] } override def close(): Unit = { @@ -127,8 +122,7 @@ private[sql] class Serializer2DeserializationStream( } private[sql] class SparkSqlSerializer2Instance( - keySchema: Array[DataType], - valueSchema: Array[DataType]) + rowSchema: Array[DataType]) extends SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer = @@ -141,30 +135,25 @@ private[sql] class SparkSqlSerializer2Instance( throw new UnsupportedOperationException("Not supported.") def serializeStream(s: OutputStream): SerializationStream = { - new Serializer2SerializationStream(keySchema, valueSchema, s) + new Serializer2SerializationStream(rowSchema, s) } def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(keySchema, valueSchema, s) + new Serializer2DeserializationStream(rowSchema, s) } } /** * SparkSqlSerializer2 is a special serializer that creates serialization function and * deserialization function based on the schema of data. It assumes that values passed in - * are key/value pairs and values returned from it are also key/value pairs. - * The schema of keys is represented by `keySchema` and that of values is represented by - * `valueSchema`. + * are Rows. */ -private[sql] class SparkSqlSerializer2( - keySchema: Array[DataType], - valueSchema: Array[DataType]) +private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType]) extends Serializer with Logging with Serializable{ - def newInstance(): SerializerInstance = - new SparkSqlSerializer2Instance(keySchema, valueSchema) + def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema) override def supportsRelocationOfSerializedObjects: Boolean = { // SparkSqlSerializer2 is stateless and writes no stream headers diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala new file mode 100644 index 0000000000000..19503ed00056c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import com.google.common.io.ByteStreams + +import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.PlatformDependent + +/** + * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as + * bytes, this serializer simply copies those bytes to the underlying output stream. When + * deserializing a stream of rows, instances of this serializer mutate and return a single UnsafeRow + * instance that is backed by an on-heap byte array. + * + * Note that this serializer implements only the [[Serializer]] methods that are used during + * shuffle, so certain [[SerializerInstance]] methods will throw UnsupportedOperationException. + * + * This serializer does not support UnsafeRows that use + * [[org.apache.spark.sql.catalyst.util.ObjectPool]]. + * + * @param numFields the number of fields in the row being serialized. + */ +private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with Serializable { + override def newInstance(): SerializerInstance = new UnsafeRowSerializerInstance(numFields) + override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true +} + +private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { + + private[this] val EOF: Int = -1 + + override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { + private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) + private[this] val dOut: DataOutputStream = new DataOutputStream(out) + + override def writeValue[T: ClassTag](value: T): SerializationStream = { + val row = value.asInstanceOf[UnsafeRow] + assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool") + dOut.writeInt(row.getSizeInBytes) + var dataRemaining: Int = row.getSizeInBytes + val baseObject = row.getBaseObject + var rowReadPosition: Long = row.getBaseOffset + while (dataRemaining > 0) { + val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining) + PlatformDependent.copyMemory( + baseObject, + rowReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer) + out.write(writeBuffer, 0, toTransfer) + rowReadPosition += toTransfer + dataRemaining -= toTransfer + } + this + } + override def writeKey[T: ClassTag](key: T): SerializationStream = { + assert(key.isInstanceOf[Int]) + this + } + override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = + throw new UnsupportedOperationException + override def writeObject[T: ClassTag](t: T): SerializationStream = + throw new UnsupportedOperationException + override def flush(): Unit = dOut.flush() + override def close(): Unit = { + writeBuffer = null + dOut.writeInt(EOF) + dOut.close() + } + } + + override def deserializeStream(in: InputStream): DeserializationStream = { + new DeserializationStream { + private[this] val dIn: DataInputStream = new DataInputStream(in) + private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) + private[this] var row: UnsafeRow = new UnsafeRow() + private[this] var rowTuple: (Int, UnsafeRow) = (0, row) + + override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { + new Iterator[(Int, UnsafeRow)] { + private[this] var rowSize: Int = dIn.readInt() + + override def hasNext: Boolean = rowSize != EOF + + override def next(): (Int, UnsafeRow) = { + if (rowBuffer.length < rowSize) { + rowBuffer = new Array[Byte](rowSize) + } + ByteStreams.readFully(in, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + rowSize = dIn.readInt() // read the next row's size + if (rowSize == EOF) { // We are returning the last row in this stream + val _rowTuple = rowTuple + // Null these out so that the byte array can be garbage collected once the entire + // iterator has been consumed + row = null + rowBuffer = null + rowTuple = null + _rowTuple + } else { + rowTuple + } + } + } + } + override def asIterator: Iterator[Any] = throw new UnsupportedOperationException + override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException + override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException + override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException + override def close(): Unit = dIn.close() + } + } + + override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 82bef269b069f..fdd7ad59aba50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -56,11 +56,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - @transient lazy val conditionEvaluator: (InternalRow) => Boolean = - newPredicate(condition, child.output) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - iter.filter(conditionEvaluator) + iter.filter(newPredicate(condition, child.output)) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala new file mode 100644 index 0000000000000..79e903c2bbd40 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition + +class ExchangeSuite extends SparkPlanTest { + test("shuffling UnsafeRows in exchange") { + val input = (1 to 1000).map(Tuple1.apply) + checkAnswer( + input.toDF(), + plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + input.map(Row.fromTuple) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 71f6b26bcd01a..4a53fadd7e099 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -132,8 +132,8 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll expectedSerializerClass: Class[T]): Unit = { executedPlan.foreach { case exchange: Exchange => - val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] - val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + val shuffledRDD = exchange.execute() + val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] val serializerNotSetMessage = s"Expected $expectedSerializerClass as the serializer of Exchange. " + s"However, the serializer was not set." diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala new file mode 100644 index 0000000000000..bd788ec8c14b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter} +import org.apache.spark.sql.catalyst.util.ObjectPool +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent + +class UnsafeRowSerializerSuite extends SparkFunSuite { + + private def toUnsafeRow( + row: Row, + schema: Array[DataType], + objPool: ObjectPool = null): UnsafeRow = { + val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val rowConverter = new UnsafeRowConverter(schema) + val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow) + val byteArray = new Array[Byte](rowSizeInBytes) + rowConverter.writeRow( + internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes, objPool) + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo( + byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes, objPool) + unsafeRow + } + + test("toUnsafeRow() test helper method") { + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + assert(row.getString(0) === unsafeRow.get(0).toString) + assert(row.getInt(1) === unsafeRow.getInt(1)) + } + + test("basic row serialization") { + val rows = Seq(Row("Hello", 1), Row("World", 2)) + val unsafeRows = rows.map(row => toUnsafeRow(row, Array(StringType, IntegerType))) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val baos = new ByteArrayOutputStream() + val serializerStream = serializer.serializeStream(baos) + for (unsafeRow <- unsafeRows) { + serializerStream.writeKey(0) + serializerStream.writeValue(unsafeRow) + } + serializerStream.close() + val deserializerIter = serializer.deserializeStream( + new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + for (expectedRow <- unsafeRows) { + val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 + assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) + assert(expectedRow.getString(0) === actualRow.getString(0)) + assert(expectedRow.getInt(1) === actualRow.getInt(1)) + } + assert(!deserializerIter.hasNext) + } +}