Skip to content

Commit

Permalink
separate sketching and determining bounds impl
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jul 27, 2014
1 parent c436d30 commit eb95dd8
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 64 deletions.
141 changes: 81 additions & 60 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.{ClassTag, classTag}
import scala.util.hashing.byteswap32

import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{PartitionPruningRDD, RDD}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.{CollectionsUtils, Utils}
import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils}
Expand Down Expand Up @@ -106,40 +107,30 @@ class RangePartitioner[K : Ordering : ClassTag, V](
private var ascending: Boolean = true)
extends Partitioner {

private var ordering = implicitly[Ordering[K]]
// We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")

@transient private[spark] var singlePass = true // for unit tests
private var ordering = implicitly[Ordering[K]]

// An array of upper bounds for the first (partitions - 1) partitions
private var rangeBounds: Array[K] = {
if (partitions == 1) {
if (partitions <= 1) {
Array.empty
} else {
// This is the sample size we need to have roughly balanced output partitions.
val sampleSize = 20.0 * partitions
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
val sampleSize = math.min(20.0 * partitions, 1e6)
// Assume the input partitions are roughly balanced and over-sample a little bit.
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
val shift = rdd.id
val classTagK = classTag[K]
val sketch = rdd.mapPartitionsWithIndex { (idx, iter) =>
val seed = byteswap32(idx + shift)
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
iter.map(_._1), sampleSizePerPartition, seed)(classTagK)
Iterator((idx, n, sample))
}.collect()
var numItems = 0L
sketch.foreach { case (_, n, _) =>
numItems += n
}
val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
if (numItems == 0L) {
Array.empty
} else {
// If a partition contains much more than the average number of items, we re-sample from it
// to ensure that enough items are collected from that partition.
val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
val candidates = ArrayBuffer.empty[(K, Float)]
val imbalancedPartitions = ArrayBuffer.empty[Int]
sketch.foreach { case (idx, n, sample) =>
val imbalancedPartitions = mutable.Set.empty[Int]
sketched.foreach { case (idx, n, sample) =>
if (fraction * n > sampleSizePerPartition) {
imbalancedPartitions += idx
} else {
Expand All @@ -151,48 +142,14 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
}
if (imbalancedPartitions.nonEmpty) {
singlePass = false
val sampleFunc: (TaskContext, Iterator[Product2[K, V]]) => Array[K] = { (context, iter) =>
val random = new XORShiftRandom(byteswap32(context.partitionId - shift))
iter.map(_._1).filter(t => random.nextDouble() < fraction).toArray(classTagK)
}
// Re-sample imbalanced partitions with the desired sampling probability.
val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
val seed = byteswap32(-rdd.id - 1)
val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
val weight = (1.0 / fraction).toFloat
val resultHandler: (Int, Array[K]) => Unit = { (_, sample) =>
sample.foreach { key =>
candidates += ((key, weight))
}
}
rdd.context.runJob(
rdd, sampleFunc, imbalancedPartitions, allowLocal = false, resultHandler)
candidates ++= reSampled.map(x => (x, weight))
}
val numCandidates = candidates.size
var sumWeights = 0.0
candidates.foreach { case (_, weight) =>
sumWeights += weight
}
val step = sumWeights / partitions
var cumWeight = 0.0
var target = step
val bounds = ArrayBuffer.empty[K]
val orderedCandidates = candidates.sortBy(_._1)
var i = 0
var j = 0
var previousBound = Option.empty[K]
while ((i < numCandidates) && (j < partitions - 1)) {
val (key, weight) = orderedCandidates(i)
cumWeight += weight
if (cumWeight > target) {
// Skip duplicate values.
if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) {
bounds += key
target += step
j += 1
previousBound = Some(key)
}
}
i += 1
}
bounds.toArray
RangePartitioner.determineBounds(candidates, partitions)
}
}
}
Expand Down Expand Up @@ -282,3 +239,67 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
}
}

private[spark] object RangePartitioner {

/**
* Sketches the input RDD via reservoir sampling on each partition.
*
* @param rdd the input RDD to sketch
* @param sampleSizePerPartition max sample size per partition
* @return (total number of items, an array of (partitionId, number of items, sample))
*/
def sketch[K:ClassTag](
rdd: RDD[K],
sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
val shift = rdd.id
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
val seed = byteswap32(idx ^ (shift << 16))
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
iter, sampleSizePerPartition, seed)
Iterator((idx, n, sample))
}.collect()
val numItems = sketched.map(_._2.toLong).sum
(numItems, sketched)
}

