# Improvements to Consistent Hashing

Normally, [consistent hashing](https://en.wikipedia.org/wiki/Consistent_hashing) is a little expensive, because each node needs the whole set of keys to know which subset it should be working with.

But with a little ingenuity in key design, we can enable a pattern that allows each node to only query the work it needs to do!

## How Consistent Hashing Works

Consistent hashing works by effectively splitting up a ring into multiple parts, and assigning each node a (more or less) equal share.

It does this by having each node put the same number of dots on a circle:

In [24]:
import math
PointNode = namedtuple("PointNode", ["point", "node"])

POINTS_BY_NODE = [
    PointNode(0, "a"),
    PointNode(math.pi / 2, "b"),
    PointNode(math.pi, "c"),
    PointNode(math.pi * 3 / 2, "d'")
]

Effectively enabling buckets in between the points. In the example above, we can just find the point that is less than the point we're attempting to bucket:

In [28]:
import bisect

def get_node_for_point(node_by_point, point):
    """ given the node_by_point, return the node that the point belongs to. """
    as_point_node = PointNode(point, "_")
    index = bisect.bisect_right(node_by_point, as_point_node)
    if index == len(node_by_point):
        index = -1
    return node_by_point[index].node

get_node_for_point(POINTS_BY_NODE, math.pi * 7 / 4)

"d'"

We can construct our own ring from any arbitrary set of nodes, as long as we have a way to uniquely name on versus the other:

In [37]:
import bisect
import math
import pprint
from collections import namedtuple
LENGTH = 2 * math.pi

PointNode = namedtuple("PointNode", ["point", "node"])

def _calculate_point_for_node(node, point_num):
    """ return back the point for the node, between 0 and 2 * PI """
    return hash(node + str(point_num)) % LENGTH

def points_for_node(node, num_points):
    return [_calculate_point_for_node(node, i) for i in range(num_points)]

def get_node_by_point(node_names, num_points):
    """ return a tuple of (point, node), ordering by point """
    point_by_node = [PointNode(p, n) for n in node_names for p in points_for_node(n, num_points)]
    point_by_node.sort()
    return point_by_node

node_by_point = get_node_by_point(["a", "b", "c", "d"], 4)
get_node_for_point(node_by_point, 2)

'a'

## Bucketing the Points without all the keys
Normaly, consistent hashing requires the one executing the algorithm to be aware of two sets of data:
    
    1. the identifiers of all the nodes in the cluster
    2. the set of keys to assign.
    
This is because the standard algorithm runs through the list of all keys, and assigns them:

In [34]:
def assign_nodes(node_by_point, items):
    key_by_bucket = {}
    for i in items:
        value = hash(i) % LENGTH
        node = get_node_for_point(node_by_point, value)
        key_by_bucket.setdefault(node, [])
        key_by_bucket[node].append(i)
    return key_by_bucket

items = list(range(40))
assign_nodes(node_by_point, items) 

{'a': [1, 2, 4, 6, 8, 11, 14, 17, 20, 23, 26, 27, 28, 31, 33, 36, 39],
 'b': [3, 5, 9, 12, 16, 18, 22, 24, 34, 37],
 'c': [10, 25, 29, 32, 35],
 'd': [0, 7, 13, 15, 19, 21, 30, 38]}

(note the lack of even distribution here: as a pseudorandom algorithm, you will end up with some minor uneven distribution. We'll talk about that later.)

But getting all keys can be inefficient for larger data sets. What happens when we want to consistently hash against a data set of 1 million points?

Consistent hashing requires every node to have the full set of keys. But what if each node could just query for the data that's important to it?

There is a way to know what those are. Given all the nodes, we can calculate which ranges each node is responsible for:

In [36]:
def get_ranges_by_node(node_by_point):
    """ return a Dict[node, List[Tuple[lower_bound, upper_bound]]] for the raw nodes by point """
    range_by_node = {}
    previous_point, previous_node = 0, node_by_point[-1].node
    for point, node in node_by_point:
        point_range = (previous_point, point)
        range_by_node.setdefault(node, [])
        range_by_node[node].append(point_range)
        previous_point, previous_node = point, node
    # we close the loop by one last range to the end of the ring
    first_node = node_by_point[0].node
    range_by_node[first_node].append((previous_point, LENGTH))
    
    return range_by_node

get_ranges_by_node(node_by_point)

{'a': [(0.7221837052256888, 2.1398018192684205),
  (2.8125462380799036, 2.9036464881262134),
  (3.939380619805206, 4.844545637049649),
  (5.724601884363189, 6.102014634518035)],
 'b': [(2.4981986005276724, 2.8125462380799036),
  (2.9036464881262134, 3.4668217318932335),
  (3.4668217318932335, 3.5793269111334993),
  (4.909914260552789, 5.724601884363189)],
 'c': [(0.4962708040938111, 0.7089051885000046),
  (2.4620542383121276, 2.4981986005276724),
  (3.5793269111334993, 3.939380619805206),
  (6.102014634518035, 6.169806707365048)],
 'd': [(0, 0.4962708040938111),
  (0.7089051885000046, 0.7221837052256888),
  (2.1398018192684205, 2.4620542383121276),
  (4.844545637049649, 4.909914260552789),
  (6.169806707365048, 6.283185307179586)]}

Now we have the ranges this node is responsible for. Now we just need a database that knows how to query these ranges.

We can accomplish this by storing the range value in the database itself, and index against that:

In [44]:
import bisect
import random
import string

def _calculate_point(value):
    return hash(value) % LENGTH

def _random_string():
    return ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))

