diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index feea4f239c04d..aa7e67efaca5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager @@ -112,61 +112,49 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf - private def getSerializer( - keySchema: Array[DataType], - valueSchema: Array[DataType], - numPartitions: Int): Serializer = { + private def getSerializer(rowSchema: Array[DataType]): Serializer = { // It is true when there is no field that needs to be write out. // For now, we will not use SparkSqlSerializer2 when noField is true. - val noField = - (keySchema == null || keySchema.length == 0) && - (valueSchema == null || valueSchema.length == 0) + val noField = rowSchema == null || rowSchema.length == 0 val useSqlSerializer2 = child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. - SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. + SparkSqlSerializer2.support(rowSchema) && // The schema of row is supported. !noField - val serializer = if (useSqlSerializer2) { + if (useSqlSerializer2) { logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(keySchema, valueSchema) + new SparkSqlSerializer2(rowSchema) } else { logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } - - serializer } protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { + val rowSchema = child.output.map(_.dataType).toArray + val serializer = getSerializer(rowSchema) newPartitioning match { case HashPartitioning(expressions, numPartitions) => - val keySchema = expressions.map(_.dataType).toArray - val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, valueSchema, numPartitions) val part = new HashPartitioner(numPartitions) - val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - iter.map(r => (hashExpressions(r).copy(), r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[InternalRow, InternalRow]() - iter.map(r => mutablePair.update(hashExpressions(r), r)) + val rdd: RDD[Product2[Int, InternalRow]] = { + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + iter.map(r => (part.getPartition(hashExpressions(r)), r.copy())) + } + } else { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map(r => mutablePair.update(part.getPartition(hashExpressions(r)), r)) + } } } - val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) - shuffled.setSerializer(serializer) - shuffled.map(_._2) + new ShuffledRowRDD[InternalRow](rowSchema, rdd, serializer, part.numPartitions) case RangePartitioning(sortingExpressions, numPartitions) => - val keySchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, null, numPartitions) - val childRdd = child.execute() val part: Partitioner = { // Internally, RangePartitioner runs a job on the RDD that samples keys to compute @@ -180,37 +168,35 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una new RangePartitioner(numPartitions, rddForSampling, ascending = true) } - val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { - childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))} - } else { - childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[InternalRow, Null]() - iter.map(row => mutablePair.update(row, null)) + val rdd: RDD[Product2[Int, InternalRow]] = { + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + childRdd.mapPartitions { iter => iter.map(row => (part.getPartition(row), row.copy()))} + } else { + childRdd.mapPartitions { iter => + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map(row => mutablePair.update(part.getPartition(row), row)) + } } } - val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part) - shuffled.setSerializer(serializer) - shuffled.map(_._1) + new ShuffledRowRDD[InternalRow](rowSchema, rdd, serializer, part.numPartitions) case SinglePartition => - val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(null, valueSchema, numPartitions = 1) - val partitioner = new HashPartitioner(1) + val part = new HashPartitioner(1) - val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) { - child.execute().mapPartitions { - iter => iter.map(r => (null, r.copy())) - } - } else { - child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Null, InternalRow]() - iter.map(r => mutablePair.update(null, r)) + val rdd: RDD[Product2[Int, InternalRow]] = { + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + child.execute().mapPartitions { + iter => iter.map(r => (0, r.copy())) + } + } else { + child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map(r => mutablePair.update(0, r)) + } } } - val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner) - shuffled.setSerializer(serializer) - shuffled.map(_._2) + new ShuffledRowRDD[InternalRow](rowSchema, rdd, serializer, part.numPartitions) case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala new file mode 100644 index 0000000000000..b06f42b17cd5a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.types.DataType + +import scala.reflect.ClassTag + +private class ShuffledRowRDDPartition(val idx: Int) extends Partition { + override val index: Int = idx + override def hashCode(): Int = idx +} + +private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = key.asInstanceOf[Int] +} + +/** + * This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for + * shuffling rows instead of Java key-value pairs. Note that something like this should eventually + * be implemented in Spark core, but that is blocked by some more general refactorings to shuffle + * interfaces / internals. + */ +class ShuffledRowRDD[InternalRow: ClassTag]( + rowSchema: Array[DataType], + @transient var prev: RDD[Product2[Int, InternalRow]], + serializer: Serializer, + numPartitions: Int) + extends RDD[InternalRow](prev.context, Nil) { + + private val part: Partitioner = new PartitionIdPassthrough(numPartitions) + + override def getDependencies: Seq[Dependency[_]] = { + List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer))) + } + + override val partitioner = Some(part) + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i)) + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[Product2[Int, InternalRow]]] + .map(_._2) + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 6ed822dc70d68..c87e2064a8f33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -45,14 +45,12 @@ import org.apache.spark.unsafe.types.UTF8String * the comment of the `serializer` method in [[Exchange]] for more information on it. */ private[sql] class Serializer2SerializationStream( - keySchema: Array[DataType], - valueSchema: Array[DataType], + rowSchema: Array[DataType], out: OutputStream) extends SerializationStream with Logging { private val rowOut = new DataOutputStream(new BufferedOutputStream(out)) - private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) - private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) override def writeObject[T: ClassTag](t: T): SerializationStream = { val kv = t.asInstanceOf[Product2[Row, Row]] @@ -63,12 +61,12 @@ private[sql] class Serializer2SerializationStream( } override def writeKey[T: ClassTag](t: T): SerializationStream = { - writeKeyFunc(t.asInstanceOf[Row]) + // No-op. this } override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeValueFunc(t.asInstanceOf[Row]) + writeRowFunc(t.asInstanceOf[Row]) this } @@ -85,8 +83,7 @@ private[sql] class Serializer2SerializationStream( * The corresponding deserialization stream for [[Serializer2SerializationStream]]. */ private[sql] class Serializer2DeserializationStream( - keySchema: Array[DataType], - valueSchema: Array[DataType], + rowSchema: Array[DataType], in: InputStream) extends DeserializationStream with Logging { @@ -103,22 +100,20 @@ private[sql] class Serializer2DeserializationStream( } // Functions used to return rows for key and value. - private val getKey = rowGenerator(keySchema) - private val getValue = rowGenerator(valueSchema) + private val getRow = rowGenerator(rowSchema) // Functions used to read a serialized row from the InputStream and deserialize it. - private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn) - private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn) + private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn) override def readObject[T: ClassTag](): T = { - (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T] + readValue() } override def readKey[T: ClassTag](): T = { - readKeyFunc(getKey()).asInstanceOf[T] + null.asInstanceOf[T] // intentionally left blank. } override def readValue[T: ClassTag](): T = { - readValueFunc(getValue()).asInstanceOf[T] + readRowFunc(getRow()).asInstanceOf[T] } override def close(): Unit = { @@ -127,8 +122,7 @@ private[sql] class Serializer2DeserializationStream( } private[sql] class SparkSqlSerializer2Instance( - keySchema: Array[DataType], - valueSchema: Array[DataType]) + rowSchema: Array[DataType]) extends SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer = @@ -141,30 +135,25 @@ private[sql] class SparkSqlSerializer2Instance( throw new UnsupportedOperationException("Not supported.") def serializeStream(s: OutputStream): SerializationStream = { - new Serializer2SerializationStream(keySchema, valueSchema, s) + new Serializer2SerializationStream(rowSchema, s) } def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(keySchema, valueSchema, s) + new Serializer2DeserializationStream(rowSchema, s) } } /** * SparkSqlSerializer2 is a special serializer that creates serialization function and * deserialization function based on the schema of data. It assumes that values passed in - * are key/value pairs and values returned from it are also key/value pairs. - * The schema of keys is represented by `keySchema` and that of values is represented by - * `valueSchema`. + * are Rows. */ -private[sql] class SparkSqlSerializer2( - keySchema: Array[DataType], - valueSchema: Array[DataType]) +private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType]) extends Serializer with Logging with Serializable{ - def newInstance(): SerializerInstance = - new SparkSqlSerializer2Instance(keySchema, valueSchema) + def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema) override def supportsRelocationOfSerializedObjects: Boolean = { // SparkSqlSerializer2 is stateless and writes no stream headers