In [1]:
class Node:
    def __init__(self):
        self.keys = []
        self.children = []
        self.parent = None

In [2]:
class BTree:
    def __init__(self, t=2):
        self.root = None                                                          
        self.t = t
        
    def _split_node(self, node):
        t = self.t
        mid = node.keys[t - 1]
        left_node = Node()
        left_node.keys = node.keys[:t - 1]
        right_node = Node()
        right_node.keys = node.keys[t:]
        
        if node.children:
            for i in range(t):
                left_node.children.append(node.children[i])
                right_node.children.append(node.children[i + t])
                node.children[i].parent = left_node
                node.children[i + t].parent = right_node
            
        if node is self.root:
            self.root = Node()
            self.root.keys.append(mid)
            self.root.children.append(left_node)
            self.root.children.append(right_node)
            left_node.parent = self.root
            right_node.parent = self.root
            
        else:
            P = node.parent
            i = P.children.index(node)   
            P.keys.insert(i, mid)
            P.children.pop(i)
            P.children.insert(i, left_node)
            P.children.insert(i + 1, right_node)
            left_node.parent = P
            right_node.parent = P
        
        return left_node, right_node
        
    def insert(self, key):
        if not self.root:
            self.root = Node()
            self.root.keys.append(key)
        else:
            self._insert(self.root, key)
    
    def _insert(self, node, key):
        t = self.t
        i = 0
        for node_key in node.keys:
            if key > node_key:
                i += 1
        if len(node.keys) < 2*t-1:
            if not node.children:
                node.keys.insert(i, key)
            else:
                self._insert(node.children[i], key)
        else:
            left_node, right_node = self._split_node(node)
            if not node.children:
                if i >= 0 and i <= t - 1:
                    self._insert(left_node, key)
                else:
                    self._insert(right_node, key)
            else: 
                if i >= 0 and i <= t - 1:
                    self._insert(left_node.children[i], key)
                else:
                    self._insert(right_node.children[i - t], key)
    
    def _steal(self, node):
        t = self.t
        P = node.parent
        if P:
            j = P.children.index(node)
            if j == 0:
                S_L, S_R = None, P.children[j + 1]
            elif j == len(P.children) - 1:
                S_L, S_R = P.children[j - 1], None
            else:
                S_L, S_R = P.children[j - 1], P.children[j + 1]

            if S_L and len(S_L.keys) > t - 1:
                S_key = S_L.keys.pop()
                P_key = P.keys.pop(j - 1)
                P.keys.insert(j - 1, S_key)
                node.keys.insert(0, P_key)
                if S_L.children:
                    S_child = S_L.children.pop()
                    node.children.insert(0, S_child)
                    S_child.parent = node
                return True
            
            elif S_R and len(S_R.keys) > t - 1:
                S_key = S_R.keys.pop(0)
                P_key = P.keys.pop(j)
                P.keys.insert(j, S_key)
                node.keys.append(P_key)
                if S_R.children:
                    S_child = S_R.children.pop(0)
                    node.children.append(S_child)
                    S_child.parent = node
                return True
        
        return False
            
    def _merge(self, node):
        P = node.parent
        # assume P has at least t-1 keys
        if P:
            j = P.children.index(node)
            if j == 0:
                S_L, S_R = None, P.children[j + 1]
            elif j == len(P.children) - 1:
                S_L, S_R = P.children[j - 1], None
            else:
                S_L, S_R = P.children[j - 1], P.children[j + 1]
                
            if S_L:
                P_key = P.keys.pop(j - 1)
                S_L.keys.append(P_key)
                for key in node.keys:
                    S_L.keys.append(key)
                P.children.pop(j)
                for child in node.children:
                    S_L.children.append(child)
                    child.parent = S_L
                return S_L
                
            elif S_R:
                P_key = P.keys.pop(j)
                S_R.keys.insert(0, P_key)
                for key in node.keys[::-1]:
                    S_R.keys.insert(0, key)
                P.children.pop(j)
                for child in node.children[::-1]:
                    S_R.children.insert(0, child)
                    child.parent = S_R
                return S_R
                
    def delete(self, key):
        self._delete(self.root, key)
        
    def _delete(self, node, key):
        # preprocessing: if we encounter a node before we decsend into it 
        # with only t-1 key, steal or merge to ensure it has at least t keys
        # if we descend into root with 1 key, ensure it has at least 2 keys
        # remark that we only merge root if root has 1 key and 
        # both children have exactly t-1 keys
        if not node: return
        t = self.t
        # (i) root has no children: check if any key in root is to be deleted
        if node is self.root and not node.children:
            for i in range(len(node.keys)):
                if node.keys[i] == key:
                    node.keys.pop(i)
                    break
            if not node.keys:
                self.root = None
                
        elif node is self.root and node.children and len(node.keys) == 1 \
            and len(node.children[0].keys) == t - 1 \
            and len(node.children[1].keys) == t - 1:
            # (ii) root has only 1 key and its two children have only t-1 keys: merge
            self.root = self._merge(node.children[0])
            self.root.parent = None
            self._delete(self.root, key)
            
        # (iii) root has only 1 key and its two children have at least 1 have t keys:
        # if we descend into a node with t-1 key, we can merge or steal from another      
        # so it is just normal case
            
        # (iv) root has more than 1 key: normal case
        
        # (v) node is not root and has only t-1 keys: make it t keys
        elif node is not self.root and len(node.keys) < t:                    
            # adjacent sibling has at least t keys: steal
            if self._steal(node):
                self._delete(node, key)
            # adjacent sibling has only t-1 keys: merge 
            else:
                self._delete(self._merge(node), key)
        
        # now we go to normal case
        else:
            i = 0
            for node_key in node.keys:
                if key > node_key:
                    i += 1
                elif key == node_key:
                    # case 1: node is leaf and has at least t keys: simply remove
                    if not node.children:
                        node.keys.pop(i)
                    # case 2: node is not leaf:
                    # (a) node is internal node and has at least t keys: (i) (ii) (iii)
                    # (b) node is root and has at least 2 keys: (i) (ii) (iii)
                    # (c) node is root and one child has at least t keys: (i) (ii)
                    else:
                        C_L = node.children[i]
                        C_R = node.children[i + 1]
                        # (i) left child has at least t keys: 
                        #     replace key by predecessor in left child 
                        #     and call delete predecessor
                        if len(C_L.keys) > t - 1:
                            predecessor = self._get_in_order_predecessor(C_L)
                            node.keys[i] = predecessor
                            self._delete(C_L, predecessor)
                        # (ii) right child has at least t keys: 
                        #      replace key by successor in right child 
                        #      and call delete successor
                        elif len(C_R.keys) > t - 1:
                            successor = self._get_in_order_successor(C_R)
                            node.keys[i] = successor
                            self._delete(C_R, successor)
                        # (iii) both left and right child have only t-1 key: merge 
                        # we guarantee node has at least 2 keys by preprocessing
                        else:
                            self._delete(self._merge(C_R), key)                    

                    return
           
            if node.children:
                self._delete(node.children[i], key)
    
    def _get_in_order_successor(self, node):
        temp = node
        while temp.children:
            temp = temp.children[0]
        return temp.keys[0]

    def _get_in_order_predecessor(self, node):
        temp = node
        while temp.children:
            temp = temp.children[-1]
        return temp.keys[-1]
    
    def search(self, key):
        node = self._search(self.root, key)
        return node
    
    def _search(self, node, key):
        if not node:
            return None
        i = 0
        for node_key in node.keys:
            if key > node_key:
                i += 1
            elif key == node_key:
                return node
        if not node.children:
            return None      
        else:
            return self._search(node.children[i], key)
                
    def is_in_BTree(self, key):
        return self.search(key) is not None

    def update(self, old_key, key):
        if self.is_in_BTree(old_key):
            self.delete(old_key)
            self.insert(key)
        
    def in_order_traversal(self):
        result = []
        self._in_order_traversal(self.root, result)
        return result
    
    def _in_order_traversal(self, node, result):
        if node.children:
            for i in range(len(node.keys)):
                self._in_order_traversal(node.children[i], result)
                result.append(node.keys[i])
            self._in_order_traversal(node.children[i + 1], result)
        else:
            for key in node.keys:
                result.append(key)
            
    def level_order_traversal(self):
        if not self.root:
            return None
        
        result = []
        queue = [self.root]
        while queue:
            level = []
            flag = True
            temp = None
            while True:
                node = queue.pop(0)
                level.append(node.keys)
                if node.children:
                    for child in node.children:
                        queue.append(child)
                        if flag:
                            temp = child
                            flag = False
                if not queue or queue[0] is temp:
                    break
            result.append(level)
            
        return result

    def is_perfect(self):
        levels = self.level_order_traversal()
        for i in range(len(levels)):
            temp = 0
            for node in levels[i]:
                temp += len(node) + 1
            if i > 0:
                if len(levels[i]) != cnt:
                    return False
            cnt = temp
        return True
        
    def display(self):
        levels = self.level_order_traversal()
        print('-'*20 + 'B-Tree' + '-'*20)
        for i in range(len(levels)):
            print(f'level {i} size {len(levels[i])}: ', end='')
            for node in levels[i]:
                print(node, end=' ')
            print()

