Skip to content

Commit

Permalink
Iniitial cut at removing shuffle on KV pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 16, 2015
1 parent 43dac2c commit 3526868
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
Expand Down Expand Up @@ -112,61 +112,49 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una

@transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf

private def getSerializer(
keySchema: Array[DataType],
valueSchema: Array[DataType],
numPartitions: Int): Serializer = {
private def getSerializer(rowSchema: Array[DataType]): Serializer = {
// It is true when there is no field that needs to be write out.
// For now, we will not use SparkSqlSerializer2 when noField is true.
val noField =
(keySchema == null || keySchema.length == 0) &&
(valueSchema == null || valueSchema.length == 0)
val noField = rowSchema == null || rowSchema.length == 0

val useSqlSerializer2 =
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
SparkSqlSerializer2.support(rowSchema) && // The schema of row is supported.
!noField

val serializer = if (useSqlSerializer2) {
if (useSqlSerializer2) {
logInfo("Using SparkSqlSerializer2.")
new SparkSqlSerializer2(keySchema, valueSchema)
new SparkSqlSerializer2(rowSchema)
} else {
logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
}

serializer
}

protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") {
val rowSchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(rowSchema)
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(keySchema, valueSchema, numPartitions)
val part = new HashPartitioner(numPartitions)

val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
val mutablePair = new MutablePair[InternalRow, InternalRow]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
val rdd: RDD[Product2[Int, InternalRow]] = {
if (needToCopyObjectsBeforeShuffle(part, serializer)) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (part.getPartition(hashExpressions(r)), r.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map(r => mutablePair.update(part.getPartition(hashExpressions(r)), r))
}
}
}
val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part)
shuffled.setSerializer(serializer)
shuffled.map(_._2)
new ShuffledRowRDD[InternalRow](rowSchema, rdd, serializer, part.numPartitions)

case RangePartitioning(sortingExpressions, numPartitions) =>
val keySchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(keySchema, null, numPartitions)

val childRdd = child.execute()
val part: Partitioner = {
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
Expand All @@ -180,37 +168,35 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
}

val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))}
} else {
childRdd.mapPartitions { iter =>
val mutablePair = new MutablePair[InternalRow, Null]()
iter.map(row => mutablePair.update(row, null))
val rdd: RDD[Product2[Int, InternalRow]] = {
if (needToCopyObjectsBeforeShuffle(part, serializer)) {
childRdd.mapPartitions { iter => iter.map(row => (part.getPartition(row), row.copy()))}
} else {
childRdd.mapPartitions { iter =>
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map(row => mutablePair.update(part.getPartition(row), row))
}
}
}

val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part)
shuffled.setSerializer(serializer)
shuffled.map(_._1)
new ShuffledRowRDD[InternalRow](rowSchema, rdd, serializer, part.numPartitions)

case SinglePartition =>
val valueSchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(null, valueSchema, numPartitions = 1)
val partitioner = new HashPartitioner(1)
val part = new HashPartitioner(1)

val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) {
child.execute().mapPartitions {
iter => iter.map(r => (null, r.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Null, InternalRow]()
iter.map(r => mutablePair.update(null, r))
val rdd: RDD[Product2[Int, InternalRow]] = {
if (needToCopyObjectsBeforeShuffle(part, serializer)) {
child.execute().mapPartitions {
iter => iter.map(r => (0, r.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map(r => mutablePair.update(0, r))
}
}
}
val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner)
shuffled.setSerializer(serializer)
shuffled.map(_._2)
new ShuffledRowRDD[InternalRow](rowSchema, rdd, serializer, part.numPartitions)

case _ => sys.error(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.types.DataType

import scala.reflect.ClassTag

private class ShuffledRowRDDPartition(val idx: Int) extends Partition {
override val index: Int = idx
override def hashCode(): Int = idx
}

private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

/**
* This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for
* shuffling rows instead of Java key-value pairs. Note that something like this should eventually
* be implemented in Spark core, but that is blocked by some more general refactorings to shuffle
* interfaces / internals.
*/
class ShuffledRowRDD[InternalRow: ClassTag](
rowSchema: Array[DataType],
@transient var prev: RDD[Product2[Int, InternalRow]],
serializer: Serializer,
numPartitions: Int)
extends RDD[InternalRow](prev.context, Nil) {

private val part: Partitioner = new PartitionIdPassthrough(numPartitions)

override def getDependencies: Seq[Dependency[_]] = {
List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer)))
}

override val partitioner = Some(part)

override def getPartitions: Array[Partition] = {
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i))
}

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[Product2[Int, InternalRow]]]
.map(_._2)
}

