Skip to content

Commit

Permalink
Merge pull request #530 from non/topic/batching-and-efficiency
Browse files Browse the repository at this point in the history
Add Batched[A] type for efficient lazy addition
  • Loading branch information
johnynek committed Jun 15, 2016
2 parents b8fa8e0 + 7782d87 commit 6ee6f92
Show file tree
Hide file tree
Showing 3 changed files with 410 additions and 3 deletions.
Expand Up @@ -218,7 +218,7 @@ object Aggregator extends java.io.Serializable {
* selected. This assumes that all sampled records can fit in memory, so use this only when the
* expected number of sampled values is small.
*/
def randomSample[T](prob: Double, seed: Int = DefaultSeed): MonoidAggregator[T, List[T], List[T]] = {
def randomSample[T](prob: Double, seed: Int = DefaultSeed): MonoidAggregator[T, Option[Batched[T]], List[T]] = {
assert(prob >= 0 && prob <= 1, "randomSample.prob must lie in [0, 1]")
val rng = new java.util.Random(seed)
Preparer[T]
Expand All @@ -240,8 +240,13 @@ object Aggregator extends java.io.Serializable {
/**
* Put everything in a List. Note, this could fill the memory if the List is very large.
*/
def toList[T]: MonoidAggregator[T, List[T], List[T]] =
prepareMonoid { t: T => List(t) }
def toList[T]: MonoidAggregator[T, Option[Batched[T]], List[T]] =
new MonoidAggregator[T, Option[Batched[T]], List[T]] {
def prepare(t: T): Option[Batched[T]] = Some(Batched(t))
def monoid: Monoid[Option[Batched[T]]] = Monoid.optionMonoid(Batched.semigroup)
def present(o: Option[Batched[T]]): List[T] = o.map(_.toList).getOrElse(Nil)
}

/**
* Put everything in a Set. Note, this could fill the memory if the Set is very large.
*/
Expand Down
331 changes: 331 additions & 0 deletions algebird-core/src/main/scala/com/twitter/algebird/Batched.scala
@@ -0,0 +1,331 @@
package com.twitter.algebird

import scala.annotation.tailrec

/**
* Batched: the free semigroup.
*
* For any type `T`, `Batched[T]` represents a way to lazily combine T
* values as a semigroup would (i.e. associatively). A `Semigroup[T]`
* instance can be used to recover a `T` value from a `Batched[T]`.
*
* Like other free structures, Batched trades space for time. A sum of
* batched values defers the underlying semigroup action, instead
* storing all values in memory (in a tree structure). If an
* underlying semigroup is available, `Batched.semigroup` and
* `Batch.monoid` can be configured to periodically sum the tree to
* keep the overall size below `batchSize`.
*
* `Batched[T]` values are guaranteed not to be empty -- that is, they
* will contain at least one `T` value.
*/
sealed abstract class Batched[T] extends Serializable {

/**
* Sum all the `T` values in this batch using the given semigroup.
*/
def sum(implicit sg: Semigroup[T]): T

/**
* Combine two batched values.
*
* As mentioned above, this just creates a new tree structure
* containing `this` and `that`.
*/
def combine(that: Batched[T]): Batched[T] =
Batched.Items(this, that)

/**
* Compact this batch if it exceeds `batchSize`.
*
* Compacting a branch means summing it, and then storing the summed
* value in a new single-item batch.
*/
def compact(batchSize: Int)(implicit s: Semigroup[T]): Batched[T] =
if (size < batchSize) this else Batched.Item(sum(s))

/**
* Add more values to a batched value.
*
* This method will grow the tree to the left.
*/
def append(that: TraversableOnce[T]): Batched[T] =
that.foldLeft(this)((b, t) => b.combine(Batched(t)))

/**
* Provide an iterator over the underlying tree structure.
*
* This is the order used by `.sum`.
*
* This iterator traverses the tree from left-to-right. If the
* original expression was (w + x + y + z), this iterator returns w,
* x, y, and then z.
*/
def iterator: Iterator[T] =
this match {
case Batched.Item(t) => Iterator.single(t)
case b => new Batched.ForwardItemsIterator(b)
}

/**
* Convert the batch to a `List[T]`.
*/
def toList: List[T] =
reverseIterator.foldLeft(List.empty[T])((ts, t) => t :: ts)

/**
* Provide a reversed iterator over the underlying tree structure.
*
* This iterator traverses the tree from right-to-left. If the
* original expression was (w + x + y + z), this iterator returns z,
* y, x, and then w.
*/
def reverseIterator: Iterator[T] =
this match {
case Batched.Item(t) => Iterator.single(t)
case b => new Batched.ReverseItemsIterator(b)
}

/**
* Report the size of the underlying tree structure.
*
* This is an O(1) operation -- each subtree knows how big it is.
*/
def size: Int
}

object Batched {

/**
* Constructed a batch from a single value.
*/
def apply[T](t: T): Batched[T] =
Item(t)

/**
* Constructed an optional batch from a collection of values.
*
* Since batches cannot be empty, this method returns `None` if `ts`
* is empty, and `Some(batch)` otherwise.
*/
def items[T](ts: TraversableOnce[T]): Option[Batched[T]] =
if (ts.isEmpty) None else {
val it = ts.toIterator
val t0 = it.next
Some(Item(t0).append(it))
}

/**
* Equivalence for batches.
*
* Batches are equivalent if they sum to the same value. Since the
* free semigroup is associative, it's not correct to take tree
* structure into account when determining equality.
*
* One thing to note here is that two equivalent batches might
* produce different lists (for instance, if one of the batches has
* more zeros in it than another one).
*/
implicit def equiv[A](implicit e: Equiv[A], s: Semigroup[A]): Equiv[Batched[A]] =
new Equiv[Batched[A]] {
def equiv(x: Batched[A], y: Batched[A]): Boolean =
e.equiv(x.sum(s), y.sum(s))
}

/**
* The free semigroup for batched values.
*
* This semigroup just accumulates batches and doesn't ever evaluate
* them to flatten the tree.
*/
implicit def semigroup[A]: Semigroup[Batched[A]] =
new Semigroup[Batched[A]] {
def plus(x: Batched[A], y: Batched[A]): Batched[A] = x combine y
}

/**
* Compacting semigroup for batched values.
*
* This semigroup ensures that the batch's tree structure has fewer
* than `batchSize` values in it. When more values are added, the
* tree is compacted using `s`.
*/
def compactingSemigroup[A: Semigroup](batchSize: Int): Semigroup[Batched[A]] =
new BatchedSemigroup[A](batchSize)

/**
* Compacting monoid for batched values.
*
* This monoid ensures that the batch's tree structure has fewer
* than `batchSize` values in it. When more values are added, the
* tree is compacted using `m`.
*
* It's worth noting that `x + 0` here will produce the same sum as
* `x`, but `.toList` will produce different lists (one will have an
* extra zero).
*/
def compactingMonoid[A: Monoid](batchSize: Int): Monoid[Batched[A]] =
new BatchedMonoid[A](batchSize)

/**
* This aggregator batches up `agg` so that all the addition can be
* performed at once.
*
* It is useful when `sumOption` is much faster than using `plus`
* (e.g. when there is temporary mutable state used to make
* summation fast).
*/
def aggregator[A, B, C](batchSize: Int, agg: Aggregator[A, B, C]): Aggregator[A, Batched[B], C] = new Aggregator[A, Batched[B], C] {
def prepare(a: A): Batched[B] = Item(agg.prepare(a))
def semigroup: Semigroup[Batched[B]] = new BatchedSemigroup(batchSize)(agg.semigroup)
def present(b: Batched[B]): C = agg.present(b.sum(agg.semigroup))
}

/**
* This monoid aggregator batches up `agg` so that all the addition
* can be performed at once.
*
* It is useful when `sumOption` is much faster than using `plus`
* (e.g. when there is temporary mutable state used to make
* summation fast).
*/
def monoidAggregator[A, B, C](batchSize: Int, agg: MonoidAggregator[A, B, C]): MonoidAggregator[A, Batched[B], C] =
new MonoidAggregator[A, Batched[B], C] {
def prepare(a: A): Batched[B] = Item(agg.prepare(a))
def monoid: Monoid[Batched[B]] = new BatchedMonoid(batchSize)(agg.monoid)
def present(b: Batched[B]): C = agg.present(b.sum(agg.semigroup))
}

def foldOption[T: Semigroup](batchSize: Int): Fold[T, Option[T]] =
Fold.foldLeft[T, Option[Batched[T]]](Option.empty[Batched[T]]) {
case (Some(b), t) => Some(b.combine(Item(t)).compact(batchSize))
case (None, t) => Some(Item(t))
}.map(_.map(_.sum))

def fold[T](batchSize: Int)(implicit m: Monoid[T]): Fold[T, T] =
Fold.foldLeft[T, Batched[T]](Batched(m.zero)) { (b, t) =>
b.combine(Item(t)).compact(batchSize)
}.map(_.sum)

/**
* This represents a single (unbatched) value.
*/
private[algebird] case class Item[T](t: T) extends Batched[T] {
def size: Int = 1
def sum(implicit sg: Semigroup[T]): T = t
}

/**
* This represents two (or more) batched values being added.
*
* The actual addition is deferred until the `.sum` method is called.
*/
private[algebird] case class Items[T](left: Batched[T], right: Batched[T]) extends Batched[T] {
// Items#size will always be >= 2.
val size: Int = left.size + right.size

def sum(implicit sg: Semigroup[T]): T =
sg.sumOption(new ForwardItemsIterator(this)).get
}

/**
* Abstract iterator through a batch's tree.
*
* This class is agnostic about whether the traversal is
* left-to-right or right-to-left. The abstract method `descend`
* controls which direction the iterator moves.
*/
private[algebird] abstract class ItemsIterator[A](root: Batched[A]) extends Iterator[A] {
var stack: List[Batched[A]] = Nil
var running: Boolean = true
var ready: A = descend(root)

def ascend(): Unit =
stack match {
case Nil =>
running = false
case h :: t =>
stack = t
ready = descend(h)
}

def descend(v: Batched[A]): A

def hasNext: Boolean =
running

def next(): A =
if (running) {
val result = ready
ascend()
result
} else {
throw new NoSuchElementException("next on empty iterator")
}
}

/**
* Left-to-right iterator through a batch's tree.
*/
private[algebird] class ForwardItemsIterator[A](root: Batched[A]) extends ItemsIterator[A](root) {
def descend(v: Batched[A]): A = {
@inline @tailrec def descend0(v: Batched[A]): A =
v match {
case Items(lhs, rhs) =>
stack = rhs :: stack
descend0(lhs)
case Item(value) =>
value
}
descend0(v)
}
}

/**
* Right-to-left iterator through a batch's tree.
*/
private[algebird] class ReverseItemsIterator[A](root: Batched[A]) extends ItemsIterator[A](root) {
def descend(v: Batched[A]): A = {
@inline @tailrec def descend0(v: Batched[A]): A =
v match {
case Items(lhs, rhs) =>
stack = lhs :: stack
descend0(rhs)
case Item(value) =>
value
}
descend0(v)
}
}
}

/**
* Compacting semigroup for batched values.
*
* This semigroup ensures that the batch's tree structure has fewer
* than `batchSize` values in it. When more values are added, the
* tree is compacted using `s`.
*/
class BatchedSemigroup[T: Semigroup](batchSize: Int) extends Semigroup[Batched[T]] {

require(batchSize > 0, s"Batch size must be > 0, found: $batchSize")

def plus(a: Batched[T], b: Batched[T]): Batched[T] =
a.combine(b).compact(batchSize)
}

/**
* Compacting monoid for batched values.
*
* This monoid ensures that the batch's tree structure has fewer
* than `batchSize` values in it. When more values are added, the
* tree is compacted using `m`.
*/
class BatchedMonoid[T: Monoid](batchSize: Int) extends BatchedSemigroup[T](batchSize) with Monoid[Batched[T]] {
val zero: Batched[T] = Batched(Monoid.zero)

// if we knew that (a+b=0) only for (a=0, b=0), we could instead do:
// new Batched.ItemsIterator(b).exists(monoid.isNonZero)
override def isNonZero(b: Batched[T]): Boolean =
Monoid.isNonZero(b.sum)
}

0 comments on commit 6ee6f92

Please sign in to comment.