Skip to content

Commit

Permalink
Speed up QTree
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Apr 16, 2015
1 parent adab132 commit 5cc967d
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import java.lang.Math._
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicLong

import com.google.caliper.{Param, SimpleBenchmark}
import com.twitter.algebird.{HyperLogLogMonoid, _}
import com.google.caliper.{ Param, SimpleBenchmark }
import com.twitter.algebird.{ HyperLogLogMonoid, _ }
import com.twitter.algebird.util.summer._
import com.twitter.bijection._
import com.twitter.util.{Await, Duration, FuturePool}
import com.twitter.util.{ Await, Duration, FuturePool }

import scala.util.Random

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.twitter.algebird.caliper

import com.google.caliper.{Param, SimpleBenchmark}
import com.google.caliper.{ Param, SimpleBenchmark }

/**
* Benchmarks the hashing algorithms used by Count-Min sketch for CMS[BigInt].
Expand Down Expand Up @@ -33,7 +33,7 @@ class CMSHashingBenchmark extends SimpleBenchmark {
/**
* Width of the counting table.
*/
@Param(Array("11" /* eps = 0.271 */ , "544" /* eps = 0.005 */ , "2719" /* eps = 1E-3 */ , "271829" /* eps = 1E-5 */))
@Param(Array("11" /* eps = 0.271 */ , "544" /* eps = 0.005 */ , "2719" /* eps = 1E-3 */ , "271829" /* eps = 1E-5 */ ))
val width: Int = 0

/**
Expand All @@ -54,7 +54,7 @@ class CMSHashingBenchmark extends SimpleBenchmark {
override def setUp() {
random = new scala.util.Random
// We draw numbers randomly from a 2^maxBits address space.
inputs = (1 to operations).view.map { _ => scala.math.BigInt(maxBits, random)}
inputs = (1 to operations).view.map { _ => scala.math.BigInt(maxBits, random) }
}

private def murmurHashScala(a: Int, b: Int, width: Int)(x: BigInt) = {
Expand Down Expand Up @@ -82,7 +82,7 @@ class CMSHashingBenchmark extends SimpleBenchmark {
def timeBrokenCurrentHashWithRandomMaxBitsNumbers(operations: Int): Int = {
var dummy = 0
while (dummy < operations) {
inputs.foreach { input => brokenCurrentHash(a, b, width)(input)}
inputs.foreach { input => brokenCurrentHash(a, b, width)(input) }
dummy += 1
}
dummy
Expand All @@ -91,7 +91,7 @@ class CMSHashingBenchmark extends SimpleBenchmark {
def timeMurmurHashScalaWithRandomMaxBitsNumbers(operations: Int): Int = {
var dummy = 0
while (dummy < operations) {
inputs.foreach { input => murmurHashScala(a, b, width)(input)}
inputs.foreach { input => murmurHashScala(a, b, width)(input) }
dummy += 1
}
dummy
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.twitter.algebird.caliper

import com.twitter.algebird._
import scala.util.Random
import com.twitter.bijection._

import java.util.concurrent.Executors
import com.twitter.algebird.util._
import com.google.caliper.{ Param, SimpleBenchmark }
import java.nio.ByteBuffer

import scala.math._

class OldQTreeSemigroup[A: Monoid](k: Int) extends QTreeSemigroup[A](k) {
override def sumOption(items: TraversableOnce[QTree[A]]) =
if (items.isEmpty) None
else Some(items.reduce(plus))
}

class QTreeBenchmark extends SimpleBenchmark {
var qtree: QTreeSemigroup[Long] = _
var oldqtree: QTreeSemigroup[Long] = _

@Param(Array("5", "10", "12"))
val depthK: Int = 0

@Param(Array("100", "10000"))
val numElements: Int = 0

var inputData: Seq[QTree[Long]] = _

override def setUp {
qtree = new QTreeSemigroup[Long](depthK)
oldqtree = new OldQTreeSemigroup(depthK)

val rng = new Random("qtree".hashCode)

inputData = (0L until numElements).map { _ =>
QTree(rng.nextInt(1000).toLong)
}
}

def timeSumOption(reps: Int): Int = {
var dummy = 0
while (dummy < reps) {
qtree.sumOption(inputData)
dummy += 1
}
dummy
}

/*
def timeOldSumOption(reps: Int): Int = {
var dummy = 0
while (dummy < reps) {
oldqtree.sumOption(inputData)
dummy += 1
}
dummy
} */
}
114 changes: 73 additions & 41 deletions algebird-core/src/main/scala/com/twitter/algebird/QTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,21 @@ package com.twitter.algebird
*/

