In [1]:
RED = True
BLACK = False

class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.parent = None
        self.color = RED

In [2]:
class RedBlackTree:
    def __init__(self):
        self.root = None

    def _left_rotate(self, x):
        y = x.right
        B = y.left 
        
        x.right = B
        y.left = x
        
        y.parent = x.parent
        x.parent = y
        if B: B.parent = x
            
        return y
    
    def _right_rotate(self, x):
        y = x.left 
        B = y.right
        
        x.left = B 
        y.right = x 
        
        y.parent = x.parent
        x.parent = y
        if B: B.parent = x
        
        return y 
    
    def _left_right_rotate(self, x):
        x.left = self._left_rotate(x.left)
        return self._right_rotate(x)
    
    def _right_left_rotate(self, x):
        x.right = self._right_rotate(x.right)
        return self._left_rotate(x)
                                    
    def insert(self, value):
        if self.root is None:
            # case 1: empty root (other cases see _insert_fixup)
            self.root = Node(value)
            self.root.color = BLACK
        else:
            self._insert(self.root, value)
    
    def _insert(self, node, value):
        if value <= node.value:
            if not node.left:
                node.left = Node(value)
                node.left.parent = node
                self._insert_fixup(node, node.left)
            else:
                self._insert(node.left, value)
        else:
            if not node.right:
                node.right = Node(value)
                node.right.parent = node
                self._insert_fixup(node, node.right)
            else:
                self._insert(node.right, value)

    def _insert_fixup(self, P, C):
        # case 2: parent is BLACK: done
        # case 3: parent is RED
        if P.color is RED:
            # (a) uncle is BLACK: rotation and recolor (parent and grandparent)
            G = P.parent
            U = G.left if P is G.right else G.right
            GG = G.parent
            
            if not U or U.color is BLACK:
                # (i) LL Case
                if P is G.left and C is P.left:
                    new_node = self._right_rotate(G)
                    new_node.color = BLACK if new_node.color is RED else RED
                    new_node.right.color = BLACK if new_node.right.color is RED else RED
            
                # (ii) LR Case
                elif P is G.left and C is P.right:
                    new_node = self._left_right_rotate(G)
                    new_node.color = BLACK if new_node.color is RED else RED
                    new_node.right.color = BLACK if new_node.right.color is RED else RED
                    
                # (iii) RR Case
                elif P is G.right and C is P.right:
                    new_node = self._left_rotate(G)
                    new_node.color = BLACK if new_node.color is RED else RED
                    new_node.left.color = BLACK if new_node.left.color is RED else RED
                    
                # (iv) RL Case
                elif P is G.right and C is P.left:
                    new_node = self._right_left_rotate(G)
                    new_node.color = BLACK if new_node.color is RED else RED
                    new_node.left.color = BLACK if new_node.left.color is RED else RED
                    
                if not GG:
                    self.root = new_node
                else: 
                    if G is GG.left:
                        GG.left = new_node
                    else: GG.right = new_node
            # (b) uncle is RED: recolor (parent, uncle) and  
            # if G is not root, recolor G and recheck G
            else:
                P.color = BLACK
                U.color = BLACK
                if G is not self.root:
                    G.color = RED
                    self._insert_fixup(G.parent, G)         
    
    def delete(self, value):
        self._delete(self.root, value)
        
    def _delete(self, node, value):
        if not node:
            return
        elif value < node.value:
            self._delete(node.left, value)
        elif value > node.value:
            self._delete(node.right, value)
        else:
            # case 1: No Child
            if not node.left and not node.right:
                Child = None
            # case 1: One Child
            elif not node.right and node.left:
                Child = node.left
            elif not node.left and node.right:
                Child = node.right
            # case 3: Two Children
            elif node.left and node.right:
                leaf = self._get_in_order_successor(node.right)
                node.value = leaf.value
                self._delete(node.right, leaf.value)
                return

            if node is self.root:
                self.root = Child
                if Child: Child.parent = None
            else:
                if node is node.parent.left:
                    node.parent.left = Child
                    if Child: Child.parent = node.parent
                else:
                    node.parent.right = Child
                    if Child: Child.parent = node.parent

            # if node is RED: done
            # if node is BLACK and Child is RED: Child becomes BLACK and done
            if node.color is BLACK and (Child and Child.color is RED):
                Child.color = BLACK
            # if node is BLACK and Child is BLACK: Child becomes DOUBLE BLACK, call _delete_fixup 
            elif node.color is BLACK and (not Child or Child.color is BLACK):
                self._delete_fixup(Child, node.parent)


    def _delete_fixup(self, DB, P):
        # since we remove a BLACK node, the Child successor will inherit BLACK to make Black-Height consistent
        # Child: RED -> BLACK, BLACK -> DOUBLE BLACK
        # case 1: DB (DOUBLE BLACK) is root: done
        # terminate condition: DB node is root or node is not DB
        if DB is self.root or (DB and DB.color is RED):
            DB.color = BLACK
        else:
            is_left = True if DB is P.left else False
            S = P.right if is_left else P.left
            C_L = S.left 
            C_R = S.right
            
            # case 2: DB's sibing S is BLACK and S's two children are BLACK:
            # add BLACK to parent P, S becomes RED
            if S.color is BLACK and (not C_L or C_L.color is BLACK) and \
                (not C_R or C_R.color is BLACK):
                S.color = RED
                if P.color is RED:
                    P.color = BLACK           
                else: # P becomes DB
                    self._delete_fixup(P, P.parent)

            # case 3: S is RED: recolor P and S, rotate P towards DB, go to case 2
            elif S.color is RED:
                P.color = RED
                S.color = BLACK
                if is_left:
                    new_node = self._left_rotate(P)
                else:
                    new_node = self._right_rotate(P)
        
                if not new_node.parent:
                    self.root = new_node
                else:
                    if P is new_node.parent.left:
                        new_node.parent.left = new_node
                    else:
                        new_node.parent.right = new_node
                self._delete_fixup(DB, P)

            # case 4: S is BLACK, S's far child is RED:
            # swap color of P and S, rotate P towards DB, set the far child to BLACK
            elif S.color is BLACK and (C_R and C_R.color is RED) and is_left:
                P.color, S.color = S.color, P.color
                new_node = self._left_rotate(P)
                S.right.color = BLACK
                if not new_node.parent:
                    self.root = new_node
                else:
                    if P is new_node.parent.left:
                        new_node.parent.left = new_node
                    else:
                        new_node.parent.right = new_node
                        
            elif S.color is BLACK and (C_L and C_L.color is RED) and not is_left:
                P.color, S.color = S.color, P.color
                new_node = self._right_rotate(P)
                S.left.color = BLACK
                if not new_node.parent:
                    self.root = new_node
                else:
                    if P is new_node.parent.left:
                        new_node.parent.left = new_node
                    else:
                        new_node.parent.right = new_node
            
            # case 5: S is BLACK, S's far child is BLACK, S's near child is RED:
            # set S to RED and the near child to BLACK, rotate P away from DB, go to case 4
            elif S.color is BLACK and (C_L and C_L.color is RED) and \
                (not C_R or C_R.color is BLACK) and is_left:
                S.color = RED
                S.left.color = BLACK
                P.right = self._right_rotate(S)
                self._delete_fixup(DB, P)
                
            elif S.color is BLACK and (C_R and C_R.color is RED) and \
                (not C_L or C_L.color is BLACK) and not is_left:
                S.color = RED
                S.right.color = BLACK
                P.left = self._left_rotate(S)
                self._delete_fixup(DB, P)
            
    def _get_in_order_successor(self, node):
        temp = node
        while temp.left:
            temp = temp.left
        return temp
    
    def search(self, value):
        return self._search(self.root, value)
    
    def _search(self, node, value):
        if not node:
            return None
        elif value < node.value:
            return self._search(node.left, value)
        elif value > node.value:
            return self._search(node.right, value)
        else:
            return node
    
    def is_in_Red_Black_Tree(self, value):
        return self.search(value) is not None
        
    def update(self, old_value, value):
        if self.is_in_Red_Black_Tree(old_value):
            self.delete(old_value)
            self.insert(value)
        
    def in_order_traversal(self):
        result = []
        self._in_order_traversal(self.root, result)
        return result
    
    def _in_order_traversal(self, node, result):
        # When you pass a list to a function, you are passing a reference to the same list object, not a copy of it. 
        # As a result, any modifications made to the list inside the function will affect the original list outside the function.
        if node:
            self._in_order_traversal(node.left, result)
            result.append(node.value)
            self._in_order_traversal(node.right, result)

    def is_Red_Black_Tree(self):
        if not self.root:
            return True
        
        # root must be black
        if self.root.color is not BLACK:
            return False
        
        # black_count is a list (mutable object) with one item
        # if black_count is integer, then we cannot modify it within the inner function
        # because the inner function with create a new variable for immutable object
        # this will cause 'UnboundLocalError: local variable referenced before assignment'
        # using list can fix this since it is mutable 
        black_count = [0]
        
        def dfs(node, count):
            # if we reach NIL node, verify Black-Height Balance property
            if not node:
                # encounter NIL for the first time, record it
                if black_count[0] == 0:
                    black_count[0] = count
                elif black_count[0] != count:
                    return False
                return True

            # No Red-Red Adjacency
            if node.color is RED:
                if (node.left and node.left.color is RED) and \
                    (node.right and node.right.color is RED):
                    return False
                
            current_count = count + (1 if node.color is BLACK else 0)
            
            # verify left and right sub-trees recursively
            return dfs(node.left, current_count) and dfs(node.right, current_count)

        return dfs(self.root, 0)

