In [4]:
import time

In [6]:
class Trie0(object):

    def __init__(self, key: str=""):
        self.key = key # eg. "01110101"
        self.branch = [None, None]
        self.size = 0 # only counts the leaves
                
    # add the provided key to the trie
    # returns True on success and False on failure (usually
    # when the key is already in the trie)
    def add(self, key: str) -> bool:
        if self.branch[int(key[len(self.key)])] is not None:
            # branch already exists
            branch = self.branch[int(key[len(self.key)])]

            minLen = min(len(key), len(branch.key))
            if key[:minLen] == branch.key[:minLen]:
                if len(key)==len(branch.key):
                    # key already in the trie, cannot insert it
                    return False
                else:
                    # key is a branch of branch
                    #
                    #     11  +insert(111001)  *111* +insert(111001)
                    #    /  \                  /   \
                    # 110  *111*     -->     1110  1111

                    success = branch.add(key)
            else:
                # insert between two nodes in the trie
                # 
                #   11 +insert(111001)     11 self
                #  / \                    / \
                #     \       -->          1110 mid
                #      \                  /    \
                #     111011     key 111001 111011 branch

                for i in range(minLen):
                    if branch.key[i]!=key[i]:
                        # first bit where branch.key and key diverge

                        # create mid Trie node
                        mid = Trie0(key=branch.key[:i])
                        # define mid branches
                        mid.branch[int(key[i])] = Trie0(key=key)
                        mid.branch[1-int(key[i])] = branch
                        # set its size
                        mid.size = branch.size + 1
                        # update self branch to mid
                        self.branch[int(key[len(self.key)])] = mid
                        success=True

                        break
        else:
            # self doesn't have the appropriate branch
            # only useful for root node
            #
            #    . +insert(001)  .            . +insert(110)   .
            #         -->       /     OR     /      -->       / \
            #                 001          001              001 110

            self.branch[int(key[len(self.key)])] = Trie0(key=key)
            success=True

        if success:
            self.size+=1
        return success
                
    # returns True if key in trie and False otherwise
    def find(self, key: str) -> bool:
        if len(self.key) >= len(key):
            return self.key == key
        elif self.branch[int(key[len(self.key)])] is not None:
            return self.branch[int(key[len(self.key)])].find(key)
        else:
            return False

    # returns a list of the n closest keys in the Trie to the given key
    def nClosest(self, key:str, n:int) -> list[str]:
        if self.branch[0] == self.branch[1] == None:
            # leaf of the trie
            return [self.key]
        
        nclosest = []
        if self.branch[int(key[len(self.key)])] is not None:
            # get n closest on the closest branch
            nclosest += self.branch[int(key[len(self.key)])].nClosest(key, n)
        if len(nclosest) < n and self.branch[1-int(key[len(self.key)])] is not None:    
            # if we don't have n keys yet, get the difference from the other branch
            nclosest += self.branch[1-int(key[len(self.key)])].nClosest(key, n-len(nclosest))
        return nclosest

