From 9ee9992f8581557ca410cf38a88557d3fd3fe21a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 23 Jul 2014 11:16:01 -0700 Subject: [PATCH] update range partitioner to run only one job on roughly balanced data --- .../scala/org/apache/spark/Partitioner.scala | 93 ++++++++++++++++--- .../spark/util/random/RandomSampler.scala | 25 +++++ .../org/apache/spark/PartitioningSuite.scala | 6 +- 3 files changed, 107 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 52c018baa5f7b..f269dc9abcd55 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -19,7 +19,11 @@ package org.apache.spark import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import scala.reflect.ClassTag +import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.{ClassTag, classTag} +import scala.util.hashing.byteswap32 import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer @@ -108,21 +112,84 @@ class RangePartitioner[K : Ordering : ClassTag, V]( // An array of upper bounds for the first (partitions - 1) partitions private var rangeBounds: Array[K] = { if (partitions == 1) { - Array() + Array.empty } else { - val rddSize = rdd.count() - val maxSampleSize = partitions * 20.0 - val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) - val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted - if (rddSample.length == 0) { - Array() + // This is the sample size we need to have roughly balanced output partitions. + val sampleSize = 20.0 * partitions + // Assume the input partitions are roughly balanced and over-sample a little bit. + val sampleSizePerPartition = math.ceil(5.0 * sampleSize / rdd.partitions.size).toInt + val shift = rdd.id + val classTagK = classTag[K] + val sketch = rdd.mapPartitionsWithIndex { (idx, iter) => + val seed = byteswap32(idx + shift) + val (sample, n) = SamplingUtils.reservoirSampleAndCount( + iter.map(_._1), sampleSizePerPartition, seed)(classTagK) + Iterator((idx, n, sample)) + }.collect() + var numItems = 0L + sketch.foreach { case (_, n, _) => + numItems += n + } + if (numItems == 0L) { + Array.empty } else { - val bounds = new Array[K](partitions - 1) - for (i <- 0 until partitions - 1) { - val index = (rddSample.length - 1) * (i + 1) / partitions - bounds(i) = rddSample(index) + // If a partition contains much more than the average number of items, we re-sample from it + // to ensure that enough items are collected from that partition. + val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) + val candidates = ArrayBuffer.empty[(K, Float)] + val imbalancedPartitions = ArrayBuffer.empty[Int] + sketch.foreach { case (idx, n, sample) => + if (fraction * n > sampleSizePerPartition) { + imbalancedPartitions += idx + } else { + // The weight is 1 over the sampling probability. + val weight = (n.toDouble / sample.size).toFloat + sample.foreach { key => + candidates += ((key, weight)) + } + } + } + if (imbalancedPartitions.nonEmpty) { + val sampleFunc: (TaskContext, Iterator[Product2[K, V]]) => Array[K] = { (context, iter) => + val random = new XORShiftRandom(byteswap32(context.partitionId - shift)) + iter.map(_._1).filter(t => random.nextDouble() < fraction).toArray + } + val weight = (1.0 / fraction).toFloat + val resultHandler: (Int, Array[K]) => Unit = { (index, sample) => + sample.foreach { key => + candidates += ((key, weight)) + } + } + rdd.context.runJob( + rdd, sampleFunc, imbalancedPartitions, allowLocal = false, resultHandler) + } + var sumWeights: Double = 0.0 + candidates.foreach { case (_, weight) => + sumWeights += weight + } + val step = sumWeights / partitions + var cumWeight = 0.0 + var target = step + val bounds = ArrayBuffer.empty[K] + val sorted = candidates.sortBy(_._1) + var i = 0 + var j = 0 + var previousBound = Option.empty[K] + while ((i < sorted.length) && (j < partitions - 1)) { + val (key, weight) = sorted(i) + cumWeight += weight + if (cumWeight > target) { + // Skip duplicate values. + if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { + bounds += key + target += step + j += 1 + previousBound = Some(key) + } + } + i += 1 } - bounds + bounds.toArray } } } diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 32c5fdad75e58..04c4d78b1de6b 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -106,3 +106,28 @@ class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] { override def clone = new PoissonSampler[T](mean) } + +/** + * :: DeveloperApi :: + * A sampler selects items based on their importance scores defined in the keys. + * + * The importance score should be within range `[0, 1]`. Items with scores less than or equal to 0 + * would never get selected, while items with scores greater than or equal to 1 would always get + * selected. + * + * @param ratio sampling probability + * @tparam T item type + */ +@DeveloperApi +class ImportanceSampler[T](ratio: Double) extends RandomSampler[(Double, T), (Double, T)] { + + private[random] var rng: Random = new XORShiftRandom + + override def setSeed(seed: Long) = rng.setSeed(seed) + + override def sample(items: Iterator[(Double, T)]): Iterator[(Double, T)] = { + items.filter(item => rng.nextDouble() < ratio) + } + + override def clone = new ImportanceSampler[T](ratio) +} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 4658a08064280..2138704ff6fb6 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -52,14 +52,12 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(p2 === p2) assert(p4 === p4) - assert(p2 != p4) - assert(p4 != p2) + assert(p2 === p4) assert(p4 === anotherP4) assert(anotherP4 === p4) assert(descendingP2 === descendingP2) assert(descendingP4 === descendingP4) - assert(descendingP2 != descendingP4) - assert(descendingP4 != descendingP2) + assert(descendingP2 === descendingP4) assert(p2 != descendingP2) assert(p4 != descendingP4) assert(descendingP2 != p2)