In [6]:

# Example usage:
bt = BTree()
bt.insert(10)
bt.insert(5)
bt.insert(20)
print(bt.level_order_traversal())
bt.insert(3)
bt.insert(7)
bt.insert(15)
bt.insert(30)
print(bt.level_order_traversal())
bt.insert(11)
bt.insert(29)
print(bt.level_order_traversal())
bt.insert(10)
bt.insert(2)
print(bt.level_order_traversal())
print(bt.is_in_BTree(15))
bt.display()
bt.delete(10)
bt.delete(0)
print(bt.level_order_traversal())
bt.delete(30)
print(bt.level_order_traversal())
bt.update(5,321)
print(bt.level_order_traversal())
bt.display()
print(bt.is_perfect())
print(bt.in_order_traversal())

[[[5, 10, 20]]]
[[[10]], [[3, 5, 7], [15, 20, 30]]]
[[[10, 20]], [[3, 5, 7], [11, 15], [29, 30]]]
[[[10]], [[5], [20]], [[2, 3], [7, 10], [11, 15], [29, 30]]]
True
--------------------B-Tree--------------------
level 0 size 1: [10] 
level 1 size 2: [5] [20] 
level 2 size 4: [2, 3] [7, 10] [11, 15] [29, 30] 
[[[5, 10, 20]], [[2, 3], [7], [11, 15], [29, 30]]]
[[[5, 10, 20]], [[2, 3], [7], [11, 15], [29]]]
[[[10]], [[3], [20]], [[2], [7], [11, 15], [29, 321]]]
--------------------B-Tree--------------------
level 0 size 1: [10] 
level 1 size 2: [3] [20] 
level 2 size 4: [2] [7] [11, 15] [29, 321] 
True
[2, 3, 7, 10, 11, 15, 20, 29, 321]


In [4]:

# Example usage:
import random
from tqdm import tqdm
n = 1 << 15
cnt = 0
whole = []
wrong = []
correct = []

for i in tqdm(range(n)):
    bt = BTree(4)
    nums = []
    for _ in range(1000):
        num = random.randint(1, 10000)
        bt.insert(num)
        nums.append(num)
    a = nums.copy()
    for j in range(100):
        num = nums.pop()
        bt.delete(num)
    if bt.in_order_traversal() == sorted(nums) and bt.is_perfect():
        cnt += 1
    else:
        whole.append(a)
        wrong.append(bt)
        correct.append(sorted(nums))
print(cnt)
print(n)
print(cnt == n)

  0%|          | 0/32768 [00:00<?, ?it/s]

100%|██████████| 32768/32768 [02:41<00:00, 202.98it/s]

32768
32768
True





In [5]:
bt = BTree(2)
for num in range(1, 16):
    bt.insert(num)
bt.display()

--------------------B-Tree--------------------
level 0 size 1: [4, 8] 
level 1 size 3: [2] [6] [10, 12] 
level 2 size 7: [1] [3] [5] [7] [9] [11] [13, 14, 15] 
