In [92]:
class Node:
    def __init__(self):
        self.keys = []
        self.children = []
        self.parent = None

In [94]:

class TwoThreeFourTree:
    def __init__(self):
        self.root = None                                                          
    
    def _split_node(self, node):
        mid = node.keys[1]
        left_node = Node()
        left_node.keys.append(node.keys[0])
        right_node = Node()
        right_node.keys.append(node.keys[2])
        
        if node.children:
            left_node.children.append(node.children[0])
            left_node.children.append(node.children[1])
            node.children[0].parent = left_node
            node.children[1].parent = left_node
        
            right_node.children.append(node.children[2])
            right_node.children.append(node.children[3])
            node.children[2].parent = right_node
            node.children[3].parent = right_node
            
        if node is self.root:
            self.root = Node()
            self.root.keys.append(mid)
            self.root.children.append(left_node)
            self.root.children.append(right_node)
            left_node.parent = self.root
            right_node.parent = self.root
            
        else:
            P = node.parent
            i = P.children.index(node)   
            P.keys.insert(i, mid)
            P.children.pop(i)
            P.children.insert(i, left_node)
            P.children.insert(i + 1, right_node)
            left_node.parent = P
            right_node.parent = P
        
        return left_node, right_node
        
    def insert(self, key):
        if not self.root:
            self.root = Node()
            self.root.keys.append(key)
        else:
            self._insert(self.root, key)
    
    def _insert(self, node, key):
        i = 0
        for node_key in node.keys:
            if key > node_key:
                i += 1
        if len(node.keys) < 3:
            if not node.children:
                node.keys.insert(i, key)
            else:
                self._insert(node.children[i], key)
        else:
            left_node, right_node = self._split_node(node)
            if not node.children:
                if i == 0 or i == 1:
                    self._insert(left_node, key)
                else:
                    self._insert(right_node, key)
            else: 
                if i == 0 or i == 1:
                    self._insert(left_node.children[i], key)
                else:
                    self._insert(right_node.children[i - 2], key)
    
    def _steal(self, node):
        P = node.parent
        if P:
            j = P.children.index(node)
            if j == 0:
                S_L, S_R = None, P.children[j + 1]
            elif j == len(P.children) - 1:
                S_L, S_R = P.children[j - 1], None
            else:
                S_L, S_R = P.children[j - 1], P.children[j + 1]

            if S_L and len(S_L.keys) > 1:
                S_key = S_L.keys.pop()
                P_key = P.keys.pop(j - 1)
                P.keys.insert(j - 1, S_key)
                node.keys.insert(0, P_key)
                if S_L.children:
                    S_child = S_L.children.pop()
                    node.children.insert(0, S_child)
                    S_child.parent = node
                return True
            
            elif S_R and len(S_R.keys) > 1:
                S_key = S_R.keys.pop(0)
                P_key = P.keys.pop(j)
                P.keys.insert(j, S_key)
                node.keys.append(P_key)
                if S_R.children:
                    S_child = S_R.children.pop(0)
                    node.children.append(S_child)
                    S_child.parent = node
                return True
        
        return False
            
    def _merge(self, node):
        P = node.parent
        # assume P has at least 2 keys
        if P:
            j = P.children.index(node)
            if j == 0:
                S_L, S_R = None, P.children[j + 1]
            elif j == len(P.children) - 1:
                S_L, S_R = P.children[j - 1], None
            else:
                S_L, S_R = P.children[j - 1], P.children[j + 1]
                
            if S_L:
                P_key = P.keys.pop(j - 1)
                S_L.keys.append(P_key)
                S_L.keys.append(node.keys[0])
                P.children.pop(j)
                for child in node.children:
                    S_L.children.append(child)
                    child.parent = S_L
                return S_L
                
            elif S_R:
                P_key = P.keys.pop(j)
                S_R.keys.insert(0, P_key)
                S_R.keys.insert(0, node.keys[0])
                P.children.pop(j)
                for child in node.children[::-1]:
                    S_R.children.insert(0, child)
                    child.parent = S_R
                return S_R
                
    def delete(self, key):
        self._delete(self.root, key)
        
    def _delete(self, node, key):
        # preprocessing: if we encounter a node before we decsend into it 
        # with only 1 key, steal or merge to ensure it has at least 2 keys
        # remark that we only merge root if both children have exactly 1 key 
        if not node: return
        # no adjacent siblings: node is root
        # (i) root has no children: check if the only key is to be deleted
        if node is self.root and not node.children:
            for i in range(len(node.keys)):
                if node.keys[i] == key:
                    node.keys.pop(i)
                    break
            if not node.keys:
                self.root = None
                
        # (ii) root has two children, both have only 1 key: merge
        elif node is self.root and node.children and len(node.keys) == 1 \
            and len(node.children[0].keys) == 1 \
            and len(node.children[1].keys) == 1:
            self.root = self._merge(node.children[0])
            self.root.parent = None
            self._delete(self.root, key)
        
        # (iii) root has two children, at least 1 have 2 keys:
        # if we descend into a node with 1 key, we can steal from another  
        # so it is just normal case
        
        # (iv) root has more than 1 key: normal case
        
        # (v) node is not root and has only 1 keys: make it 2 keys
        elif node is not self.root and len(node.keys) == 1: 
            # adjacent sibling has at least 2 keys: steal
            if self._steal(node):
                self._delete(node, key)
            # adjacent sibling has only 1 key: merge 
            else:
                self._delete(self._merge(node), key)
        
        # now we go to normal case
        else:
            i = 0
            for node_key in node.keys:
                if key > node_key:
                    i += 1
                elif key == node_key:
                    # case 1: node is leaf and has at least 2 keys: simply remove
                    if not node.children:
                        node.keys.pop(i)
                    # case 2: node is not leaf:
                    # (a) node is internal node and has at least 2 keys: (i) (ii) (iii)
                    # (b) node is root and has at least 2 keys: (i) (ii) (iii)
                    # (c) node is root and one child has at least 2 keys: (i) (ii)
                    else:
                        C_L = node.children[i]
                        C_R = node.children[i + 1]
                        # (i) left child has at least 2 keys: 
                        #     replace key by predecessor in left child 
                        #     and call delete predecessor
                        if len(C_L.keys) > 1:
                            predecessor = self._get_in_order_predecessor(C_L)
                            node.keys[i] = predecessor
                            self._delete(C_L, predecessor)
                        # (ii) right child has at least 2 keys: 
                        #      replace key by successor in right child 
                        #      and call delete successor
                        elif len(C_R.keys) > 1:
                            successor = self._get_in_order_successor(C_R)
                            node.keys[i] = successor
                            self._delete(C_R, successor)
                        # (iii) both left and right child have only 1 key: merge 
                        # we guarantee node has at least 2 keys by preprocessing
                        else:
                            self._delete(self._merge(C_R), key)                    

                    return
           
            if node.children:
                self._delete(node.children[i], key)
    
    def _get_in_order_successor(self, node):
        temp = node
        while temp.children:
            temp = temp.children[0]
        return temp.keys[0]

    def _get_in_order_predecessor(self, node):
        temp = node
        while temp.children:
            temp = temp.children[-1]
        return temp.keys[-1]
    
    def search(self, key):
        node = self._search(self.root, key)
        return node
    
    def _search(self, node, key):
        if not node:
            return None
        i = 0
        for node_key in node.keys:
            if key > node_key:
                i += 1
            elif key == node_key:
                return node
        if not node.children:
            return None      
        else:
            return self._search(node.children[i], key)
                
    def is_in_TwoThreeFourTree(self, key):
        return self.search(key) is not None

    def update(self, old_key, key):
        if self.is_in_TwoThreeFourTree(old_key):
            self.delete(old_key)
            self.insert(key)
        
    def in_order_traversal(self):
        result = []
        self._in_order_traversal(self.root, result)
        return result
    
    def _in_order_traversal(self, node, result):
        if node.children:
            for i in range(len(node.keys)):
                self._in_order_traversal(node.children[i], result)
                result.append(node.keys[i])
            self._in_order_traversal(node.children[i + 1], result)
        else:
            for key in node.keys:
                result.append(key)
            
    def level_order_traversal(self):
        if not self.root:
            return None
        
        result = []
        queue = [self.root]
        while queue:
            level = []
            flag = True
            temp = None
            while True:
                node = queue.pop(0)
                level.append(node.keys)
                if node.children:
                    for child in node.children:
                        queue.append(child)
                        if flag:
                            temp = child
                            flag = False
                if not queue or queue[0] is temp:
                    break
            result.append(level)
            
        return result

    def is_perfect(self):
        levels = self.level_order_traversal()
        for i in range(len(levels)):
            temp = 0
            for node in levels[i]:
                temp += len(node) + 1
            if i > 0:
                if len(levels[i]) != cnt:
                    return False
            cnt = temp
        return True
        
    def display(self):
        levels = self.level_order_traversal()
        print('-'*20 + '2-3-4 Tree' + '-'*20)
        for i in range(len(levels)):
            print(f'level {i} size {len(levels[i])}: ', end='')
            for node in levels[i]:
                print(node, end=' ')
            print()