In [3]:
# Example usage:
rbt = RedBlackTree()
rbt.insert(10)
rbt.insert(5)
rbt.insert(20)
rbt.insert(3)
rbt.insert(7)
rbt.insert(15)
rbt.insert(30)
rbt.insert(11)
rbt.insert(29)
rbt.insert(10)
rbt.insert(2)
print(rbt.in_order_traversal())
rbt.delete(10)
print(rbt.in_order_traversal())
print(rbt.is_Red_Black_Tree())
print(rbt.is_in_Red_Black_Tree(11))
rbt.update(10,99)
print(rbt.in_order_traversal())

[2, 3, 5, 7, 10, 10, 11, 15, 20, 29, 30]
[2, 3, 5, 7, 10, 11, 15, 20, 29, 30]
True
True
[2, 3, 5, 7, 11, 15, 20, 29, 30, 99]


In [4]:
# Example usage:
import random
n = 1 << 14
cnt = 0

for i in range(n):
    rbt = RedBlackTree()
    nums = []
    for _ in range(1000):
        num = random.randint(1, 10000)
        rbt.insert(num)
        nums.append(num)
    # for j in range(100):
    #     num = nums.pop()
    #     rbt.delete(num)
    if rbt.is_Red_Black_Tree() and rbt.in_order_traversal() == sorted(nums):
        cnt += 1
print(cnt)
print(cnt == n)

16384
True
