Skip to content

Commit

Permalink
update range partitioner to run only one job on roughly balanced data
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jul 24, 2014
1 parent cc12f47 commit 9ee9992
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 17 deletions.
93 changes: 80 additions & 13 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ package org.apache.spark

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.reflect.ClassTag
import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.{ClassTag, classTag}
import scala.util.hashing.byteswap32

import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.JavaSerializer
Expand Down Expand Up @@ -108,21 +112,84 @@ class RangePartitioner[K : Ordering : ClassTag, V](
// An array of upper bounds for the first (partitions - 1) partitions
private var rangeBounds: Array[K] = {
if (partitions == 1) {
Array()
Array.empty
} else {
val rddSize = rdd.count()
val maxSampleSize = partitions * 20.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted
if (rddSample.length == 0) {
Array()
// This is the sample size we need to have roughly balanced output partitions.
val sampleSize = 20.0 * partitions
// Assume the input partitions are roughly balanced and over-sample a little bit.
val sampleSizePerPartition = math.ceil(5.0 * sampleSize / rdd.partitions.size).toInt
val shift = rdd.id
val classTagK = classTag[K]
val sketch = rdd.mapPartitionsWithIndex { (idx, iter) =>
val seed = byteswap32(idx + shift)
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
iter.map(_._1), sampleSizePerPartition, seed)(classTagK)
Iterator((idx, n, sample))
}.collect()
var numItems = 0L
sketch.foreach { case (_, n, _) =>
numItems += n
}
if (numItems == 0L) {
Array.empty
} else {
val bounds = new Array[K](partitions - 1)
for (i <- 0 until partitions - 1) {
val index = (rddSample.length - 1) * (i + 1) / partitions
bounds(i) = rddSample(index)
// If a partition contains much more than the average number of items, we re-sample from it
// to ensure that enough items are collected from that partition.
val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
val candidates = ArrayBuffer.empty[(K, Float)]
val imbalancedPartitions = ArrayBuffer.empty[Int]
sketch.foreach { case (idx, n, sample) =>
if (fraction * n > sampleSizePerPartition) {
imbalancedPartitions += idx
} else {
// The weight is 1 over the sampling probability.
val weight = (n.toDouble / sample.size).toFloat
sample.foreach { key =>
candidates += ((key, weight))
}
}
}
if (imbalancedPartitions.nonEmpty) {
val sampleFunc: (TaskContext, Iterator[Product2[K, V]]) => Array[K] = { (context, iter) =>
val random = new XORShiftRandom(byteswap32(context.partitionId - shift))
iter.map(_._1).filter(t => random.nextDouble() < fraction).toArray
}
val weight = (1.0 / fraction).toFloat
val resultHandler: (Int, Array[K]) => Unit = { (index, sample) =>
sample.foreach { key =>
candidates += ((key, weight))
}
}
rdd.context.runJob(
rdd, sampleFunc, imbalancedPartitions, allowLocal = false, resultHandler)
}
var sumWeights: Double = 0.0
candidates.foreach { case (_, weight) =>
sumWeights += weight
}
val step = sumWeights / partitions
var cumWeight = 0.0
var target = step
val bounds = ArrayBuffer.empty[K]
val sorted = candidates.sortBy(_._1)
var i = 0
var j = 0
var previousBound = Option.empty[K]
while ((i < sorted.length) && (j < partitions - 1)) {
val (key, weight) = sorted(i)
cumWeight += weight
if (cumWeight > target) {
// Skip duplicate values.
if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) {
bounds += key
target += step
j += 1
previousBound = Some(key)
}
}
i += 1
}
bounds
bounds.toArray
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,28 @@ class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] {

override def clone = new PoissonSampler[T](mean)
}

/**
* :: DeveloperApi ::
* A sampler selects items based on their importance scores defined in the keys.
*
* The importance score should be within range `[0, 1]`. Items with scores less than or equal to 0
* would never get selected, while items with scores greater than or equal to 1 would always get
* selected.
*
* @param ratio sampling probability
* @tparam T item type
*/
@DeveloperApi
class ImportanceSampler[T](ratio: Double) extends RandomSampler[(Double, T), (Double, T)] {

private[random] var rng: Random = new XORShiftRandom

override def setSeed(seed: Long) = rng.setSeed(seed)

override def sample(items: Iterator[(Double, T)]): Iterator[(Double, T)] = {
items.filter(item => rng.nextDouble() < ratio)
}

override def clone = new ImportanceSampler[T](ratio)
}
6 changes: 2 additions & 4 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,12 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet

assert(p2 === p2)
assert(p4 === p4)
assert(p2 != p4)
assert(p4 != p2)
assert(p2 === p4)
assert(p4 === anotherP4)
assert(anotherP4 === p4)
assert(descendingP2 === descendingP2)
assert(descendingP4 === descendingP4)
assert(descendingP2 != descendingP4)
assert(descendingP4 != descendingP2)
assert(descendingP2 === descendingP4)
assert(p2 != descendingP2)
assert(p4 != descendingP4)
assert(descendingP2 != p2)
Expand Down

0 comments on commit 9ee9992

Please sign in to comment.