object QTree {
def apply[A: Monoid](kv: (Double, A), level: Int = -16): QTree[A] = {
def apply[A](kv: (Double, A), level: Int = -16): QTree[A] =
QTree(math.floor(kv._1 / math.pow(2.0, level)).toLong,
level,
1,
kv._2,
None,
None)
}

def apply[A: Monoid](kv: (Long, A)): QTree[A] = {
def apply[A](kv: (Long, A)): QTree[A] =
QTree(kv._1,
0,
1,
kv._2,
None,
None)
}

/**
* The common case of wanting a count and sum for the same value
Expand All @@ -63,8 +61,26 @@ object QTree {
def apply(k: Double): QTree[Double] = apply(k -> k)
}

class QTreeSemigroup[A: Monoid](k: Int) extends Semigroup[QTree[A]] {
class QTreeSemigroup[A](k: Int)(implicit val underlyingMonoid: Monoid[A]) extends Semigroup[QTree[A]] {
/** Override this if you want to change how frequently sumOption calls compress */
def compressBatchSize: Int = 25
def plus(left: QTree[A], right: QTree[A]) = left.merge(right).compress(k)
override def sumOption(items: TraversableOnce[QTree[A]]): Option[QTree[A]] = if (items.isEmpty) None
else {
// only call compressBatchSize once
val batchSize = compressBatchSize
var count = 1 // start at 1, so we only compress after batchSize items
val iter = items.toIterator
var result = iter.next // due to not being empty, this does not throw
while (iter.hasNext) {
result = result.merge(iter.next)
count += 1
if (count % batchSize == 0) {
result = result.compress(k)
}
}
Some(result.compress(k))
}
}

