/
SketchMap.scala
229 lines (190 loc) · 7.86 KB
/
SketchMap.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
/*
Copyright 2012 Twitter, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package com.twitter.algebird
/**
* A Sketch Map is a generalized version of the Count-Min Sketch that is an
* approximation of Map[K, V] that stores reference to top heavy hitters. The
* Sketch Map can approximate the sums of any summable value that has a monoid.
*/
/**
* Responsible for creating instances of SketchMap.
*/
class SketchMapMonoid[K, V](eps: Double, delta: Double, seed: Int, heavyHittersCount: Int)
(implicit serialization: K => Array[Byte], valueOrdering: Ordering[V], monoid: Monoid[V])
extends Monoid[SketchMap[K, V]] {
val hashes: Seq[K => Int] = {
val r = new scala.util.Random(seed)
val numHashes = SketchMap.depth(delta)
val numCounters = SketchMap.width(eps)
(0 to (numHashes - 1)).map { _ =>
SketchMapHash[K](CMSHash(r.nextInt, 0, numCounters), seed)
}
}
/**
* All Sketch Maps created with this monoid will have the same parameter configuration.
*/
val params: SketchMapParams[K, V] = SketchMapParams[K, V](hashes, SketchMap.width(eps), SketchMap.depth(delta), heavyHittersCount)
/**
* A zero Sketch Map is one with zero elements.
*/
val zero: SketchMap[K, V] = SketchMap[K, V](params, AdaptiveMatrix.fill(params.depth, params.width)(monoid.zero), Nil, monoid.zero)
/**
* We assume the Sketch Map on the left and right use the same hash functions.
*/
def plus(left: SketchMap[K, V], right: SketchMap[K, V]): SketchMap[K, V] = left ++ right
/**
* Create a Sketch Map sketch out of a single key/value pair.
*/
def create(pair: (K, V)): SketchMap[K, V] = zero + pair
/**
* Create a Sketch Map sketch from a sequence of pairs.
*/
def create(data: Seq[(K, V)]): SketchMap[K, V] = {
data.foldLeft(zero) { case (acc, (key, value)) =>
plus(acc, create(key, value))
}
}
}
object SketchMap {
/**
* Functions to translate between (eps, delta) and (depth, width). The translation is:
* depth = ceil(ln 1/delta)
* width = ceil(e / eps)
*/
def eps(width: Int): Double = scala.math.exp(1.0) / width
def delta(depth: Int): Double = 1.0 / scala.math.exp(depth)
def depth(delta: Double): Int = scala.math.ceil(scala.math.log(1.0 / delta)).toInt
def width(eps: Double): Int = scala.math.ceil(scala.math.exp(1) / eps).toInt
/**
* Generates a monoid used to create SketchMap instances. Requires a
* serialization from K to Array[Byte] for hashing, an ordering for V, and a
* monoid for V.
*/
def monoid[K, V](eps: Double, delta: Double, seed: Int, heavyHittersCount: Int)
(implicit serialization: K => Array[Byte], valueOrdering: Ordering[V], monoid: Monoid[V]): SketchMapMonoid[K, V] = {
new SketchMapMonoid(eps, delta, seed, heavyHittersCount)(serialization, valueOrdering, monoid)
}
}
/**
* Data structure representing an approximation of Map[K, V], where V has an
* implicit ordering and monoid. This is a more generic version of
* CountMinSketch.
*
* Values are stored in valuesTable, a 2D vector containing aggregated sums of
* values inserted to the Sketch Map.
*
* The data structure stores top non-zero values, called Heavy Hitters. The
* values are sorted by an implicit reverse ordering for the value, and the
* number of heavy hitters stored is based on the heavyHittersCount set in
* params.
*
* Use SketchMapMonoid to create instances of this class.
*/
case class SketchMap[K, V](
val params: SketchMapParams[K, V],
val valuesTable: AdaptiveMatrix[V],
val heavyHitterKeys: List[K],
val totalValue: V
)(implicit ordering: Ordering[V], monoid: Monoid[V]) extends java.io.Serializable {
/**
* All of the Heavy Hitter frequencies calculated all at once.
*/
private val heavyHittersMapping: Map[K, V] = calculateHeavyHittersMapping(heavyHitterKeys, valuesTable)
/**
* Ordering used to sort keys by its value. We use the reverse implicit
* ordering on V because we want the hold the "largest" values.
*/
private implicit val keyValueOrdering: Ordering[K] = Ordering.by[K, V] { heavyHittersMapping(_) } reverse
def eps: Double = params.eps
def delta: Double = params.delta
/**
* Returns a sorted list of heavy hitter key/value tuples.
*/
def heavyHitters: List[(K, V)] = heavyHitterKeys.map { item => (item, heavyHittersMapping(item)) }
/**
* Calculates the frequencies for every heavy hitter.
*/
private def calculateHeavyHittersMapping(keys: Iterable[K], table: AdaptiveMatrix[V]): Map[K, V] = {
keys.map { item: K => (item, frequency(item, table)) } toMap
}
/**
* Calculates the frequency for a key given a values table.
*/
private def frequency(key: K, table: AdaptiveMatrix[V]): V = {
val estimates = table.contents.zipWithIndex.map { case (row, i) =>
row(params.hashes(i)(key))
}
estimates.min
}
/**
* Calculates the approximate frequency for any key.
*/
def frequency(key: K): V = {
// If the key is a heavy hitter, then use the precalculated heavy hitters mapping.
// Otherwise, calculate it normally.
heavyHittersMapping.get(key).getOrElse(frequency(key, valuesTable))
}
/**
* Returns a new Sketch Map with a key value pair added.
*/
def +(pair: (K, V)): SketchMap[K, V] = {
val (key, value) = pair
val newHeavyHitters = key :: heavyHitterKeys
val newValuesTable = (0 to (params.depth - 1)).foldLeft(valuesTable) { case (table, row) =>
val pos = (row, params.hashes(row)(key))
val currValue: V = table.getValue(pos)
table.updated(pos, Monoid.plus(currValue, value))
}
SketchMap(params, newValuesTable, updatedHeavyHitters(newHeavyHitters, newValuesTable), Monoid.plus(totalValue, value))
}
/**
* Returns a new Sketch Map summed with another Sketch Map. These should have
* the same parameters, and be generated from the same monoid.
*/
def ++(other: SketchMap[K, V]): SketchMap[K, V] = {
val newValuesTable = Monoid.plus(valuesTable, other.valuesTable)
val newHeavyHitters = (heavyHitterKeys ++ other.heavyHitterKeys).distinct
SketchMap(params, newValuesTable, updatedHeavyHitters(newHeavyHitters, newValuesTable), Monoid.plus(totalValue, other.totalValue))
}
/**
* Returns a new set of sorted and concatenated heavy hitters given an
* arbitrary list of keys.
*/
private def updatedHeavyHitters(hitters: Seq[K], table: AdaptiveMatrix[V]): List[K] = {
val mapping = calculateHeavyHittersMapping(hitters, table)
val specificOrdering = Ordering.by[K, V] { mapping(_) } reverse
hitters.sorted(specificOrdering).take(params.heavyHittersCount).toList
}
}
/**
* Convenience class for holding constant parameters of a Sketch Map.
*/
case class SketchMapParams[K, V](hashes: Seq[K => Int], width: Int, depth: Int, heavyHittersCount: Int) {
assert(0 < width, "width must be greater than 0")
assert(0 < depth, "depth must be greater than 0")
assert(0 <= heavyHittersCount , "heavyHittersCount must be greater than 0")
val eps = SketchMap.eps(width)
val delta = SketchMap.delta(depth)
}
/**
* Hashes an arbitrary key type to one that the Sketch Map can use.
*/
case class SketchMapHash[T](hasher: CMSHash, seed: Int)
(implicit serialization: T => Array[Byte]) extends Function1[T, Int] {
def apply(obj: T): Int = {
val hashKey: Long = MurmurHash128(seed)(serialization(obj)) match {
case (first: Long, second: Long) => (first ^ second)
}
hasher(hashKey)
}
}