Skip to content

Commit

Permalink
[SPARK-9023] [SQL] Efficiency improvements for UnsafeRows in Exchange
Browse files Browse the repository at this point in the history
This pull request aims to improve the performance of SQL's Exchange operator when shuffling UnsafeRows.  It also makes several general efficiency improvements to Exchange.

Key changes:

- When performing hash partitioning, the old Exchange projected the partitioning columns into a new row then passed a `(partitioningColumRow: InternalRow, row: InternalRow)` pair into the shuffle. This is very inefficient because it ends up redundantly serializing the partitioning columns only to immediately discard them after the shuffle.  After this patch's changes, Exchange now shuffles `(partitionId: Int, row: InternalRow)` pairs.  This still isn't optimal, since we're still shuffling extra data that we don't need, but it's significantly more efficient than the old implementation; in the future, we may be able to further optimize this once we implement a new shuffle write interface that accepts non-key-value-pair inputs.
- Exchange's `compute()` method has been significantly simplified; the new code has less duplication and thus is easier to understand.
- When the Exchange's input operator produces UnsafeRows, Exchange will use a specialized `UnsafeRowSerializer` to serialize these rows.  This serializer is significantly more efficient since it simply copies the UnsafeRow's underlying bytes.  Note that this approach does not work for UnsafeRows that use the ObjectPool mechanism; I did not add support for this because we are planning to remove ObjectPool in the next few weeks.

Author: Josh Rosen <joshrosen@databricks.com>

Closes apache#7456 from JoshRosen/unsafe-exchange and squashes the following commits:

7e75259 [Josh Rosen] Fix cast in SparkSqlSerializer2Suite
0082515 [Josh Rosen] Some additional comments + small cleanup to remove an unused parameter
a27cfc1 [Josh Rosen] Add missing newline
741973c [Josh Rosen] Add simple test of UnsafeRow shuffling in Exchange.
359c6a4 [Josh Rosen] Remove println() and add comments
93904e7 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-exchange
8dd3ff2 [Josh Rosen] Exchange outputs UnsafeRows when its child outputs them
dd9c66d [Josh Rosen] Fix for copying logic
035af21 [Josh Rosen] Add logic for choosing when to use UnsafeRowSerializer
7876f31 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-shuffle
cbea80b [Josh Rosen] Add UnsafeRowSerializer
0f2ac86 [Josh Rosen] Import ordering
3ca8515 [Josh Rosen] Big code simplification in Exchange
3526868 [Josh Rosen] Iniitial cut at removing shuffle on KV pairs
  • Loading branch information
JoshRosen authored and rxin committed Jul 20, 2015
1 parent 972d890 commit 79ec072
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 116 deletions.
132 changes: 49 additions & 83 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
Expand All @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.MutablePair
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}

Expand All @@ -44,6 +43,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una

override def output: Seq[Attribute] = child.output

override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows

override def canProcessSafeRows: Boolean = true

override def canProcessUnsafeRows: Boolean = true

/**
* Determines whether records must be defensively copied before being sent to the shuffle.
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
Expand Down Expand Up @@ -112,109 +117,70 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una

@transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf

private def getSerializer(
keySchema: Array[DataType],
valueSchema: Array[DataType],
numPartitions: Int): Serializer = {
private val serializer: Serializer = {
val rowDataTypes = child.output.map(_.dataType).toArray
// It is true when there is no field that needs to be write out.
// For now, we will not use SparkSqlSerializer2 when noField is true.
val noField =
(keySchema == null || keySchema.length == 0) &&
(valueSchema == null || valueSchema.length == 0)
val noField = rowDataTypes == null || rowDataTypes.length == 0

val useSqlSerializer2 =
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported.
!noField

val serializer = if (useSqlSerializer2) {
if (child.outputsUnsafeRows) {
logInfo("Using UnsafeRowSerializer.")
new UnsafeRowSerializer(child.output.size)
} else if (useSqlSerializer2) {
logInfo("Using SparkSqlSerializer2.")
new SparkSqlSerializer2(keySchema, valueSchema)
new SparkSqlSerializer2(rowDataTypes)
} else {
logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
}

serializer
}

protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(keySchema, valueSchema, numPartitions)
val part = new HashPartitioner(numPartitions)

val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
val mutablePair = new MutablePair[InternalRow, InternalRow]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
}
val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part)
shuffled.setSerializer(serializer)
shuffled.map(_._2)

val rdd = child.execute()
val part: Partitioner = newPartitioning match {
case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
case RangePartitioning(sortingExpressions, numPartitions) =>
val keySchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(keySchema, null, numPartitions)

val childRdd = child.execute()
val part: Partitioner = {
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
val rddForSampling = childRdd.mapPartitions { iter =>
val mutablePair = new MutablePair[InternalRow, Null]()
iter.map(row => mutablePair.update(row.copy(), null))
}
// TODO: RangePartitioner should take an Ordering.
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
}

val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))}
} else {
childRdd.mapPartitions { iter =>
val mutablePair = new MutablePair[InternalRow, Null]()
iter.map(row => mutablePair.update(row, null))
}
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
val rddForSampling = rdd.mapPartitions { iter =>
val mutablePair = new MutablePair[InternalRow, Null]()
iter.map(row => mutablePair.update(row.copy(), null))
}

val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part)
shuffled.setSerializer(serializer)
shuffled.map(_._1)

implicit val ordering = new RowOrdering(sortingExpressions, child.output)
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
case SinglePartition =>
val valueSchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(null, valueSchema, numPartitions = 1)
val partitioner = new HashPartitioner(1)

val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) {
child.execute().mapPartitions {
iter => iter.map(r => (null, r.copy()))
}
} else {
child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Null, InternalRow]()
iter.map(r => mutablePair.update(null, r))
}
new Partitioner {
override def numPartitions: Int = 1
override def getPartition(key: Any): Int = 0
}
val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner)
shuffled.setSerializer(serializer)
shuffled.map(_._2)

case _ => sys.error(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match {
case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
case RangePartitioning(_, _) | SinglePartition => identity
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
}
val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
if (needToCopyObjectsBeforeShuffle(part, serializer)) {
rdd.mapPartitions { iter =>
val getPartitionKey = getPartitionKeyExtractor()
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
}
} else {
rdd.mapPartitions { iter =>
val getPartitionKey = getPartitionKeyExtractor()
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
}
}
}
new ShuffledRowRDD(rddWithPartitionIds, serializer, part.numPartitions)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.DataType

private class ShuffledRowRDDPartition(val idx: Int) extends Partition {
override val index: Int = idx
override def hashCode(): Int = idx
}

/**
* A dummy partitioner for use with records whose partition ids have been pre-computed (i.e. for
* use on RDDs of (Int, Row) pairs where the Int is a partition id in the expected range).
*/
private class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

