diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index d99067fb5fefb..d5ab7931433b1 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -19,11 +19,12 @@ package org.apache.spark import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.{ClassTag, classTag} import scala.util.hashing.byteswap32 -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} @@ -106,31 +107,21 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ascending: Boolean = true) extends Partitioner { - private var ordering = implicitly[Ordering[K]] + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. + require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") - @transient private[spark] var singlePass = true // for unit tests + private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions private var rangeBounds: Array[K] = { - if (partitions == 1) { + if (partitions <= 1) { Array.empty } else { - // This is the sample size we need to have roughly balanced output partitions. - val sampleSize = 20.0 * partitions + // This is the sample size we need to have roughly balanced output partitions, capped at 1M. + val sampleSize = math.min(20.0 * partitions, 1e6) // Assume the input partitions are roughly balanced and over-sample a little bit. val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt - val shift = rdd.id - val classTagK = classTag[K] - val sketch = rdd.mapPartitionsWithIndex { (idx, iter) => - val seed = byteswap32(idx + shift) - val (sample, n) = SamplingUtils.reservoirSampleAndCount( - iter.map(_._1), sampleSizePerPartition, seed)(classTagK) - Iterator((idx, n, sample)) - }.collect() - var numItems = 0L - sketch.foreach { case (_, n, _) => - numItems += n - } + val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) if (numItems == 0L) { Array.empty } else { @@ -138,8 +129,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( // to ensure that enough items are collected from that partition. val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) val candidates = ArrayBuffer.empty[(K, Float)] - val imbalancedPartitions = ArrayBuffer.empty[Int] - sketch.foreach { case (idx, n, sample) => + val imbalancedPartitions = mutable.Set.empty[Int] + sketched.foreach { case (idx, n, sample) => if (fraction * n > sampleSizePerPartition) { imbalancedPartitions += idx } else { @@ -151,48 +142,14 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } if (imbalancedPartitions.nonEmpty) { - singlePass = false - val sampleFunc: (TaskContext, Iterator[Product2[K, V]]) => Array[K] = { (context, iter) => - val random = new XORShiftRandom(byteswap32(context.partitionId - shift)) - iter.map(_._1).filter(t => random.nextDouble() < fraction).toArray(classTagK) - } + // Re-sample imbalanced partitions with the desired sampling probability. + val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains) + val seed = byteswap32(-rdd.id - 1) + val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() val weight = (1.0 / fraction).toFloat - val resultHandler: (Int, Array[K]) => Unit = { (_, sample) => - sample.foreach { key => - candidates += ((key, weight)) - } - } - rdd.context.runJob( - rdd, sampleFunc, imbalancedPartitions, allowLocal = false, resultHandler) + candidates ++= reSampled.map(x => (x, weight)) } - val numCandidates = candidates.size - var sumWeights = 0.0 - candidates.foreach { case (_, weight) => - sumWeights += weight - } - val step = sumWeights / partitions - var cumWeight = 0.0 - var target = step - val bounds = ArrayBuffer.empty[K] - val orderedCandidates = candidates.sortBy(_._1) - var i = 0 - var j = 0 - var previousBound = Option.empty[K] - while ((i < numCandidates) && (j < partitions - 1)) { - val (key, weight) = orderedCandidates(i) - cumWeight += weight - if (cumWeight > target) { - // Skip duplicate values. - if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { - bounds += key - target += step - j += 1 - previousBound = Some(key) - } - } - i += 1 - } - bounds.toArray + RangePartitioner.determineBounds(candidates, partitions) } } } @@ -282,3 +239,67 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } } + +private[spark] object RangePartitioner { + + /** + * Sketches the input RDD via reservoir sampling on each partition. + * + * @param rdd the input RDD to sketch + * @param sampleSizePerPartition max sample size per partition + * @return (total number of items, an array of (partitionId, number of items, sample)) + */ + def sketch[K:ClassTag]( + rdd: RDD[K], + sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + val shift = rdd.id + // val classTagK = classTag[K] // to avoid serializing the entire partitioner object + val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => + val seed = byteswap32(idx ^ (shift << 16)) + val (sample, n) = SamplingUtils.reservoirSampleAndCount( + iter, sampleSizePerPartition, seed) + Iterator((idx, n, sample)) + }.collect() + val numItems = sketched.map(_._2.toLong).sum + (numItems, sketched) + } + + /** + * Determines the bounds for range partitioning from candidates with weights indicating how many + * items each represents. Usually this is 1 over the probability used to sample this candidate. + * + * @param candidates unordered candidates with weights + * @param partitions number of partitions + * @return selected bounds + */ + def determineBounds[K:Ordering:ClassTag]( + candidates: ArrayBuffer[(K, Float)], + partitions: Int): Array[K] = { + val ordering = implicitly[Ordering[K]] + val ordered = candidates.sortBy(_._1) + val numCandidates = ordered.size + val sumWeights = ordered.map(_._2.toDouble).sum + val step = sumWeights / partitions + var cumWeight = 0.0 + var target = step + val bounds = ArrayBuffer.empty[K] + var i = 0 + var j = 0 + var previousBound = Option.empty[K] + while ((i < numCandidates) && (j < partitions - 1)) { + val (key, weight) = ordered(i) + cumWeight += weight + if (cumWeight > target) { + // Skip duplicate values. + if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { + bounds += key + target += step + j += 1 + previousBound = Some(key) + } + } + i += 1 + } + bounds.toArray + } +} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 44f800bd61dbe..fc0cee3e8749d 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import scala.collection.mutable.ArrayBuffer import scala.math.abs import org.scalatest.{FunSuite, PrivateMethodTester} @@ -100,6 +101,28 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet partitioner.getPartition(Row(100)) } + test("RangPartitioner.sketch") { + val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => + val random = new java.util.Random(i) + Iterator.fill(i)(random.nextDouble()) + }.cache() + val sampleSizePerPartition = 10 + val (count, sketched) = RangePartitioner.sketch(rdd, sampleSizePerPartition) + assert(count === rdd.count()) + sketched.foreach { case (idx, n, sample) => + assert(n === idx) + assert(sample.size === math.min(n, sampleSizePerPartition)) + } + } + + test("RangePartitioner.determineBounds") { + assert(RangePartitioner.determineBounds(ArrayBuffer.empty[(Int, Float)], 10).isEmpty, + "Bounds on an empty candidates set should be empty.") + val candidates = ArrayBuffer( + (0.7, 2.0f), (0.1, 1.0f), (0.4, 1.0f), (0.3, 1.0f), (0.2, 1.0f), (0.5, 1.0f), (1.0, 3.0f)) + assert(RangePartitioner.determineBounds(candidates, 3) === Array(0.4, 0.7)) + } + test("RangePartitioner should run only one job if data is roughly balanced") { val rdd = sc.makeRDD(0 until 20, 20).flatMap { i => val random = new java.util.Random(i) @@ -108,9 +131,8 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet for (numPartitions <- Seq(10, 20, 40)) { val partitioner = new RangePartitioner(numPartitions, rdd) assert(partitioner.numPartitions === numPartitions) - assert(partitioner.singlePass === true) val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values - assert(counts.max < 2.0 * counts.min) + assert(counts.max < 3.0 * counts.min) } } @@ -122,12 +144,20 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet for (numPartitions <- Seq(2, 4, 8)) { val partitioner = new RangePartitioner(numPartitions, rdd) assert(partitioner.numPartitions === numPartitions) - assert(partitioner.singlePass === false) val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values - assert(counts.max < 2.0 * counts.min) + assert(counts.max < 3.0 * counts.min) } } + test("RangePartitioner should return a single partition for empty RDDs") { + val empty1 = sc.emptyRDD[(Int, Double)] + val partitioner1 = new RangePartitioner(0, empty1) + assert(partitioner1.numPartitions === 1) + val empty2 = sc.makeRDD(0 until 2, 2).flatMap(i => Seq.empty[(Int, Double)]) + val partitioner2 = new RangePartitioner(2, empty2) + assert(partitioner2.numPartitions === 1) + } + test("HashPartitioner not equal to RangePartitioner") { val rdd = sc.parallelize(1 to 10).map(x => (x, x)) val rangeP2 = new RangePartitioner(2, rdd) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 2924de112934c..e9dcb7fe1b959 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -604,6 +604,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("sort an empty RDD") { + val data = sc.emptyRDD[Int] + assert(data.sortBy(x => x).collect() === Array.empty) + } + test("sortByKey") { val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))