In [1]:
from functools import wraps
import mmh3

In [2]:
# node with their sockets
from server_config import NODES

class NodeRing():
    # creator of ring
    def __init__(self, nodes):
        assert len(nodes) > 0
        self.nodes = nodes

    # get node's socket by index
    def get_node(self, key_hex):
        # the md5 hash was digest by hex, so convert to int base on 16
        key = int(key_hex, 16)
        # get the node index by modulus
        node_index = key % len(self.nodes)
        # return the socket of the node
        return self.nodes[node_index]

    def hrw(self,key_hex):
        weights = [mmh3.hash(key_hex,i) for i in range(len(NODES))]
        node_index = weights.index(max(weights))
        return self.nodes[node_index]

In [3]:
def test():
    # create a ring with provided Nodes sockets
    ring = NodeRing(nodes=NODES)
    # get the node socket of provided md5 hash value
    node = ring.get_node('9ad5794ec94345c4873c4e591788743a')
    print(node)
    print(ring.get_node('ed9440c442632621b608521b3f2650b8'))
    node2 = ring.hrw('9ad5794ec94345c4873c4e591788743a')
    print(node2)
    print(ring.hrw('ed9440c442632621b608521b3f2650b8'))

In [4]:
test()

{'host': '127.0.0.1', 'port': 4002}
{'host': '127.0.0.1', 'port': 4000}
{'host': '127.0.0.1', 'port': 4003}
{'host': '127.0.0.1', 'port': 4002}


In [56]:
import bisect
class ConsistentHashing():
    def __init__(self, nodes,replication_factor = 8):
        assert len(nodes) > 0
        self.nodes = nodes
        self.M = pow(2,32)
        self.rep = replication_factor
        self.nodering = []
        self.nodehash = {}
        for node in self.nodes:
            self.add_node(node)
    def add_node(self,node):
        _hash = mmh3.hash(str(node).encode()) % self.M
        self.nodehash[_hash] = node
        self.nodering.append(_hash)
        for i in range(self.rep):
            v_hash = mmh3.hash((str(node)+f"#{i}").encode()) % self.M
            self.nodehash[v_hash] = node
            self.nodering.append(v_hash)
        self.nodering.sort()
    def remove_node(self,node):
        rmlist = []
        _hash = mmh3.hash(str(node).encode()) % self.M
        rmlist.append(nodering.append(_hash))
        for i in range(self.rep):
            v_hash = mmh3.hash((str(node)+f"#{i}").encode()) % self.M
            rmlist.append(v_hash)
        #find
        for each in rmlist:
            self.nodering.remove(each)
            self.nodehash.pop(each)
    def get_node(self, key):
        k_hash = mmh3.hash(key) % self.M
        n_i = bisect.bisect_left(self.nodering,k_hash)
        node_list = []
        node_list.append(self.nodehash[self.nodering[n_i]])
        n_nxt = n_i + 1 % len(self.nodering)
        while self.nodehash[self.nodering[n_i]] == self.nodehash[self.nodering[n_nxt]]:
            n_nxt = (n_nxt + 1) % len(self.nodering)
        node_list.append(self.nodehash[self.nodering[n_nxt]])
        return node_list
    def check(self):
        print(self.nodering)
        print(self.nodehash)
    

In [57]:
def test():
    # create a ring with provided Nodes sockets
    ring = ConsistentHashing(nodes=NODES)
    # get the node socket of provided md5 hash value
    node = ring.get_node('9ad5794ec94345c4873c4e591788743a')
    print(node)
    print(ring.get_node('ed9440c442632621b608521b3f2650b8'))

In [58]:
test()

[{'host': '127.0.0.1', 'port': 4000}, {'host': '127.0.0.1', 'port': 4001}]
[{'host': '127.0.0.1', 'port': 4000}, {'host': '127.0.0.1', 'port': 4002}]