In [95]:
# Example usage:
ttft = TwoThreeFourTree()
ttft.insert(10)
ttft.insert(5)
ttft.insert(20)
print(ttft.level_order_traversal())
ttft.insert(3)
ttft.insert(7)
ttft.insert(15)
ttft.insert(30)
print(ttft.level_order_traversal())
ttft.insert(11)
ttft.insert(29)
print(ttft.level_order_traversal())
ttft.insert(10)
ttft.insert(2)
print(ttft.level_order_traversal())
print(ttft.is_in_TwoThreeFourTree(15))
ttft.display()
ttft.delete(10)
ttft.delete(0)
print(ttft.level_order_traversal())
ttft.delete(30)
print(ttft.level_order_traversal())
ttft.delete(5)
print(ttft.level_order_traversal())
ttft.display()
print(ttft.is_perfect())

[[[5, 10, 20]]]
[[[10]], [[3, 5, 7], [15, 20, 30]]]
[[[10, 20]], [[3, 5, 7], [11, 15], [29, 30]]]
[[[10]], [[5], [20]], [[2, 3], [7, 10], [11, 15], [29, 30]]]
True
--------------------2-3-4 Tree--------------------
level 0 size 1: [10] 
level 1 size 2: [5] [20] 
level 2 size 4: [2, 3] [7, 10] [11, 15] [29, 30] 
[[[5, 10, 20]], [[2, 3], [7], [11, 15], [29, 30]]]
[[[5, 10, 20]], [[2, 3], [7], [11, 15], [29]]]
[[[3, 10, 20]], [[2], [7], [11, 15], [29]]]
--------------------2-3-4 Tree--------------------
level 0 size 1: [3, 10, 20] 
level 1 size 4: [2] [7] [11, 15] [29] 
True


In [96]:

# Example usage:
import random
from tqdm import tqdm
n = 1 << 15
cnt = 0
whole = []
wrong = []
correct = []

for i in tqdm(range(n)):
    ttft = TwoThreeFourTree()
    nums = []
    for _ in range(1000):
        num = random.randint(1, 10000)
        ttft.insert(num)
        nums.append(num)
    a = nums.copy()
    # whole.append(a)
    for j in range(100):
        num = nums.pop()
        ttft.delete(num)
    if ttft.in_order_traversal() == sorted(nums) and ttft.is_perfect():
        cnt += 1
    else:
        whole.append(a)
        wrong.append(ttft)
        correct.append(sorted(nums))
print(cnt)
print(n)
print(cnt == n)

100%|██████████| 32768/32768 [04:18<00:00, 126.84it/s]

32768
32768
True



