Skip to content

Commit

Permalink
SPARK-2045 Sort-based shuffle
Browse files Browse the repository at this point in the history
This adds a new ShuffleManager based on sorting, as described in https://issues.apache.org/jira/browse/SPARK-2045. The bulk of the code is in an ExternalSorter class that is similar to ExternalAppendOnlyMap, but sorts key-value pairs by partition ID and can be used to create a single sorted file with a map task's output. (Longer-term I think this can take on the remaining functionality in ExternalAppendOnlyMap and replace it so we don't have code duplication.)

The main TODOs still left are:
- [x] enabling ExternalSorter to merge across spilled files
  - [x] with an Ordering
  - [x] without an Ordering, using the keys' hash codes
- [x] adding more tests (e.g. a version of our shuffle suite that runs on this)
- [x] rebasing on top of the size-tracking refactoring in apache#1165 when that is merged
- [x] disabling spilling if spark.shuffle.spill is set to false

Despite this though, this seems to work pretty well (running successfully in cases where the hash shuffle would OOM, such as 1000 reduce tasks on executors with only 1G memory), and it seems to be comparable in speed or faster than hash-based shuffle (it will create much fewer files for the OS to keep track of). So I'm posting it to get some early feedback.

After these TODOs are done, I'd also like to enable ExternalSorter to sort data within each partition by a key as well, which will allow us to use it to implement external spilling in reduce tasks in `sortByKey`.

Author: Matei Zaharia <matei@databricks.com>

Closes apache#1499 from mateiz/sort-based-shuffle and squashes the following commits:

bd841f9 [Matei Zaharia] Various review comments
d1c137f [Matei Zaharia] Various review comments
a611159 [Matei Zaharia] Compile fixes due to rebase
62c56c8 [Matei Zaharia] Fix ShuffledRDD sometimes not returning Tuple2s.
f617432 [Matei Zaharia] Fix a failing test (seems to be due to change in SizeTracker logic)
9464d5f [Matei Zaharia] Simplify code and fix conflicts after latest rebase
0174149 [Matei Zaharia] Add cleanup behavior and cleanup tests for sort-based shuffle
eb4ee0d [Matei Zaharia] Remove customizable element type in ShuffledRDD
fa2e8db [Matei Zaharia] Allow nextBatchStream to be called after we're done looking at all streams
a34b352 [Matei Zaharia] Fix tracking of indices within a partition in SpillReader, and add test
03e1006 [Matei Zaharia] Add a SortShuffleSuite that runs ShuffleSuite with sort-based shuffle
3c7ff1f [Matei Zaharia] Obey the spark.shuffle.spill setting in ExternalSorter
ad65fbd [Matei Zaharia] Rebase on top of Aaron's Sorter change, and use Sorter in our buffer
44d2a93 [Matei Zaharia] Use estimateSize instead of atGrowThreshold to test collection sizes
5686f71 [Matei Zaharia] Optimize merging phase for in-memory only data:
5461cbb [Matei Zaharia] Review comments and more tests (e.g. tests with 1 element per partition)
e9ad356 [Matei Zaharia] Update ContextCleanerSuite to make sure shuffle cleanup tests use hash shuffle (since they were written for it)
c72362a [Matei Zaharia] Added bug fix and test for when iterators are empty
de1fb40 [Matei Zaharia] Make trait SizeTrackingCollection private[spark]
4988d16 [Matei Zaharia] tweak
c1b7572 [Matei Zaharia] Small optimization
ba7db7f [Matei Zaharia] Handle null keys in hash-based comparator, and add tests for collisions
ef4e397 [Matei Zaharia] Support for partial aggregation even without an Ordering
4b7a5ce [Matei Zaharia] More tests, and ability to sort data if a total ordering is given
e1f84be [Matei Zaharia] Fix disk block manager test
5a40a1c [Matei Zaharia] More tests
614f1b4 [Matei Zaharia] Add spill metrics to map tasks
cc52caf [Matei Zaharia] Add more error handling and tests for error cases
bbf359d [Matei Zaharia] More work
3a56341 [Matei Zaharia] More partial work towards sort-based shuffle
7a0895d [Matei Zaharia] Some more partial work towards sort-based shuffle
b615476 [Matei Zaharia] Scaffolding for sort-based shuffle
  • Loading branch information