/**
* This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for
* shuffling rows instead of Java key-value pairs. Note that something like this should eventually
* be implemented in Spark core, but that is blocked by some more general refactorings to shuffle
* interfaces / internals.
*
* @param prev the RDD being shuffled. Elements of this RDD are (partitionId, Row) pairs.
* Partition ids should be in the range [0, numPartitions - 1].
* @param serializer the serializer used during the shuffle.
* @param numPartitions the number of post-shuffle partitions.
*/
class ShuffledRowRDD(
@transient var prev: RDD[Product2[Int, InternalRow]],
serializer: Serializer,
numPartitions: Int)
extends RDD[InternalRow](prev.context, Nil) {

private val part: Partitioner = new PartitionIdPassthrough(numPartitions)

override def getDependencies: Seq[Dependency[_]] = {
List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer)))
}

override val partitioner = Some(part)

override def getPartitions: Array[Partition] = {
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i))
}

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[Product2[Int, InternalRow]]]
.map(_._2)
}

override def clearDependencies() {
super.clearDependencies()
prev = null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,12 @@ import org.apache.spark.unsafe.types.UTF8String
* the comment of the `serializer` method in [[Exchange]] for more information on it.
*/
private[sql] class Serializer2SerializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
rowSchema: Array[DataType],
out: OutputStream)
extends SerializationStream with Logging {

private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut)

override def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]]
Expand All @@ -63,12 +61,12 @@ private[sql] class Serializer2SerializationStream(
}

override def writeKey[T: ClassTag](t: T): SerializationStream = {
writeKeyFunc(t.asInstanceOf[Row])
// No-op.
this
}

override def writeValue[T: ClassTag](t: T): SerializationStream = {
writeValueFunc(t.asInstanceOf[Row])
writeRowFunc(t.asInstanceOf[Row])
this
}

Expand All @@ -85,8 +83,7 @@ private[sql] class Serializer2SerializationStream(
* The corresponding deserialization stream for [[Serializer2SerializationStream]].
*/
private[sql] class Serializer2DeserializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
rowSchema: Array[DataType],
in: InputStream)
extends DeserializationStream with Logging {

Expand All @@ -103,22 +100,20 @@ private[sql] class Serializer2DeserializationStream(
}

// Functions used to return rows for key and value.
private val getKey = rowGenerator(keySchema)
private val getValue = rowGenerator(valueSchema)
private val getRow = rowGenerator(rowSchema)
// Functions used to read a serialized row from the InputStream and deserialize it.
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn)

override def readObject[T: ClassTag](): T = {
(readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
readValue()
}

override def readKey[T: ClassTag](): T = {
readKeyFunc(getKey()).asInstanceOf[T]
null.asInstanceOf[T] // intentionally left blank.
}

override def readValue[T: ClassTag](): T = {
readValueFunc(getValue()).asInstanceOf[T]
readRowFunc(getRow()).asInstanceOf[T]
}

override def close(): Unit = {
Expand All @@ -127,8 +122,7 @@ private[sql] class Serializer2DeserializationStream(
}

private[sql] class SparkSqlSerializer2Instance(
keySchema: Array[DataType],
valueSchema: Array[DataType])
rowSchema: Array[DataType])
extends SerializerInstance {

def serialize[T: ClassTag](t: T): ByteBuffer =
Expand All @@ -141,30 +135,25 @@ private[sql] class SparkSqlSerializer2Instance(
throw new UnsupportedOperationException("Not supported.")

def serializeStream(s: OutputStream): SerializationStream = {
new Serializer2SerializationStream(keySchema, valueSchema, s)
new Serializer2SerializationStream(rowSchema, s)
}

def deserializeStream(s: InputStream): DeserializationStream = {
new Serializer2DeserializationStream(keySchema, valueSchema, s)
new Serializer2DeserializationStream(rowSchema, s)
}
}

/**
* SparkSqlSerializer2 is a special serializer that creates serialization function and
* deserialization function based on the schema of data. It assumes that values passed in
* are key/value pairs and values returned from it are also key/value pairs.
* The schema of keys is represented by `keySchema` and that of values is represented by
* `valueSchema`.
* are Rows.
*/
private[sql] class SparkSqlSerializer2(
keySchema: Array[DataType],
valueSchema: Array[DataType])
private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType])
extends Serializer
with Logging
with Serializable{

def newInstance(): SerializerInstance =
new SparkSqlSerializer2Instance(keySchema, valueSchema)
def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema)

override def supportsRelocationOfSerializedObjects: Boolean = {
// SparkSqlSerializer2 is stateless and writes no stream headers
Expand Down
Loading

0 comments on commit 79ec072

Please sign in to comment.