case class QTree[A](
Expand All @@ -73,15 +89,15 @@ case class QTree[A](
count: Long, //the total count for this node and all of its children
sum: A, //the sum at just this node (*not* including its children)
lowerChild: Option[QTree[A]],
upperChild: Option[QTree[A]])(implicit monoid: Monoid[A]) {
upperChild: Option[QTree[A]]) {

require(offset >= 0, "QTree can not accept negative values")

def range: Double = math.pow(2.0, level)
def lowerBound: Double = range * offset
def upperBound: Double = range * (offset + 1)

private def extendToLevel(n: Int): QTree[A] = {
private def extendToLevel(n: Int)(implicit monoid: Monoid[A]): QTree[A] = {
if (n <= level)
this
else {
Expand All @@ -93,6 +109,7 @@ case class QTree[A](
QTree[A](nextOffset, nextLevel, count, monoid.zero, Some(this), None)
else
QTree[A](nextOffset, nextLevel, count, monoid.zero, None, Some(this))

parent.extendToLevel(n)
}
}
Expand All @@ -117,14 +134,14 @@ case class QTree[A](
ancestorLevel.max(level).max(other.level)
}

def merge(other: QTree[A]) = {
def merge(other: QTree[A])(implicit monoid: Monoid[A]) = {
val commonAncestor = commonAncestorLevel(other)
val left = extendToLevel(commonAncestor)
val right = other.extendToLevel(commonAncestor)
left.mergeWithPeer(right)
}

private def mergeWithPeer(other: QTree[A]): QTree[A] = {
private def mergeWithPeer(other: QTree[A])(implicit monoid: Monoid[A]): QTree[A] = {
assert(other.lowerBound == lowerBound, "lowerBound " + other.lowerBound + " != " + lowerBound)
assert(other.level == level, "level " + other.level + " != " + level)

Expand All @@ -134,34 +151,33 @@ case class QTree[A](
upperChild = mergeOptions(upperChild, other.upperChild))
}

private def mergeOptions(a: Option[QTree[A]], b: Option[QTree[A]]): Option[QTree[A]] = {
private def mergeOptions(a: Option[QTree[A]], b: Option[QTree[A]])(implicit monoid: Monoid[A]): Option[QTree[A]] =
(a, b) match {
case (Some(qa), Some(qb)) => Some(qa.mergeWithPeer(qb))
case (None, _) => b
case (_, None) => a
case (None, right) => right
case (left, None) => left
}
}

def quantileBounds(p: Double): (Double, Double) = {
val rank = math.floor(count * p).toLong
(findRankLowerBound(rank).get, findRankUpperBound(rank).get)
}

private def findRankLowerBound(rank: Long): Option[Double] = {
private def findRankLowerBound(rank: Long): Option[Double] =
if (rank > count)
None
else {
val childCounts = mapChildrenWithDefault(0L){ _.count }
val childCounts = mapChildrenWithDefault(0L)(_.count)
val parentCount = count - childCounts._1 - childCounts._2
lowerChild.flatMap{ _.findRankLowerBound(rank - parentCount) }.orElse {
val newRank = rank - childCounts._1 - parentCount
if (newRank <= 0)
Some(lowerBound)
else
upperChild.flatMap{ _.findRankLowerBound(newRank) }
}
lowerChild.flatMap { _.findRankLowerBound(rank - parentCount) }
.orElse {
val newRank = rank - childCounts._1 - parentCount
if (newRank <= 0)
Some(lowerBound)
else
upperChild.flatMap{ _.findRankLowerBound(newRank) }
}
}
}

private def findRankUpperBound(rank: Long): Option[Double] = {
if (rank > count)
Expand All @@ -174,7 +190,7 @@ case class QTree[A](
}
}

def rangeSumBounds(from: Double, to: Double): (A, A) = {
def rangeSumBounds(from: Double, to: Double)(implicit monoid: Monoid[A]): (A, A) = {
if (from <= lowerBound && to >= upperBound) {
val s = totalSum
(s, s)
Expand All @@ -201,36 +217,51 @@ case class QTree[A](
}
}

def compress(k: Int) = {
def compress(k: Int)(implicit m: Monoid[A]): QTree[A] = {
val minCount = count >> k
val (newTree, pruned) = pruneChildrenWhere{ _.count < minCount }
newTree
if ((minCount > 1L) || (count < 1L)) {
pruneChildren(minCount)
} else {
// count > 0, so for all nodes, if minCount <= 1, then count >= minCount
// so we don't need to traverse
// this is common when you only add few items together, which happens
// on map-side aggregation commonly
this
}
}

private def pruneChildrenWhere(fn: QTree[A] => Boolean): (QTree[A], Boolean) = {
if (fn(this)) {
(copy(sum = totalSum, lowerChild = None, upperChild = None), true)
// If we don't prune we MUST return this
private def pruneChildren(minCount: Long)(implicit m: Monoid[A]): QTree[A] =
if (count < minCount) {
copy(sum = totalSum, lowerChild = None, upperChild = None)
} else {
val (newLower, lowerPruned) = pruneChildWhere(lowerChild, fn)
val (newUpper, upperPruned) = pruneChildWhere(upperChild, fn)
if (!lowerPruned && !upperPruned)
(this, false)
val newLower = pruneChild(minCount, lowerChild)
val lowerNotPruned = newLower eq lowerChild
val newUpper = pruneChild(minCount, upperChild)
val upperNotPruned = newUpper eq upperChild
if (lowerNotPruned && upperNotPruned)
this
else
(copy(lowerChild = newLower, upperChild = newUpper), true)
copy(lowerChild = newLower, upperChild = newUpper)
}
}

private def pruneChildWhere(child: Option[QTree[A]], fn: QTree[A] => Boolean): (Option[QTree[A]], Boolean) = {
val result = child.map{ _.pruneChildrenWhere(fn) }
(result.map{ _._1 }, result.map{ _._2 }.getOrElse(false))
// If we don't prune we MUST return child
@inline
private def pruneChild(minCount: Long,
child: Option[QTree[A]])(implicit m: Monoid[A]): Option[QTree[A]] = child match {
case exists @ Some(oldChild) =>
val newChild = oldChild.pruneChildren(minCount)
if (newChild eq oldChild) exists // need to pass the same reference if we don't change
else Some(newChild)
case n @ None => n // make sure we pass the same ref out
}

def size: Int = {
val childSizes = mapChildrenWithDefault(0){ _.size }
1 + childSizes._1 + childSizes._2
}

def totalSum: A = {
def totalSum(implicit monoid: Monoid[A]): A = {
val childSums = mapChildrenWithDefault(monoid.zero){ _.totalSum }
monoid.plus(sum, monoid.plus(childSums._1, childSums._2))
}
Expand Down Expand Up @@ -258,6 +289,7 @@ case class QTree[A](
}

def interQuartileMean(implicit n: Numeric[A]): (Double, Double) = {
implicit val monoid: Monoid[A] = Ring.numericRing[A]
val (l25, u25) = quantileBounds(0.25)
val (l75, u75) = quantileBounds(0.75)
val (ll, lu) = rangeSumBounds(l25, l75)
Expand All @@ -267,4 +299,4 @@ case class QTree[A](

(n.toDouble(ll) / luc, n.toDouble(uu) / ulc)
}
}
}

0 comments on commit 5cc967d

Please sign in to comment.