In [7]:
class Key:
    
    # val: value of the key in bytes
    # size: len of the key in bits
    
    # the unused bits are the leading ones
    # e.g b=bytes('01010001 10111100'), size=12 -> key='00001 10111100'
    def __init__(self, bitstring: str="", b: bytes=bytes(), size: int=0):
        if bitstring != "":
            for c in bitstring:
                if c not in ['0', '1']:
                    print("invalid bistring key:", bitstring)
                    return
            self.size = len(bitstring)
            self.val = int(bitstring, 2).to_bytes((len(bitstring) + 7) // 8, byteorder='big')
        elif len(b) > 0:
            if size != 0:
                if size > 8*len(b) or size < 8*(len(b)-1):
                    print("invalid size ("+str(size)+")was given for key:", b)
                    return
                self.size = size
            else:
                self.size = 8*len(b)
            self.val = bytes([b[0] & (255>>8-self.size%8 if self.size%8!=0 else 0)]) + b[1:]

        else:
            # empty key
            self.val = bytes(1)
            self.size = 0
    
    def __len__(self) -> int:
        return self.size
    
    def __eq__(self, other) -> bool:
        return self.size == other.size and \
            self.val[0]<<(self.size%8) == other.val[0]<<(self.size%8) and \
            self.val[1:] == other.val[1:]
    
    def __repr__(self) -> str:
        return "".join(f'{byte:08b}' for byte in self.val)[8-self.size%8:]
    
    def bitAt(self, pos) -> int:
        if pos>=self.size:
            print("keysize="+str(self.size)+", cannot get bitAt("+str(pos)+")")
            return 0
        
        pos += 8 - self.size%8 if self.size%8 != 0 else 0
        if self.val[pos // 8] & (1 << (7 - pos%8)) == 0:
            return 0
        else:
            return 1
            
    def isPrefixOf(self, other) -> bool:
        return self.size <= other.size and str(self)==str(other)[:len(self)]
    
class Trie1:
    
    def __init__(self, key: Key=Key()):
        self.key = key
        self.branch = [None, None]
        self.size = 0 # only counts the number of leaves
            
    def getNodes(self):
        if self.size == 1:
            return [self.key]
        
        nodes = []
        for i in range(2):
            if self.branch[i] is not None:
                nodes += self.branch[i].getNodes()
        return nodes
    
    def add(self, key) -> bool:
        if self.branch[key.bitAt(len(self.key))] is not None:
            branch = self.branch[key.bitAt(len(self.key))]
            if branch.key.isPrefixOf(key):
                if len(key) == len(branch.key):
                    return False
                else:
                    success = branch.add(key)
            else:
                for i in range(min(len(key),len(branch.key))):
                    if key.bitAt(i) != branch.key.bitAt(i):
                        midKey = Key(bitstring=str(key)[:i])
                        break
            
                mid = Trie1(key=midKey)
                mid.branch[key.bitAt(len(midKey))] = Trie1(key=key)
                mid.branch[1-key.bitAt(len(midKey))] = branch

                mid.size = branch.size+1
                self.branch[key.bitAt(len(self.key))] = mid
                success = True
            
        else:
            self.branch[key.bitAt(len(self.key))] = Trie1(key=key)
            success = True
            
        if success:
            self.size += 1      
        return success

    # return the Trie associated with the key on success and None otherwise
    def find(self, key):
        if len(self.key) >= len(key):
            if self.key == key:
                return self
            else:
                return None
        elif self.branch[key.bitAt(len(self.key))] is not None:
            return self.branch[key.bitAt(len(self.key))].find(key)
        else:
            return None
        
    # return an array of the n "closest" peers to a target key, sorted by distance
    def nClosest(self,key,n:int):
        # leaf is the closest of its try to any key
        if self.branch[0] == self.branch[1] == None:
            return [self.key]
        
        nclosest = []
        if self.branch[key.bitAt(len(self.key))] is not None:
            nclosest += self.branch[key.bitAt(len(self.key))].nClosest(key,n)
        if len(nclosest) < n and self.branch[1-key.bitAt(len(self.key))] is not None:
            nclosest += self.branch[1-key.bitAt(len(self.key))].nClosest(key,n-len(nclosest))
        return nclosest

In [8]:
def bytes_to_bit_string(data: bytes) -> str:
    return "".join(f'{byte:08b}' for byte in data)

def int_to_bit_string(i: int, l) -> str:
    return bytes_to_bit_string((i).to_bytes(l//8+1, 'big', signed=False))[-l:]

In [22]:
startTime = time.time()
t0 = Trie0()
for i in range(2**16):
    t0.add(int_to_bit_string(i,16))

executionTime = (time.time() - startTime)
print('Trie0: ' + str(executionTime))

startTime = time.time()
t1 = Trie1()
for i in range(2**16):
    t1.add(Key(b=(i).to_bytes(2, byteorder='big',signed=False)))

executionTime = (time.time() - startTime)
print('Trie1: ' + str(executionTime))

Trie0: 0.9080667495727539
Trie1: 24.897801399230957


In [23]:
startTime = time.time()
for _ in range(1024):
    t1.find(Key(b=(10234).to_bytes(2, byteorder='big',signed=False)))
executionTime = (time.time() - startTime)
print('Trie1: ' + str(executionTime))

startTime = time.time()
for _ in range(1024):
    t0.find(int_to_bit_string(10234,16))
executionTime = (time.time() - startTime)
print('Trie0: ' + str(executionTime))

Trie1: 0.26752758026123047
Trie0: 0.006219148635864258


In [28]:
startTime = time.time()
for i in range(10):
    t1.nClosest(Key(b=(10234).to_bytes(2, byteorder='big',signed=False)),i*10)
executionTime = (time.time() - startTime)
print('Trie1: ' + str(executionTime))

startTime = time.time()
for _ in range(10):
    t0.nClosest(int_to_bit_string(10234,16),i*10)
executionTime = (time.time() - startTime)
print('Trie0: ' + str(executionTime))

Trie1: 0.005110979080200195
Trie0: 0.0010249614715576172