VALUES = [_random_string() for _ in range(100)]
DATABASE = {_calculate_point(v): v for v in VALUES}
INDEX = sorted(DATABASE.keys())

def query_database(index, database, bounds):
    lower, upper = bounds
    lower_index = bisect.bisect_right(index, lower)
    upper_index = bisect.bisect_left(index, upper)
    return [database[index[i]] for i in range(lower_index, upper_index)]

query_database(INDEX, DATABASE, (0.5, 0.6))

['D6N66HENOM']

At that point, we can pinpoint and query the specific values that are relevant to our node. We can accomplish this with just the information about the nodes themselves:

In [45]:
def query_values_for_node(node_by_point, index, database, node):
    range_by_node = get_ranges_by_node(node_by_point)
    values = []
    for bounds in range_by_node[node]:
        values += query_database(index, database, bounds)
        
    return values
    
query_values_for_node(node_by_point, INDEX, DATABASE, "a")

['98YLIK05FO',
 'G65IPQJPXK',
 'KMF6NLYDEB',
 '0RBYONF0XK',
 '7U8PC79F3V',
 'CEOLWMNI3W',
 'Y7QNLCAXEO',
 '3JFM658SZZ',
 'AOT371FQGD',
 'PVPMM7S75V',
 'A89JB63ULD',
 '4NDV0AWK6U',
 'UAVSW4MQBN',
 'VBX3JSM3TY',
 'T4CW8ASMES',
 'TC17WA4A7X',
 '1PLBQO1Q9N',
 'MGM68X168W',
 'L21PQREYGF',
 '316IBN0BHP',
 'M05207VFGC',
 '6MC5TS7OJN',
 'I6CH3AXE76',
 'J6OXH0UHZL',
 'MD5ZXGSQS7',
 '5XIV9B1CKA',
 '4WDGYWCA43',
 'Z86M8ILNL3',
 'ZPGE2WL9PF',
 'VLTQKJ44Z3',
 'V8D46BOPIH',
 'GLDCOECKE3',
 'YRVACTQ6LF',
 'GQH0ZEIAKJ',
 'F11EV0HSP8',
 'MLTTRGRVH5',
 'QLP8FSLY50',
 'BW507S1M1C',
 'T9Q46PDYFA',
 'EPDNXCGLDX',
 'H9CLUQZ35M',
 'W1WTBYAWJR',
 'XFL30R5CHB',
 'FIWLOXG4FE',
 'B2F4218G10']

There's additional performance benefits that can come from storing the index as it's position on the ring. If your database ensures data locality using the same key (such as DynamoDB's shard key), you can gain the advantage of all of your keys living close to each other on disk. This can make the reads for each node's items even faster. 


## Bucketing Values Evenly

As you may have noted earlier, the buckets themselves are not always even. That depends entirely on the distribution of points: for a random distribution, and a high enough number, we will have an extremely high likelyhood of bucketing evenly.

So how man buckets is enough? With the approach explained above, it's important to keep the bucket count low: the lower, the fewer queries that have to be made on the database, and the more performant the query on the database.