mateiz authored and rxin committed Jul 31, 2014
1 parent da50176 commit e966284
Show file tree
Hide file tree
Showing 35 changed files with 1,969 additions and 159 deletions.
24 changes: 16 additions & 8 deletions core/src/main/scala/org/apache/spark/Aggregator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,23 @@ case class Aggregator[K, V, C] (
} else {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
// TODO: Make this non optional in a future release
Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
// Update task metrics if context is not null
// TODO: Make context non optional in a future release
Option(context).foreach { c =>
c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
}
combiners.iterator
}
}

@deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0")
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] =
def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] =
combineCombinersByKey(iter, null)

def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = {
def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext)
: Iterator[(K, C)] =
{
if (!externalSorting) {
val combiners = new AppendOnlyMap[K,C]
var kc: Product2[K, C] = null
Expand All @@ -85,9 +90,12 @@ case class Aggregator[K, V, C] (
val pair = iter.next()
combiners.insert(pair._1, pair._2)
}
// TODO: Make this non optional in a future release
Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
// Update task metrics if context is not null
// TODO: Make context non-optional in a future release
Option(context).foreach { c =>
c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
}
combiners.iterator
}
}
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class SparkContext(config: SparkConf) extends Logging {
value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} {
executorEnvs(envKey) = value
}
Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
executorEnvs("SPARK_PREPEND_CLASSES") = v
}
// The Mesos scheduler backend relies on this environment variable to set executor memory.
Expand Down Expand Up @@ -1203,10 +1203,10 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
* If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
* check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
* If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
* check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
* if not.
*
*
* @param f the closure to clean
* @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
* @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
*/
def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
sample(withReplacement, fraction, Utils.random.nextLong)

/**
* Return a sampled subset of this RDD.
*/
Expand Down
7 changes: 4 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.rdd

import scala.language.existentials

