-
Notifications
You must be signed in to change notification settings - Fork 343
/
Batched.scala
320 lines (284 loc) · 10.4 KB
/
Batched.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
package com.twitter.algebird
import scala.annotation.tailrec
import scala.collection.compat._
/**
* Batched: the free semigroup.
*
* For any type `T`, `Batched[T]` represents a way to lazily combine T values as a semigroup would (i.e.
* associatively). A `Semigroup[T]` instance can be used to recover a `T` value from a `Batched[T]`.
*
* Like other free structures, Batched trades space for time. A sum of batched values defers the underlying
* semigroup action, instead storing all values in memory (in a tree structure). If an underlying semigroup is
* available, `Batched.semigroup` and `Batch.monoid` can be configured to periodically sum the tree to keep
* the overall size below `batchSize`.
*
* `Batched[T]` values are guaranteed not to be empty -- that is, they will contain at least one `T` value.
*/
sealed abstract class Batched[T] extends Serializable {
/**
* Sum all the `T` values in this batch using the given semigroup.
*/
def sum(implicit sg: Semigroup[T]): T
/**
* Combine two batched values.
*
* As mentioned above, this just creates a new tree structure containing `this` and `that`.
*/
def combine(that: Batched[T]): Batched[T] =
Batched.Items(this, that)
/**
* Compact this batch if it exceeds `batchSize`.
*
* Compacting a branch means summing it, and then storing the summed value in a new single-item batch.
*/
def compact(batchSize: Int)(implicit s: Semigroup[T]): Batched[T] =
if (size < batchSize) this else Batched.Item(sum(s))
/**
* Add more values to a batched value.
*
* This method will grow the tree to the left.
*/
def append(that: TraversableOnce[T]): Batched[T] =
that.iterator.foldLeft(this)((b, t) => b.combine(Batched(t)))
/**
* Provide an iterator over the underlying tree structure.
*
* This is the order used by `.sum`.
*
* This iterator traverses the tree from left-to-right. If the original expression was (w + x + y + z), this
* iterator returns w, x, y, and then z.
*/
def iterator: Iterator[T] =
this match {
case Batched.Item(t) => Iterator.single(t)
case b => new Batched.ForwardItemsIterator(b)
}
/**
* Convert the batch to a `List[T]`.
*/
def toList: List[T] =
reverseIterator.foldLeft(List.empty[T])((ts, t) => t :: ts)
/**
* Provide a reversed iterator over the underlying tree structure.
*
* This iterator traverses the tree from right-to-left. If the original expression was (w + x + y + z), this
* iterator returns z, y, x, and then w.
*/
def reverseIterator: Iterator[T] =
this match {
case Batched.Item(t) => Iterator.single(t)
case b => new Batched.ReverseItemsIterator(b)
}
/**
* Report the size of the underlying tree structure.
*
* This is an O(1) operation -- each subtree knows how big it is.
*/
def size: Int
}
object Batched {
/**
* Constructed a batch from a single value.
*/
def apply[T](t: T): Batched[T] =
Item(t)
/**
* Constructed an optional batch from a collection of values.
*
* Since batches cannot be empty, this method returns `None` if `ts` is empty, and `Some(batch)` otherwise.
*/
def items[T](ts: TraversableOnce[T]): Option[Batched[T]] =
if (ts.iterator.isEmpty) None
else {
val it = ts.iterator
val t0 = it.next()
Some(Item(t0).append(it))
}
/**
* Equivalence for batches.
*
* Batches are equivalent if they sum to the same value. Since the free semigroup is associative, it's not
* correct to take tree structure into account when determining equality.
*
* One thing to note here is that two equivalent batches might produce different lists (for instance, if one
* of the batches has more zeros in it than another one).
*/
implicit def equiv[A](implicit e: Equiv[A], s: Semigroup[A]): Equiv[Batched[A]] =
new Equiv[Batched[A]] {
override def equiv(x: Batched[A], y: Batched[A]): Boolean =
e.equiv(x.sum(s), y.sum(s))
}
/**
* The free semigroup for batched values.
*
* This semigroup just accumulates batches and doesn't ever evaluate them to flatten the tree.
*/
implicit def semigroup[A]: Semigroup[Batched[A]] =
new Semigroup[Batched[A]] {
override def plus(x: Batched[A], y: Batched[A]): Batched[A] = x.combine(y)
}
/**
* Compacting semigroup for batched values.
*
* This semigroup ensures that the batch's tree structure has fewer than `batchSize` values in it. When more
* values are added, the tree is compacted using `s`.
*/
def compactingSemigroup[A: Semigroup](batchSize: Int): Semigroup[Batched[A]] =
new BatchedSemigroup[A](batchSize)
/**
* Compacting monoid for batched values.
*
* This monoid ensures that the batch's tree structure has fewer than `batchSize` values in it. When more
* values are added, the tree is compacted using `m`.
*
* It's worth noting that `x + 0` here will produce the same sum as `x`, but `.toList` will produce
* different lists (one will have an extra zero).
*/
def compactingMonoid[A: Monoid](batchSize: Int): Monoid[Batched[A]] =
new BatchedMonoid[A](batchSize)
/**
* This aggregator batches up `agg` so that all the addition can be performed at once.
*
* It is useful when `sumOption` is much faster than using `plus` (e.g. when there is temporary mutable
* state used to make summation fast).
*/
def aggregator[A, B, C](batchSize: Int, agg: Aggregator[A, B, C]): Aggregator[A, Batched[B], C] =
new Aggregator[A, Batched[B], C] {
override def prepare(a: A): Batched[B] = Item(agg.prepare(a))
override def semigroup: Semigroup[Batched[B]] =
new BatchedSemigroup(batchSize)(agg.semigroup)
override def present(b: Batched[B]): C = agg.present(b.sum(agg.semigroup))
}
/**
* This monoid aggregator batches up `agg` so that all the addition can be performed at once.
*
* It is useful when `sumOption` is much faster than using `plus` (e.g. when there is temporary mutable
* state used to make summation fast).
*/
def monoidAggregator[A, B, C](
batchSize: Int,
agg: MonoidAggregator[A, B, C]
): MonoidAggregator[A, Batched[B], C] =
new MonoidAggregator[A, Batched[B], C] {
override def prepare(a: A): Batched[B] = Item(agg.prepare(a))
override def monoid: Monoid[Batched[B]] = new BatchedMonoid(batchSize)(agg.monoid)
override def present(b: Batched[B]): C = agg.present(b.sum(agg.semigroup))
}
def foldOption[T: Semigroup](batchSize: Int): Fold[T, Option[T]] =
Fold
.foldLeft[T, Option[Batched[T]]](Option.empty[Batched[T]]) {
case (Some(b), t) => Some(b.combine(Item(t)).compact(batchSize))
case (None, t) => Some(Item(t))
}
.map(_.map(_.sum))
def fold[T](batchSize: Int)(implicit m: Monoid[T]): Fold[T, T] =
Fold
.foldLeft[T, Batched[T]](Batched(m.zero))((b, t) => b.combine(Item(t)).compact(batchSize))
.map(_.sum)
/**
* This represents a single (unbatched) value.
*/
private[algebird] case class Item[T](t: T) extends Batched[T] {
override def size: Int = 1
override def sum(implicit sg: Semigroup[T]): T = t
}
/**
* This represents two (or more) batched values being added.
*
* The actual addition is deferred until the `.sum` method is called.
*/
private[algebird] case class Items[T](left: Batched[T], right: Batched[T]) extends Batched[T] {
// Items#size will always be >= 2.
override val size: Int = left.size + right.size
override def sum(implicit sg: Semigroup[T]): T =
sg.sumOption(new ForwardItemsIterator(this)).get
}
/**
* Abstract iterator through a batch's tree.
*
* This class is agnostic about whether the traversal is left-to-right or right-to-left. The abstract method
* `descend` controls which direction the iterator moves.
*/
private[algebird] abstract class ItemsIterator[A](root: Batched[A]) extends Iterator[A] {
var stack: List[Batched[A]] = Nil
var running: Boolean = true
var ready: A = descend(root)
def ascend(): Unit =
stack match {
case Nil =>
running = false
case h :: t =>
stack = t
ready = descend(h)
}
def descend(v: Batched[A]): A
override def hasNext: Boolean =
running
override def next(): A =
if (running) {
val result = ready
ascend()
result
} else {
throw new NoSuchElementException("next on empty iterator")
}
}
/**
* Left-to-right iterator through a batch's tree.
*/
private[algebird] class ForwardItemsIterator[A](root: Batched[A]) extends ItemsIterator[A](root) {
override def descend(v: Batched[A]): A = {
@inline @tailrec def descend0(v: Batched[A]): A =
v match {
case Items(lhs, rhs) =>
stack = rhs :: stack
descend0(lhs)
case Item(value) =>
value
}
descend0(v)
}
}
/**
* Right-to-left iterator through a batch's tree.
*/
private[algebird] class ReverseItemsIterator[A](root: Batched[A]) extends ItemsIterator[A](root) {
override def descend(v: Batched[A]): A = {
@inline @tailrec def descend0(v: Batched[A]): A =
v match {
case Items(lhs, rhs) =>
stack = lhs :: stack
descend0(rhs)
case Item(value) =>
value
}
descend0(v)
}
}
}
/**
* Compacting semigroup for batched values.
*
* This semigroup ensures that the batch's tree structure has fewer than `batchSize` values in it. When more
* values are added, the tree is compacted using `s`.
*/
class BatchedSemigroup[T: Semigroup](batchSize: Int) extends Semigroup[Batched[T]] {
require(batchSize > 0, s"Batch size must be > 0, found: $batchSize")
override def plus(a: Batched[T], b: Batched[T]): Batched[T] =
a.combine(b).compact(batchSize)
}
/**
* Compacting monoid for batched values.
*
* This monoid ensures that the batch's tree structure has fewer than `batchSize` values in it. When more
* values are added, the tree is compacted using `m`.
*/
class BatchedMonoid[T: Monoid](batchSize: Int)
extends BatchedSemigroup[T](batchSize)
with Monoid[Batched[T]] {
override val zero: Batched[T] = Batched(Monoid.zero)
// if we knew that (a+b=0) only for (a=0, b=0), we could instead do:
// new Batched.ItemsIterator(b).exists(monoid.isNonZero)
override def isNonZero(b: Batched[T]): Boolean =
Monoid.isNonZero(b.sum)
}