# Count-min sketch



Think of the count-min sketch as a generalization of the Bloom filter:  instead of overestimating _whether or not_ we've seen a certain key, the count-min sketch overestimates _how many times_ we've seen it.  You could implement a precise structure to solve this problem with a map from keys to counts (a tree, an associative array, or a hash table, for example), but -- just as with the Bloom filter -- there are cases in which the space requirements of a precise structure may be unacceptable.

We'll start by importing some necessary libraries -- `numpy`, `pandas`, and our hash functions  -- again.

In [None]:
from datasketching.hashing import hashes_for
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
class CMS(object):
    def __init__(self, width, hashes):
        """ Initializes a Count-min sketch with the
            given width and a collection of hashes, 
            which are functions taking arbitrary 
            values and returning integers.  The depth
            of the sketch structure is taken from the
            number of supplied hash functions.
            
            hashes can be either a function taking 
            a value and returning a list of results
            or a list of functions.  In the latter 
            case, this constructor will synthesize 
            the former """
        self.__width = width
        
        if hasattr(hashes, '__call__'):
            self.__hashes = hashes
            # inspect the tuple returned by the hash function to get a depth
            self.__depth = len(hashes(bytes()))
        else:
            funs = hashes[:]
            self.__depth = len(hashes)
            def h(value):
                return [int(f(value)) for f in funs]
            self.__hashes = h
        
        self.__buckets = np.zeros((int(width), int(self.__depth)), np.uint64)
    
    
    def width(self):
        return self.__width
    
    def depth(self):
        return self.__depth
    
    def insert(self, value):
        """ Inserts a value into this sketch """
        for (row, col) in enumerate(self.__hashes(value)):
            self.__buckets[col % self.__width][row] += 1
    
    def lookup(self, value):
        """ Returns a biased estimate of number of times value has been inserted in this sketch"""
        return min([self.__buckets[col % self.__width][row] for (row, col) in enumerate(self.__hashes(value))])
    
    def merge_from(self, other):
        """ Merges other in to this sketch by 
            adding the counts from each bucket in other
            to the corresponding buckets in this
            
            Updates this. """
        self.__buckets += other.__buckets
    
    def merge(self, other):
        """ Creates a new sketch by merging this sketch's
            counts with those of another sketch. """
        
        cms = CMS(self.width(), self.__hashes)
        cms.__buckets += self.__buckets
        cms.__buckets += other.__buckets
        return cms
    
    def inner(self, other):
        """ returns the inner product of self and other, estimating 
            the equijoin size between the streams modeled by 
            self and other """
        r, = np.tensordot(self.__buckets, other.__buckets).flat
        return r
    
    def minimum(self, other):
        """ Creates a new sketch by taking the elementwise minimum 
            of this sketch and another. """
        cms = CMS(self.width(), self.__hashes)
        cms.__buckets = np.minimum(self.__buckets, other.__buckets)
        return cms

    def dup(self):
        cms = CMS(self.width(), self.__hashes)
        cms.merge_from(self)
        return cms

In [None]:
cms = CMS(16384, hashes_for(3,8))

In [None]:
cms.lookup("foo")

In [None]:
cms.insert("foo")
cms.lookup("foo")

While hash collisions in Bloom filters lead to false positives, hash collisions in count-min sketches lead to overestimating counts.  To see how much this will affect us in practice, we can design an empirical experiment to plot the cumulative distribution of the factors that we've overestimated counts by in sketches of various sizes.

In [None]:
def cms_experiment(sample_count, size, hashes, seed=0x15300625):
    import random
    from collections import namedtuple
   
    random.seed(seed)
    cms = CMS(size, hashes)
    
    result = []
    total_count = 0
    
    # update the counts
    for i in range(sample_count):
        bits = random.getrandbits(64)
        if i % 100 == 0:
            # every hundredth entry is a heavy hitter
            insert_count = (bits % 512) + 1
        else:
            insert_count = (bits % 8) + 1
        
        for i in range(insert_count):
            cms.insert(bits)
    
    random.seed(seed)
    # look up the bit sequences again
    for i in range(sample_count):
        bits = random.getrandbits(64)
        if i % 100 == 0:
            # every hundredth entry is a heavy hitter
            expected_count = (bits % 512) + 1
        else:
            expected_count = (bits % 8) + 1

        result.append((int(cms.lookup(bits)), int(expected_count)))
    
    return result

In [None]:
results = cms_experiment(1 << 14, 4096, hashes_for(3, 8))
df = pd.DataFrame.from_records(results)
df.rename(columns={0: "actual count", 1: "expected count"}, inplace=True)
sns.distplot(df["actual count"] / df["expected count"], hist_kws=dict(cumulative=True), kde_kws=dict(cumulative=True))

As you can see, about 55% of our counts for this small sketch are overestimated by less than a factor of three, although the worst overestimates are quite large indeed.  Let's try with a larger sketch structure.

In [None]:
results = cms_experiment(1 << 14, 8192, hashes_for(3, 8))
df = pd.DataFrame.from_records(results)
df.rename(columns={0: "actual count", 1: "expected count"}, inplace=True)

sns.distplot(df["actual count"] / df["expected count"], hist_kws=dict(cumulative=True), kde_kws=dict(cumulative=True))

With a larger filter size (columns) *and* more hash functions (rows), we can dramatically reduce the bias.

In [None]:
results = cms_experiment(1 << 14, 8192, hashes_for(8, 5))
df = pd.DataFrame.from_records(results)
df.rename(columns={0: "actual count", 1: "expected count"}, inplace=True)

sns.distplot(df["actual count"] / df["expected count"], hist_kws=dict(cumulative=True), kde_kws=dict(cumulative=True))

## Exercises

Here are some exercises to try out if you're interested in extending the count-min sketch:

* The count-min sketch is a biased estimator.  Implement a technique to adjust the estimates for expected bias.
* When paired with an auxiliary structure like a priority queue, the count-min sketch can be used to track the top-_k_ event types in a stream.  Try implementing a couple of approaches!
* Consider how you'd handle negative inserts.  How would you need to change the query code?  What else might change?
* The implementation includes a `minimum` method.  What might it be useful for?  What limitations might it have?