/**
* Determines the bounds for range partitioning from candidates with weights indicating how many
* items each represents. Usually this is 1 over the probability used to sample this candidate.
*
* @param candidates unordered candidates with weights
* @param partitions number of partitions
* @return selected bounds
*/
def determineBounds[K:Ordering:ClassTag](
candidates: ArrayBuffer[(K, Float)],
partitions: Int): Array[K] = {
val ordering = implicitly[Ordering[K]]
val ordered = candidates.sortBy(_._1)
val numCandidates = ordered.size
val sumWeights = ordered.map(_._2.toDouble).sum
val step = sumWeights / partitions
var cumWeight = 0.0
var target = step
val bounds = ArrayBuffer.empty[K]
var i = 0
var j = 0
var previousBound = Option.empty[K]
while ((i < numCandidates) && (j < partitions - 1)) {
val (key, weight) = ordered(i)
cumWeight += weight
if (cumWeight > target) {
// Skip duplicate values.
if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) {
bounds += key
target += step
j += 1
previousBound = Some(key)
}
}
i += 1
}
bounds.toArray
}
}
38 changes: 34 additions & 4 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import scala.collection.mutable.ArrayBuffer
import scala.math.abs

import org.scalatest.{FunSuite, PrivateMethodTester}
Expand Down Expand Up @@ -100,6 +101,28 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
partitioner.getPartition(Row(100))
}

test("RangPartitioner.sketch") {
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
val random = new java.util.Random(i)
Iterator.fill(i)(random.nextDouble())
}.cache()
val sampleSizePerPartition = 10
val (count, sketched) = RangePartitioner.sketch(rdd, sampleSizePerPartition)
assert(count === rdd.count())
sketched.foreach { case (idx, n, sample) =>
assert(n === idx)
assert(sample.size === math.min(n, sampleSizePerPartition))
}
}

test("RangePartitioner.determineBounds") {
assert(RangePartitioner.determineBounds(ArrayBuffer.empty[(Int, Float)], 10).isEmpty,
"Bounds on an empty candidates set should be empty.")
val candidates = ArrayBuffer(
(0.7, 2.0f), (0.1, 1.0f), (0.4, 1.0f), (0.3, 1.0f), (0.2, 1.0f), (0.5, 1.0f), (1.0, 3.0f))
assert(RangePartitioner.determineBounds(candidates, 3) === Array(0.4, 0.7))
}

test("RangePartitioner should run only one job if data is roughly balanced") {
val rdd = sc.makeRDD(0 until 20, 20).flatMap { i =>
val random = new java.util.Random(i)
Expand All @@ -108,9 +131,8 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
for (numPartitions <- Seq(10, 20, 40)) {
val partitioner = new RangePartitioner(numPartitions, rdd)
assert(partitioner.numPartitions === numPartitions)
assert(partitioner.singlePass === true)
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
assert(counts.max < 2.0 * counts.min)
assert(counts.max < 3.0 * counts.min)
}
}

Expand All @@ -122,12 +144,20 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
for (numPartitions <- Seq(2, 4, 8)) {
val partitioner = new RangePartitioner(numPartitions, rdd)
assert(partitioner.numPartitions === numPartitions)
assert(partitioner.singlePass === false)
val counts = rdd.keys.map(key => partitioner.getPartition(key)).countByValue().values
assert(counts.max < 2.0 * counts.min)
assert(counts.max < 3.0 * counts.min)
}
}

test("RangePartitioner should return a single partition for empty RDDs") {
val empty1 = sc.emptyRDD[(Int, Double)]
val partitioner1 = new RangePartitioner(0, empty1)
assert(partitioner1.numPartitions === 1)
val empty2 = sc.makeRDD(0 until 2, 2).flatMap(i => Seq.empty[(Int, Double)])
val partitioner2 = new RangePartitioner(2, empty2)
assert(partitioner2.numPartitions === 1)
}

test("HashPartitioner not equal to RangePartitioner") {
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
Expand Down
5 changes: 5 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}

test("sort an empty RDD") {
val data = sc.emptyRDD[Int]
assert(data.sortBy(x => x).collect() === Array.empty)
}

test("sortByKey") {
val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))

Expand Down

0 comments on commit eb95dd8

Please sign in to comment.