diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index feea4f239c04d..2750053594f99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -44,6 +43,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + + override def canProcessSafeRows: Boolean = true + + override def canProcessUnsafeRows: Boolean = true + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -112,109 +117,70 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf - private def getSerializer( - keySchema: Array[DataType], - valueSchema: Array[DataType], - numPartitions: Int): Serializer = { + private val serializer: Serializer = { + val rowDataTypes = child.output.map(_.dataType).toArray // It is true when there is no field that needs to be write out. // For now, we will not use SparkSqlSerializer2 when noField is true. - val noField = - (keySchema == null || keySchema.length == 0) && - (valueSchema == null || valueSchema.length == 0) + val noField = rowDataTypes == null || rowDataTypes.length == 0 val useSqlSerializer2 = child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. - SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. + SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported. !noField - val serializer = if (useSqlSerializer2) { + if (child.outputsUnsafeRows) { + logInfo("Using UnsafeRowSerializer.") + new UnsafeRowSerializer(child.output.size) + } else if (useSqlSerializer2) { logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(keySchema, valueSchema) + new SparkSqlSerializer2(rowDataTypes) } else { logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } - - serializer } protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { - newPartitioning match { - case HashPartitioning(expressions, numPartitions) => - val keySchema = expressions.map(_.dataType).toArray - val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, valueSchema, numPartitions) - val part = new HashPartitioner(numPartitions) - - val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - iter.map(r => (hashExpressions(r).copy(), r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[InternalRow, InternalRow]() - iter.map(r => mutablePair.update(hashExpressions(r), r)) - } - } - val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) - shuffled.setSerializer(serializer) - shuffled.map(_._2) - + val rdd = child.execute() + val part: Partitioner = newPartitioning match { + case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => - val keySchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, null, numPartitions) - - val childRdd = child.execute() - val part: Partitioner = { - // Internally, RangePartitioner runs a job on the RDD that samples keys to compute - // partition bounds. To get accurate samples, we need to copy the mutable keys. - val rddForSampling = childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[InternalRow, Null]() - iter.map(row => mutablePair.update(row.copy(), null)) - } - // TODO: RangePartitioner should take an Ordering. - implicit val ordering = new RowOrdering(sortingExpressions, child.output) - new RangePartitioner(numPartitions, rddForSampling, ascending = true) - } - - val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { - childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))} - } else { - childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[InternalRow, Null]() - iter.map(row => mutablePair.update(row, null)) - } + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + val rddForSampling = rdd.mapPartitions { iter => + val mutablePair = new MutablePair[InternalRow, Null]() + iter.map(row => mutablePair.update(row.copy(), null)) } - - val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part) - shuffled.setSerializer(serializer) - shuffled.map(_._1) - + implicit val ordering = new RowOrdering(sortingExpressions, child.output) + new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => - val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(null, valueSchema, numPartitions = 1) - val partitioner = new HashPartitioner(1) - - val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) { - child.execute().mapPartitions { - iter => iter.map(r => (null, r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Null, InternalRow]() - iter.map(r => mutablePair.update(null, r)) - } + new Partitioner { + override def numPartitions: Int = 1 + override def getPartition(key: Any): Int = 0 } - val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner) - shuffled.setSerializer(serializer) - shuffled.map(_._2) - case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } + def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match { + case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() + case RangePartitioning(_, _) | SinglePartition => identity + case _ => sys.error(s"Exchange not implemented for $newPartitioning") + } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + rdd.mapPartitions { iter => + val getPartitionKey = getPartitionKeyExtractor() + iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } + } + } else { + rdd.mapPartitions { iter => + val getPartitionKey = getPartitionKeyExtractor() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + } + } + } + new ShuffledRowRDD(rddWithPartitionIds, serializer, part.numPartitions) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala new file mode 100644 index 0000000000000..88f5b13c8f248 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.DataType + +private class ShuffledRowRDDPartition(val idx: Int) extends Partition { + override val index: Int = idx + override def hashCode(): Int = idx +} + +/** + * A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for + * use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range). + */ +private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = key.asInstanceOf[Int] +} + +/** + * This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for + * shuffling rows instead of Java key-value pairs. Note that something like this should eventually + * be implemented in Spark core, but that is blocked by some more general refactorings to shuffle + * interfaces / internals. + * + * @param prev the RDD being shuffled. Elements of this RDD are (partitionId, Row) pairs. + * Partition ids should be in the range [0, numPartitions - 1]. + * @param serializer the serializer used during the shuffle. + * @param numPartitions the number of post-shuffle partitions. + */ +class ShuffledRowRDD( + @transient var prev: RDD[Product2[Int, InternalRow]], + serializer: Serializer, + numPartitions: Int) + extends RDD[InternalRow](prev.context, Nil) { + + private val part: Partitioner = new PartitionIdPassthrough(numPartitions) + + override def getDependencies: Seq[Dependency[_]] = { + List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer))) + } + + override val partitioner = Some(part) + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i)) + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[Product2[Int, InternalRow]]] + .map(_._2) + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 6ed822dc70d68..c87e2064a8f33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -45,14 +45,12 @@ import org.apache.spark.unsafe.types.UTF8String * the comment of the `serializer` method in [[Exchange]] for more information on it. */ private[sql] class Serializer2SerializationStream( - keySchema: Array[DataType], - valueSchema: Array[DataType], + rowSchema: Array[DataType], out: OutputStream) extends SerializationStream with Logging { private val rowOut = new DataOutputStream(new BufferedOutputStream(out)) - private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) - private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) override def writeObject[T: ClassTag](t: T): SerializationStream = { val kv = t.asInstanceOf[Product2[Row, Row]] @@ -63,12 +61,12 @@ private[sql] class Serializer2SerializationStream( } override def writeKey[T: ClassTag](t: T): SerializationStream = { - writeKeyFunc(t.asInstanceOf[Row]) + // No-op. this } override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeValueFunc(t.asInstanceOf[Row]) + writeRowFunc(t.asInstanceOf[Row]) this } @@ -85,8 +83,7 @@ private[sql] class Serializer2SerializationStream( * The corresponding deserialization stream for [[Serializer2SerializationStream]]. */ private[sql] class Serializer2DeserializationStream( - keySchema: Array[DataType], - valueSchema: Array[DataType], + rowSchema: Array[DataType], in: InputStream) extends DeserializationStream with Logging { @@ -103,22 +100,20 @@ private[sql] class Serializer2DeserializationStream( } // Functions used to return rows for key and value. - private val getKey = rowGenerator(keySchema) - private val getValue = rowGenerator(valueSchema) + private val getRow = rowGenerator(rowSchema) // Functions used to read a serialized row from the InputStream and deserialize it. - private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn) - private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn) + private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn) override def readObject[T: ClassTag](): T = { - (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T] + readValue() } override def readKey[T: ClassTag](): T = { - readKeyFunc(getKey()).asInstanceOf[T] + null.asInstanceOf[T] // intentionally left blank. } override def readValue[T: ClassTag](): T = { - readValueFunc(getValue()).asInstanceOf[T] + readRowFunc(getRow()).asInstanceOf[T] } override def close(): Unit = { @@ -127,8 +122,7 @@ private[sql] class Serializer2DeserializationStream( } private[sql] class SparkSqlSerializer2Instance( - keySchema: Array[DataType], - valueSchema: Array[DataType]) + rowSchema: Array[DataType]) extends SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer = @@ -141,30 +135,25 @@ private[sql] class SparkSqlSerializer2Instance( throw new UnsupportedOperationException("Not supported.") def serializeStream(s: OutputStream): SerializationStream = { - new Serializer2SerializationStream(keySchema, valueSchema, s) + new Serializer2SerializationStream(rowSchema, s) } def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(keySchema, valueSchema, s) + new Serializer2DeserializationStream(rowSchema, s) } } /** * SparkSqlSerializer2 is a special serializer that creates serialization function and * deserialization function based on the schema of data. It assumes that values passed in - * are key/value pairs and values returned from it are also key/value pairs. - * The schema of keys is represented by `keySchema` and that of values is represented by - * `valueSchema`. + * are Rows. */ -private[sql] class SparkSqlSerializer2( - keySchema: Array[DataType], - valueSchema: Array[DataType]) +private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType]) extends Serializer with Logging with Serializable{ - def newInstance(): SerializerInstance = - new SparkSqlSerializer2Instance(keySchema, valueSchema) + def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema) override def supportsRelocationOfSerializedObjects: Boolean = { // SparkSqlSerializer2 is stateless and writes no stream headers diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala new file mode 100644 index 0000000000000..19503ed00056c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import com.google.common.io.ByteStreams + +import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.PlatformDependent + +/** + * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as + * bytes, this serializer simply copies those bytes to the underlying output stream. When + * deserializing a stream of rows, instances of this serializer mutate and return a single UnsafeRow + * instance that is backed by an on-heap byte array. + * + * Note that this serializer implements only the [[Serializer]] methods that are used during + * shuffle, so certain [[SerializerInstance]] methods will throw UnsupportedOperationException. + * + * This serializer does not support UnsafeRows that use + * [[org.apache.spark.sql.catalyst.util.ObjectPool]]. + * + * @param numFields the number of fields in the row being serialized. + */ +private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with Serializable { + override def newInstance(): SerializerInstance = new UnsafeRowSerializerInstance(numFields) + override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true +} + +private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { + + private[this] val EOF: Int = -1 + + override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { + private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) + private[this] val dOut: DataOutputStream = new DataOutputStream(out) + + override def writeValue[T: ClassTag](value: T): SerializationStream = { + val row = value.asInstanceOf[UnsafeRow] + assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool") + dOut.writeInt(row.getSizeInBytes) + var dataRemaining: Int = row.getSizeInBytes + val baseObject = row.getBaseObject + var rowReadPosition: Long = row.getBaseOffset + while (dataRemaining > 0) { + val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining) + PlatformDependent.copyMemory( + baseObject, + rowReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer) + out.write(writeBuffer, 0, toTransfer) + rowReadPosition += toTransfer + dataRemaining -= toTransfer + } + this + } + override def writeKey[T: ClassTag](key: T): SerializationStream = { + assert(key.isInstanceOf[Int]) + this + } + override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = + throw new UnsupportedOperationException + override def writeObject[T: ClassTag](t: T): SerializationStream = + throw new UnsupportedOperationException + override def flush(): Unit = dOut.flush() + override def close(): Unit = { + writeBuffer = null + dOut.writeInt(EOF) + dOut.close() + } + } + + override def deserializeStream(in: InputStream): DeserializationStream = { + new DeserializationStream { + private[this] val dIn: DataInputStream = new DataInputStream(in) + private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) + private[this] var row: UnsafeRow = new UnsafeRow() + private[this] var rowTuple: (Int, UnsafeRow) = (0, row) + + override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { + new Iterator[(Int, UnsafeRow)] { + private[this] var rowSize: Int = dIn.readInt() + + override def hasNext: Boolean = rowSize != EOF + + override def next(): (Int, UnsafeRow) = { + if (rowBuffer.length < rowSize) { + rowBuffer = new Array[Byte](rowSize) + } + ByteStreams.readFully(in, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + rowSize = dIn.readInt() // read the next row's size + if (rowSize == EOF) { // We are returning the last row in this stream + val _rowTuple = rowTuple + // Null these out so that the byte array can be garbage collected once the entire + // iterator has been consumed + row = null + rowBuffer = null + rowTuple = null + _rowTuple + } else { + rowTuple + } + } + } + } + override def asIterator: Iterator[Any] = throw new UnsupportedOperationException + override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException + override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException + override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException + override def close(): Unit = dIn.close() + } + } + + override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 82bef269b069f..fdd7ad59aba50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -56,11 +56,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - @transient lazy val conditionEvaluator: (InternalRow) => Boolean = - newPredicate(condition, child.output) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - iter.filter(conditionEvaluator) + iter.filter(newPredicate(condition, child.output)) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala new file mode 100644 index 0000000000000..79e903c2bbd40 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition + +class ExchangeSuite extends SparkPlanTest { + test("shuffling UnsafeRows in exchange") { + val input = (1 to 1000).map(Tuple1.apply) + checkAnswer( + input.toDF(), + plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + input.map(Row.fromTuple) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 71f6b26bcd01a..4a53fadd7e099 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -132,8 +132,8 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll expectedSerializerClass: Class[T]): Unit = { executedPlan.foreach { case exchange: Exchange => - val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] - val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + val shuffledRDD = exchange.execute() + val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] val serializerNotSetMessage = s"Expected $expectedSerializerClass as the serializer of Exchange. " + s"However, the serializer was not set." diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala new file mode 100644 index 0000000000000..bd788ec8c14b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter} +import org.apache.spark.sql.catalyst.util.ObjectPool +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent + +class UnsafeRowSerializerSuite extends SparkFunSuite { + + private def toUnsafeRow( + row: Row, + schema: Array[DataType], + objPool: ObjectPool = null): UnsafeRow = { + val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val rowConverter = new UnsafeRowConverter(schema) + val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow) + val byteArray = new Array[Byte](rowSizeInBytes) + rowConverter.writeRow( + internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes, objPool) + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo( + byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes, objPool) + unsafeRow + } + + test("toUnsafeRow() test helper method") { + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + assert(row.getString(0) === unsafeRow.get(0).toString) + assert(row.getInt(1) === unsafeRow.getInt(1)) + } + + test("basic row serialization") { + val rows = Seq(Row("Hello", 1), Row("World", 2)) + val unsafeRows = rows.map(row => toUnsafeRow(row, Array(StringType, IntegerType))) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val baos = new ByteArrayOutputStream() + val serializerStream = serializer.serializeStream(baos) + for (unsafeRow <- unsafeRows) { + serializerStream.writeKey(0) + serializerStream.writeValue(unsafeRow) + } + serializerStream.close() + val deserializerIter = serializer.deserializeStream( + new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + for (expectedRow <- unsafeRows) { + val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 + assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) + assert(expectedRow.getString(0) === actualRow.getString(0)) + assert(expectedRow.getInt(1) === actualRow.getInt(1)) + } + assert(!deserializerIter.hasNext) + } +}