override def clearDependencies() {
super.clearDependencies()
prev = null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,12 @@ import org.apache.spark.unsafe.types.UTF8String
* the comment of the `serializer` method in [[Exchange]] for more information on it.
*/
private[sql] class Serializer2SerializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
rowSchema: Array[DataType],
out: OutputStream)
extends SerializationStream with Logging {

private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut)

override def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]]
Expand All @@ -63,12 +61,12 @@ private[sql] class Serializer2SerializationStream(
}

override def writeKey[T: ClassTag](t: T): SerializationStream = {
writeKeyFunc(t.asInstanceOf[Row])
// No-op.
this
}

override def writeValue[T: ClassTag](t: T): SerializationStream = {
writeValueFunc(t.asInstanceOf[Row])
writeRowFunc(t.asInstanceOf[Row])
this
}

Expand All @@ -85,8 +83,7 @@ private[sql] class Serializer2SerializationStream(
* The corresponding deserialization stream for [[Serializer2SerializationStream]].
*/
private[sql] class Serializer2DeserializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
rowSchema: Array[DataType],
in: InputStream)
extends DeserializationStream with Logging {

Expand All @@ -103,22 +100,20 @@ private[sql] class Serializer2DeserializationStream(
}

// Functions used to return rows for key and value.
private val getKey = rowGenerator(keySchema)
private val getValue = rowGenerator(valueSchema)
private val getRow = rowGenerator(rowSchema)
// Functions used to read a serialized row from the InputStream and deserialize it.
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn)

override def readObject[T: ClassTag](): T = {
(readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
readValue()
}

override def readKey[T: ClassTag](): T = {
readKeyFunc(getKey()).asInstanceOf[T]
null.asInstanceOf[T] // intentionally left blank.
}

override def readValue[T: ClassTag](): T = {
readValueFunc(getValue()).asInstanceOf[T]
readRowFunc(getRow()).asInstanceOf[T]
}

override def close(): Unit = {
Expand All @@ -127,8 +122,7 @@ private[sql] class Serializer2DeserializationStream(
}

private[sql] class SparkSqlSerializer2Instance(
keySchema: Array[DataType],
valueSchema: Array[DataType])
rowSchema: Array[DataType])
extends SerializerInstance {

def serialize[T: ClassTag](t: T): ByteBuffer =
Expand All @@ -141,30 +135,25 @@ private[sql] class SparkSqlSerializer2Instance(
throw new UnsupportedOperationException("Not supported.")

def serializeStream(s: OutputStream): SerializationStream = {
new Serializer2SerializationStream(keySchema, valueSchema, s)
new Serializer2SerializationStream(rowSchema, s)
}

def deserializeStream(s: InputStream): DeserializationStream = {
new Serializer2DeserializationStream(keySchema, valueSchema, s)
new Serializer2DeserializationStream(rowSchema, s)
}
}

/**
* SparkSqlSerializer2 is a special serializer that creates serialization function and
* deserialization function based on the schema of data. It assumes that values passed in
* are key/value pairs and values returned from it are also key/value pairs.
* The schema of keys is represented by `keySchema` and that of values is represented by
* `valueSchema`.
* are Rows.
*/
private[sql] class SparkSqlSerializer2(
keySchema: Array[DataType],
valueSchema: Array[DataType])
private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType])
extends Serializer
with Logging
with Serializable{

def newInstance(): SerializerInstance =
new SparkSqlSerializer2Instance(keySchema, valueSchema)
def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema)

override def supportsRelocationOfSerializedObjects: Boolean = {
// SparkSqlSerializer2 is stateless and writes no stream headers
Expand Down

0 comments on commit 3526868

Please sign in to comment.