/
CountMinSketchTest.scala
200 lines (160 loc) · 6.94 KB
/
CountMinSketchTest.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
package com.twitter.algebird
import org.specs._
import org.scalacheck.Arbitrary
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.Properties
import org.scalacheck.Gen.choose
import org.scalacheck.Prop.forAll
object CountMinSketchLaws extends Properties("CountMinSketch") {
import BaseProperties._
val DELTA = 1E-10
val EPS = 0.001
val SEED = 1
implicit val cmsMonoid = new CountMinSketchMonoid(EPS, DELTA, SEED)
implicit val cmsGen =
Arbitrary {
for (v <- choose(0, 10000)) yield (cmsMonoid.create(v))
}
property("CountMinSketch is a Monoid") = monoidLaws[CMS]
}
class CountMinSketchTest extends Specification {
noDetailedDiffs()
val DELTA = 1E-10
val EPS = 0.001
val SEED = 1
val CMS_MONOID = new CountMinSketchMonoid(EPS, DELTA, SEED)
val RAND = new scala.util.Random
/**
* Returns the exact frequency of {x} in {data}.
*/
def exactFrequency(data : Seq[Long], x : Long) : Long = {
data.filter { _ == x }.size
}
/**
* Returns the exact inner product between two data streams, when the streams
* are viewed as count vectors.
*/
def exactInnerProduct(data1 : Seq[Long], data2 : Seq[Long]) : Long = {
val counts1 = data1.groupBy( x => x ).mapValues( _.size )
val counts2 = data2.groupBy( x => x ).mapValues( _.size )
(counts1.keys.toSet & counts2.keys.toSet).map { k => counts1(k) * counts2(k) }.sum
}
/**
* Returns the elements in {data} that appear at least heavyHittersPct * data.size times.
*/
def exactHeavyHitters(data : Seq[Long], heavyHittersPct : Double) : Set[Long] = {
val counts = data.groupBy( x => x ).mapValues( _.size )
val totalCount = counts.values.sum
counts.filter { _._2 >= heavyHittersPct * totalCount }.keys.toSet
}
"CountMinSketch" should {
"count total number of elements in a stream" in {
val totalCount = 1243
val range = 234
val data = (0 to (totalCount - 1)).map { _ => RAND.nextInt(range).toLong }
val cms = CMS_MONOID.create(data)
cms.totalCount must be_==(totalCount)
}
"estimate frequencies" in {
val totalCount = 5678
val range = 897
val data = (0 to (totalCount - 1)).map { _ => RAND.nextInt(range).toLong }
val cms = CMS_MONOID.create(data)
(0 to 100).foreach { _ =>
val x = RAND.nextInt(range).toLong
val exact = exactFrequency(data, x)
val approx = cms.frequency(x).estimate
val maxError = approx - cms.frequency(x).min
approx must be_>=(exact)
(approx - exact) must be_<=(maxError)
}
}
"exactly compute frequencies in a small stream" in {
val one = CMS_MONOID.create(1)
val two = CMS_MONOID.create(2)
val cms = CMS_MONOID.plus(CMS_MONOID.plus(one, two), two)
cms.frequency(0).estimate must be_==(0)
cms.frequency(1).estimate must be_==(1)
cms.frequency(2).estimate must be_==(2)
val three = CMS_MONOID.create(Seq(1L, 1L, 1L))
three.frequency(1L).estimate must be_==(3)
val four = CMS_MONOID.create(Seq(1L, 1L, 1L, 1L))
four.frequency(1L).estimate must be_==(4)
val cms2 = CMS_MONOID.plus(four, three)
cms2.frequency(1L).estimate must be_==(7)
}
"estimate inner products" in {
val totalCount = 5234
val range = 1390
val data1 = (0 to (totalCount - 1)).map { _ => RAND.nextInt(range).toLong }
val data2 = (0 to (totalCount - 1)).map { _ => RAND.nextInt(range).toLong }
val cms1 = CMS_MONOID.create(data1)
val cms2 = CMS_MONOID.create(data1)
val approxA = cms1.innerProduct(cms2)
val approx = approxA.estimate
val exact = exactInnerProduct(data1, data2)
val maxError = approx - approxA.min
approx must be_==(cms2.innerProduct(cms1).estimate)
approx must be_>=(exact)
(approx - exact) must be_<=(maxError)
}
"exactly compute inner product of small streams" in {
// Nothing in common.
val a1 = List(1L, 2L, 3L)
val a2 = List(4L, 5L, 6L)
CMS_MONOID.create(a1).innerProduct(CMS_MONOID.create(a2)).estimate must be_==(0)
// One element in common.
val b1 = List(1L, 2L, 3L)
val b2 = List(3L, 5L, 6L)
CMS_MONOID.create(b1).innerProduct(CMS_MONOID.create(b2)).estimate must be_==(1)
// Multiple, non-repeating elements in common.
val c1 = List(1L, 2L, 3L)
val c2 = List(3L, 2L, 6L)
CMS_MONOID.create(c1).innerProduct(CMS_MONOID.create(c2)).estimate must be_==(2)
// Multiple, repeating elements in common.
val d1 = List(1L, 2L, 2L, 3L, 3L)
val d2 = List(2L, 3L, 3L, 6L)
CMS_MONOID.create(d1).innerProduct(CMS_MONOID.create(d2)).estimate must be_==(6)
}
"estimate heavy hitters" in {
// Simple way of making some elements appear much more often than others.
val data1 = (1 to 3000).map { _ => RAND.nextInt(3).toLong }
val data2 = (1 to 3000).map { _ => RAND.nextInt(10).toLong }
val data3 = (1 to 1450).map { _ => -1L } // element close to being a 20% heavy hitter
val data = data1 ++ data2 ++ data3
// Find elements that appear at least 20% of the time.
val cms = (new CountMinSketchMonoid(EPS, DELTA, SEED, 0.2)).create(data)
val trueHhs = exactHeavyHitters(data, cms.heavyHittersPct)
val estimatedHhs = cms.heavyHitters
// All true heavy hitters must be claimed as heavy hitters.
(trueHhs.intersect(estimatedHhs) == trueHhs) must be_==(true)
// It should be very unlikely that any element with count less than
// (heavyHittersPct - eps) * totalCount is claimed as a heavy hitter.
val minHhCount = (cms.heavyHittersPct - cms.eps) * cms.totalCount
val infrequent = data.groupBy{ x => x }.mapValues{ _.size }.filter{ _._2 < minHhCount }.keys.toSet
infrequent.intersect(estimatedHhs).size must be_==(0)
}
"drop old heavy hitters when new heavy hitters replace them" in {
val monoid = new CountMinSketchMonoid(EPS, DELTA, SEED, 0.3)
val cms1 = monoid.create(Seq(1L, 2L, 2L))
cms1.heavyHitters must be_==(Set(1L, 2L))
val cms2 = cms1 ++ monoid.create(2L)
cms2.heavyHitters must be_==(Set(2L))
val cms3 = cms2 ++ monoid.create(1L)
cms3.heavyHitters must be_==(Set(1L, 2L))
val cms4 = cms3 ++ monoid.create(Seq(0L, 0L, 0L, 0L, 0L, 0L))
cms4.heavyHitters must be_==(Set(0L))
}
"exactly compute heavy hitters in a small stream" in {
val data1 = Seq(1L, 2L, 2L, 3L, 3L, 3L, 4L, 4L, 4L, 4L, 5L, 5L, 5L, 5L, 5L)
val cms1 = (new CountMinSketchMonoid(EPS, DELTA, SEED, 0.01)).create(data1)
val cms2 = (new CountMinSketchMonoid(EPS, DELTA, SEED, 0.1)).create(data1)
val cms3 = (new CountMinSketchMonoid(EPS, DELTA, SEED, 0.3)).create(data1)
val cms4 = (new CountMinSketchMonoid(EPS, DELTA, SEED, 0.9)).create(data1)
cms1.heavyHitters must be_==(Set(1L, 2L, 3L, 4L, 5L))
cms2.heavyHitters must be_==(Set(2L, 3L, 4L, 5L))
cms3.heavyHitters must be_==(Set(5L))
cms4.heavyHitters must be_==(Set[Long]())
}
}
}