import java.io.{IOException, ObjectOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.language.existentials

import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
Expand Down Expand Up @@ -157,8 +158,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((it, depNum) <- rddIterators) {
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
}
context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled
context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled
context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag

import org.apache.spark.{Logging, RangePartitioner}
import org.apache.spark.annotation.DeveloperApi

/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
Expand All @@ -43,10 +44,10 @@ import org.apache.spark.{Logging, RangePartitioner}
*/
class OrderedRDDFunctions[K : Ordering : ClassTag,
V: ClassTag,
P <: Product2[K, V] : ClassTag](
P <: Product2[K, V] : ClassTag] @DeveloperApi() (
self: RDD[P])
extends Logging with Serializable {

extends Logging with Serializable
{
private val ordering = implicitly[Ordering[K]]

/**
Expand All @@ -55,9 +56,12 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
// TODO: this currently doesn't work on P other than Tuple2!
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size)
: RDD[(K, V)] =
{
val part = new RangePartitioner(numPartitions, self, ascending)
new ShuffledRDD[K, V, V, P](self, part)
new ShuffledRDD[K, V, V](self, part)
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
} else {
new ShuffledRDD[K, V, C, (K, C)](self, partitioner)
new ShuffledRDD[K, V, C](self, partitioner)
.setSerializer(serializer)
.setAggregator(aggregator)
.setMapSideCombine(mapSideCombine)
Expand Down Expand Up @@ -425,7 +425,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
if (self.partitioner == Some(partitioner)) {
self
} else {
new ShuffledRDD[K, V, V, (K, V)](self, partitioner)
new ShuffledRDD[K, V, V](self, partitioner)
}
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ abstract class RDD[T: ClassTag](
val distributePartition = (index: Int, items: Iterator[T]) => {
var position = (new Random(index)).nextInt(numPartitions)
items.map { t =>
// Note that the hash code of the key will just be the key itself. The HashPartitioner
// Note that the hash code of the key will just be the key itself. The HashPartitioner
// will mod it with the number of total partitions.
position = position + 1
(position, t)
Expand All @@ -341,7 +341,7 @@ abstract class RDD[T: ClassTag](

// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
new HashPartitioner(numPartitions)),
numPartitions).values
} else {
Expand All @@ -352,8 +352,8 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean,
fraction: Double,
def sample(withReplacement: Boolean,
fraction: Double,
seed: Long = Utils.random.nextLong): RDD[T] = {
require(fraction >= 0.0, "Negative fraction value: " + fraction)
if (withReplacement) {
Expand Down
17 changes: 9 additions & 8 deletions core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* @tparam V the value class.
* @tparam C the combiner class.
*/
// TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs
@DeveloperApi
class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
class ShuffledRDD[K, V, C](
@transient var prev: RDD[_ <: Product2[K, V]],
part: Partitioner)
extends RDD[P](prev.context, Nil) {
extends RDD[(K, C)](prev.context, Nil) {

private var serializer: Option[Serializer] = None

Expand All @@ -52,25 +53,25 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
private var mapSideCombine: Boolean = false

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = {
this.serializer = Option(serializer)
this
}

/** Set key ordering for RDD's shuffle. */
def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = {
def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C] = {
this.keyOrdering = Option(keyOrdering)
this
}

/** Set aggregator for RDD's shuffle. */
def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = {
def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C] = {
this.aggregator = Option(aggregator)
this
}

/** Set mapSideCombine flag for RDD's shuffle. */
def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = {
def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C] = {
this.mapSideCombine = mapSideCombine
this
}
Expand All @@ -85,11 +86,11 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
}

override def compute(split: Partition, context: TaskContext): Iterator[P] = {
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[P]]
.asInstanceOf[Iterator[(K, C)]]
}

override def clearDependencies() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.shuffle._
* A ShuffleManager using hashing, that creates one output file per reduce partition on each
* mapper (possibly reusing these across waves of tasks).
*/
class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
/* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
override def registerShuffle[K, V, C](
shuffleId: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}

class HashShuffleReader[K, C](
private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
Expand All @@ -47,7 +47,8 @@ class HashShuffleReader[K, C](
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
iter
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
}

// Sort the output if there is a sort ordering defined.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus

class HashShuffleWriter[K, V](
private[spark] class HashShuffleWriter[K, V](
handle: BaseShuffleHandle[K, V, _],
mapId: Int,
context: TaskContext)
Expand All @@ -33,6 +33,10 @@ class HashShuffleWriter[K, V](
private val dep = handle.dependency
private val numOutputSplits = dep.partitioner.numPartitions
private val metrics = context.taskMetrics

// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
// we don't try deleting files, etc twice.
private var stopping = false

private val blockManager = SparkEnv.get.blockManager
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.sort

import java.io.{DataInputStream, FileInputStream}

import org.apache.spark.shuffle._
import org.apache.spark.{TaskContext, ShuffleDependency}
import org.apache.spark.shuffle.hash.HashShuffleReader
import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId}

private[spark] class SortShuffleManager extends ShuffleManager {
/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
*/
override def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}

/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
* Called on executors by reduce tasks.
*/
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
// We currently use the same block store shuffle fetcher as the hash-based shuffle.
new HashShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}

/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
: ShuffleWriter[K, V] = {
new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
}

/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Unit = {}

/** Shut down this ShuffleManager. */
override def stop(): Unit = {}

/** Get the location of a block in a map output file. Uses the index file we create for it. */
def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = {
// The block is actually going to be a range of a single map output file for this map, so
// figure out the ID of the consolidated file, then the offset within that from our index
val consolidatedId = blockId.copy(reduceId = 0)
val indexFile = diskManager.getFile(consolidatedId.name + ".index")
val in = new DataInputStream(new FileInputStream(indexFile))
try {
in.skip(blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset)
} finally {
in.close()
}
}
}
Loading

0 comments on commit e966284

Please sign in to comment.