From 5cc967ded746ce8f6d1de44eaf03c82d0534785d Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Wed, 15 Apr 2015 15:14:57 -1000 Subject: [PATCH] Speed up QTree --- .../caliper/AsyncSummerBenchmark.scala | 6 +- .../caliper/CMSHashingBenchmark.scala | 10 +- .../algebird/caliper/QTreeBenchmark.scala | 61 ++++++++++ .../scala/com/twitter/algebird/QTree.scala | 114 +++++++++++------- 4 files changed, 142 insertions(+), 49 deletions(-) create mode 100644 algebird-caliper/src/test/scala/com/twitter/algebird/caliper/QTreeBenchmark.scala diff --git a/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/AsyncSummerBenchmark.scala b/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/AsyncSummerBenchmark.scala index 0e9e46bef..0fe206674 100644 --- a/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/AsyncSummerBenchmark.scala +++ b/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/AsyncSummerBenchmark.scala @@ -4,11 +4,11 @@ import java.lang.Math._ import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicLong -import com.google.caliper.{Param, SimpleBenchmark} -import com.twitter.algebird.{HyperLogLogMonoid, _} +import com.google.caliper.{ Param, SimpleBenchmark } +import com.twitter.algebird.{ HyperLogLogMonoid, _ } import com.twitter.algebird.util.summer._ import com.twitter.bijection._ -import com.twitter.util.{Await, Duration, FuturePool} +import com.twitter.util.{ Await, Duration, FuturePool } import scala.util.Random diff --git a/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/CMSHashingBenchmark.scala b/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/CMSHashingBenchmark.scala index c355a8ac8..45ae58f4d 100644 --- a/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/CMSHashingBenchmark.scala +++ b/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/CMSHashingBenchmark.scala @@ -1,6 +1,6 @@ package com.twitter.algebird.caliper -import com.google.caliper.{Param, SimpleBenchmark} +import com.google.caliper.{ Param, SimpleBenchmark } /** * Benchmarks the hashing algorithms used by Count-Min sketch for CMS[BigInt]. @@ -33,7 +33,7 @@ class CMSHashingBenchmark extends SimpleBenchmark { /** * Width of the counting table. */ - @Param(Array("11" /* eps = 0.271 */ , "544" /* eps = 0.005 */ , "2719" /* eps = 1E-3 */ , "271829" /* eps = 1E-5 */)) + @Param(Array("11" /* eps = 0.271 */ , "544" /* eps = 0.005 */ , "2719" /* eps = 1E-3 */ , "271829" /* eps = 1E-5 */ )) val width: Int = 0 /** @@ -54,7 +54,7 @@ class CMSHashingBenchmark extends SimpleBenchmark { override def setUp() { random = new scala.util.Random // We draw numbers randomly from a 2^maxBits address space. - inputs = (1 to operations).view.map { _ => scala.math.BigInt(maxBits, random)} + inputs = (1 to operations).view.map { _ => scala.math.BigInt(maxBits, random) } } private def murmurHashScala(a: Int, b: Int, width: Int)(x: BigInt) = { @@ -82,7 +82,7 @@ class CMSHashingBenchmark extends SimpleBenchmark { def timeBrokenCurrentHashWithRandomMaxBitsNumbers(operations: Int): Int = { var dummy = 0 while (dummy < operations) { - inputs.foreach { input => brokenCurrentHash(a, b, width)(input)} + inputs.foreach { input => brokenCurrentHash(a, b, width)(input) } dummy += 1 } dummy @@ -91,7 +91,7 @@ class CMSHashingBenchmark extends SimpleBenchmark { def timeMurmurHashScalaWithRandomMaxBitsNumbers(operations: Int): Int = { var dummy = 0 while (dummy < operations) { - inputs.foreach { input => murmurHashScala(a, b, width)(input)} + inputs.foreach { input => murmurHashScala(a, b, width)(input) } dummy += 1 } dummy diff --git a/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/QTreeBenchmark.scala b/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/QTreeBenchmark.scala new file mode 100644 index 000000000..c55e5addc --- /dev/null +++ b/algebird-caliper/src/test/scala/com/twitter/algebird/caliper/QTreeBenchmark.scala @@ -0,0 +1,61 @@ +package com.twitter.algebird.caliper + +import com.twitter.algebird._ +import scala.util.Random +import com.twitter.bijection._ + +import java.util.concurrent.Executors +import com.twitter.algebird.util._ +import com.google.caliper.{ Param, SimpleBenchmark } +import java.nio.ByteBuffer + +import scala.math._ + +class OldQTreeSemigroup[A: Monoid](k: Int) extends QTreeSemigroup[A](k) { + override def sumOption(items: TraversableOnce[QTree[A]]) = + if (items.isEmpty) None + else Some(items.reduce(plus)) +} + +class QTreeBenchmark extends SimpleBenchmark { + var qtree: QTreeSemigroup[Long] = _ + var oldqtree: QTreeSemigroup[Long] = _ + + @Param(Array("5", "10", "12")) + val depthK: Int = 0 + + @Param(Array("100", "10000")) + val numElements: Int = 0 + + var inputData: Seq[QTree[Long]] = _ + + override def setUp { + qtree = new QTreeSemigroup[Long](depthK) + oldqtree = new OldQTreeSemigroup(depthK) + + val rng = new Random("qtree".hashCode) + + inputData = (0L until numElements).map { _ => + QTree(rng.nextInt(1000).toLong) + } + } + + def timeSumOption(reps: Int): Int = { + var dummy = 0 + while (dummy < reps) { + qtree.sumOption(inputData) + dummy += 1 + } + dummy + } + + /* + def timeOldSumOption(reps: Int): Int = { + var dummy = 0 + while (dummy < reps) { + oldqtree.sumOption(inputData) + dummy += 1 + } + dummy + } */ +} diff --git a/algebird-core/src/main/scala/com/twitter/algebird/QTree.scala b/algebird-core/src/main/scala/com/twitter/algebird/QTree.scala index d45588095..024346b0c 100644 --- a/algebird-core/src/main/scala/com/twitter/algebird/QTree.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/QTree.scala @@ -38,23 +38,21 @@ package com.twitter.algebird */ object QTree { - def apply[A: Monoid](kv: (Double, A), level: Int = -16): QTree[A] = { + def apply[A](kv: (Double, A), level: Int = -16): QTree[A] = QTree(math.floor(kv._1 / math.pow(2.0, level)).toLong, level, 1, kv._2, None, None) - } - def apply[A: Monoid](kv: (Long, A)): QTree[A] = { + def apply[A](kv: (Long, A)): QTree[A] = QTree(kv._1, 0, 1, kv._2, None, None) - } /** * The common case of wanting a count and sum for the same value @@ -63,8 +61,26 @@ object QTree { def apply(k: Double): QTree[Double] = apply(k -> k) } -class QTreeSemigroup[A: Monoid](k: Int) extends Semigroup[QTree[A]] { +class QTreeSemigroup[A](k: Int)(implicit val underlyingMonoid: Monoid[A]) extends Semigroup[QTree[A]] { + /** Override this if you want to change how frequently sumOption calls compress */ + def compressBatchSize: Int = 25 def plus(left: QTree[A], right: QTree[A]) = left.merge(right).compress(k) + override def sumOption(items: TraversableOnce[QTree[A]]): Option[QTree[A]] = if (items.isEmpty) None + else { + // only call compressBatchSize once + val batchSize = compressBatchSize + var count = 1 // start at 1, so we only compress after batchSize items + val iter = items.toIterator + var result = iter.next // due to not being empty, this does not throw + while (iter.hasNext) { + result = result.merge(iter.next) + count += 1 + if (count % batchSize == 0) { + result = result.compress(k) + } + } + Some(result.compress(k)) + } } case class QTree[A]( @@ -73,7 +89,7 @@ case class QTree[A]( count: Long, //the total count for this node and all of its children sum: A, //the sum at just this node (*not* including its children) lowerChild: Option[QTree[A]], - upperChild: Option[QTree[A]])(implicit monoid: Monoid[A]) { + upperChild: Option[QTree[A]]) { require(offset >= 0, "QTree can not accept negative values") @@ -81,7 +97,7 @@ case class QTree[A]( def lowerBound: Double = range * offset def upperBound: Double = range * (offset + 1) - private def extendToLevel(n: Int): QTree[A] = { + private def extendToLevel(n: Int)(implicit monoid: Monoid[A]): QTree[A] = { if (n <= level) this else { @@ -93,6 +109,7 @@ case class QTree[A]( QTree[A](nextOffset, nextLevel, count, monoid.zero, Some(this), None) else QTree[A](nextOffset, nextLevel, count, monoid.zero, None, Some(this)) + parent.extendToLevel(n) } } @@ -117,14 +134,14 @@ case class QTree[A]( ancestorLevel.max(level).max(other.level) } - def merge(other: QTree[A]) = { + def merge(other: QTree[A])(implicit monoid: Monoid[A]) = { val commonAncestor = commonAncestorLevel(other) val left = extendToLevel(commonAncestor) val right = other.extendToLevel(commonAncestor) left.mergeWithPeer(right) } - private def mergeWithPeer(other: QTree[A]): QTree[A] = { + private def mergeWithPeer(other: QTree[A])(implicit monoid: Monoid[A]): QTree[A] = { assert(other.lowerBound == lowerBound, "lowerBound " + other.lowerBound + " != " + lowerBound) assert(other.level == level, "level " + other.level + " != " + level) @@ -134,34 +151,33 @@ case class QTree[A]( upperChild = mergeOptions(upperChild, other.upperChild)) } - private def mergeOptions(a: Option[QTree[A]], b: Option[QTree[A]]): Option[QTree[A]] = { + private def mergeOptions(a: Option[QTree[A]], b: Option[QTree[A]])(implicit monoid: Monoid[A]): Option[QTree[A]] = (a, b) match { case (Some(qa), Some(qb)) => Some(qa.mergeWithPeer(qb)) - case (None, _) => b - case (_, None) => a + case (None, right) => right + case (left, None) => left } - } def quantileBounds(p: Double): (Double, Double) = { val rank = math.floor(count * p).toLong (findRankLowerBound(rank).get, findRankUpperBound(rank).get) } - private def findRankLowerBound(rank: Long): Option[Double] = { + private def findRankLowerBound(rank: Long): Option[Double] = if (rank > count) None else { - val childCounts = mapChildrenWithDefault(0L){ _.count } + val childCounts = mapChildrenWithDefault(0L)(_.count) val parentCount = count - childCounts._1 - childCounts._2 - lowerChild.flatMap{ _.findRankLowerBound(rank - parentCount) }.orElse { - val newRank = rank - childCounts._1 - parentCount - if (newRank <= 0) - Some(lowerBound) - else - upperChild.flatMap{ _.findRankLowerBound(newRank) } - } + lowerChild.flatMap { _.findRankLowerBound(rank - parentCount) } + .orElse { + val newRank = rank - childCounts._1 - parentCount + if (newRank <= 0) + Some(lowerBound) + else + upperChild.flatMap{ _.findRankLowerBound(newRank) } + } } - } private def findRankUpperBound(rank: Long): Option[Double] = { if (rank > count) @@ -174,7 +190,7 @@ case class QTree[A]( } } - def rangeSumBounds(from: Double, to: Double): (A, A) = { + def rangeSumBounds(from: Double, to: Double)(implicit monoid: Monoid[A]): (A, A) = { if (from <= lowerBound && to >= upperBound) { val s = totalSum (s, s) @@ -201,28 +217,43 @@ case class QTree[A]( } } - def compress(k: Int) = { + def compress(k: Int)(implicit m: Monoid[A]): QTree[A] = { val minCount = count >> k - val (newTree, pruned) = pruneChildrenWhere{ _.count < minCount } - newTree + if ((minCount > 1L) || (count < 1L)) { + pruneChildren(minCount) + } else { + // count > 0, so for all nodes, if minCount <= 1, then count >= minCount + // so we don't need to traverse + // this is common when you only add few items together, which happens + // on map-side aggregation commonly + this + } } - private def pruneChildrenWhere(fn: QTree[A] => Boolean): (QTree[A], Boolean) = { - if (fn(this)) { - (copy(sum = totalSum, lowerChild = None, upperChild = None), true) + // If we don't prune we MUST return this + private def pruneChildren(minCount: Long)(implicit m: Monoid[A]): QTree[A] = + if (count < minCount) { + copy(sum = totalSum, lowerChild = None, upperChild = None) } else { - val (newLower, lowerPruned) = pruneChildWhere(lowerChild, fn) - val (newUpper, upperPruned) = pruneChildWhere(upperChild, fn) - if (!lowerPruned && !upperPruned) - (this, false) + val newLower = pruneChild(minCount, lowerChild) + val lowerNotPruned = newLower eq lowerChild + val newUpper = pruneChild(minCount, upperChild) + val upperNotPruned = newUpper eq upperChild + if (lowerNotPruned && upperNotPruned) + this else - (copy(lowerChild = newLower, upperChild = newUpper), true) + copy(lowerChild = newLower, upperChild = newUpper) } - } - private def pruneChildWhere(child: Option[QTree[A]], fn: QTree[A] => Boolean): (Option[QTree[A]], Boolean) = { - val result = child.map{ _.pruneChildrenWhere(fn) } - (result.map{ _._1 }, result.map{ _._2 }.getOrElse(false)) + // If we don't prune we MUST return child + @inline + private def pruneChild(minCount: Long, + child: Option[QTree[A]])(implicit m: Monoid[A]): Option[QTree[A]] = child match { + case exists @ Some(oldChild) => + val newChild = oldChild.pruneChildren(minCount) + if (newChild eq oldChild) exists // need to pass the same reference if we don't change + else Some(newChild) + case n @ None => n // make sure we pass the same ref out } def size: Int = { @@ -230,7 +261,7 @@ case class QTree[A]( 1 + childSizes._1 + childSizes._2 } - def totalSum: A = { + def totalSum(implicit monoid: Monoid[A]): A = { val childSums = mapChildrenWithDefault(monoid.zero){ _.totalSum } monoid.plus(sum, monoid.plus(childSums._1, childSums._2)) } @@ -258,6 +289,7 @@ case class QTree[A]( } def interQuartileMean(implicit n: Numeric[A]): (Double, Double) = { + implicit val monoid: Monoid[A] = Ring.numericRing[A] val (l25, u25) = quantileBounds(0.25) val (l75, u75) = quantileBounds(0.75) val (ll, lu) = rangeSumBounds(l25, l75) @@ -267,4 +299,4 @@ case class QTree[A]( (n.toDouble(ll) / luc, n.toDouble(uu) / ulc) } -} \ No newline at end of file +}