# Red Black Tree

## Tree Implementations

In [2]:
# Constant identifiers for BLACK and RED
BLACK = 0
RED = 1

In [3]:
# Node Data Structure
class Node:
    def __init__(self, value, color=RED):
        self.value = value
        self.color = color
        
        self.parent = None
        self.left_child = None
        self.right_child = None

In [4]:
# Binary Tree Data Structure
class BinaryTree:
    def __init__(self):
        self.root = None

    def search(self, key):
        return self._search(key, self.root, "")

    def _search(self, key, node: Node, path):
        if node is None:
            return "not present"
        
        if key < node.value:
            path += "l"
            return self._search(key, node.left_child, path)
        elif key > node.value:
            path += "r"
            return self._search(key, node.right_child, path)
        else:
            return node, path

    def insert(self, value):
        if self.root is None:
            self.root = Node(value)
        else:
            self._insert(value, self.root)
    
    def _insert(self, value, node: Node):
        if value < node.value:
            if node.left_child is None:
                node.left_child = Node(value)
                node.left_child.parent = node
            else:
                self._insert(value, node.left_child)
        else:
            if node.right_child is None:
                node.right_child = Node(value)
                node.right_child.parent = node
            else:
                self._insert(value, node.right_child)
        
    def delete(self, key):
        self._delete(key, self.root)

    def _delete(self, key, node: Node):
        if key < node.value:
            self._delete(key, node.left_child)
        elif key > node.value:
            self._delete(key, node.right_child)
        else:
            # node doesn't exist
            if node is None:
                return
            
            # delete case 1: no children
            if node.left_child is None and node.right_child is None:
                if node.parent is None:
                    self.root = None
                elif node.parent.left_child == node:
                    node.parent.left_child = None
                else:
                    node.parent.right_child = None

            # delete case 2a: one child (right)
            elif node.left_child is None:
                if node.parent is None:
                    self.root = node.right_child
                elif node.parent.left_child == node:
                    node.parent.left_child = node.right_child
                else:
                    node.parent.right_child = node.right_child
                node.right_child.parent = node.parent

            # delete case 2b: one child (left)
            elif node.right_child is None:
                if node.parent is None:
                    self.root = node.left_child
                elif node.parent.left_child == node:
                    node.parent.left_child = node.left_child
                else:
                    node.parent.right_child = node.left_child
                node.left_child.parent = node.parent

            # delete case 3: two children
            else:
                # find smallest in right subtree
                min_right_subtree = self._min(node.right_child)
                node.value = min_right_subtree.value
                self._delete(min_right_subtree.value, node.right_child)

    def _min(self, node):
        current = node
        while current.left_child is not None:
            current = current.left_child
        return current

    def print_tree(self):
        self._print_tree(self.root, 0)

    def _print_tree(self, node, level):
        if node is None:
            return
        if node is not None:
            self._print_tree(node.right_child, level + 1)
            print(f"{'----'*level}{str(node.value)}")
            self._print_tree(node.left_child, level + 1)

In [14]:
# Red-Black Tree Data Structure
class RedBlackTree:
    def __init__(self):
        self.root = None

    def search(self, key):
        return self._search(key, self.root, "", 0, 0)

    def _search(self, key, node: Node, path, black_travelled, red_travelled):
        if node is None:
            return None
        
        if node.color == BLACK:
            black_travelled += 1
        else:
            red_travelled += 1

        if key < node.value:
            path += "l"
            return self._search(key, node.left_child, path, black_travelled, red_travelled)
        elif key > node.value:
            path += "r"
            return self._search(key, node.right_child, path, black_travelled, red_travelled)
        else:
            return node, path, black_travelled, red_travelled

    def _rotate_left(self, n: Node):
        y = n.right_child

        # left(y) becomes N's new right child
        n.right_child = y.left_child
        if y.left_child is not None:
            y.left_child.parent = n

        # fix y's parent
        y.parent = n.parent
        if n.parent is None:
            self.root = y
        elif n == n.parent.left_child:
            n.parent.left_child = y
        else:
            n.parent.right_child = y 
        y.left_child = n

        # y becomes N's parent
        n.parent = y

    def _rotate_right(self, n: Node):
        y = n.left_child

        # left(y) becomes N's new right child
        n.left_child = y.right_child
        if y.right_child is not None:
            y.right_child.parent = n

        # fix y's parent
        y.parent = n.parent
        if n.parent is None:
            self.root = y
        elif n == n.parent.right_child:
            n.parent.right_child = y
        else:
            n.parent.left_child = y 
        y.right_child = n

        # y becomes N's parent
        n.parent = y

    def insert(self, value):
        if self.root is None:  # insertion case 1
            new_node = Node(value, RED)
            self.root = new_node
        else:
            new_node = self._insert(value, self.root)
        self._insert_fix(new_node)
    
    def _insert(self, value, node: Node):
        if value < node.value:
            if node.left_child is None:
                new_node = Node(value, RED)
                node.left_child = new_node
                node.left_child.parent = node
                return new_node
            else:
                return self._insert(value, node.left_child)
        else:
            if node.right_child is None:
                new_node = Node(value, RED)
                node.right_child = new_node
                node.right_child.parent = node
                return new_node
            else:
                return self._insert(value, node.right_child)
            
    def _insert_fix(self, node: Node):
        while node.parent != None and node.parent.color == RED:
            parent, grandparent = node.parent, node.parent.parent

            if parent == grandparent.left_child:  # if parent is left child
                uncle = grandparent.right_child
                if uncle is not None and uncle.color == RED:  # if uncle is red, insert case 2
                    parent.color = BLACK
                    uncle.color = BLACK
                    grandparent.color = RED
                    node = grandparent  # propogate upwards
                else:
                    if node == parent.right_child:  # insert case 3
                        node = node.parent
                        parent, grandparent = node.parent, node.parent.parent
                        self._rotate_left(node)

                    parent.color == BLACK  # insert case 4
                    grandparent.color = RED
                    self._rotate_right(grandparent)
            else:  # if parent is right child
                uncle = grandparent.left_child
                if uncle is not None and uncle.color == RED:  # if uncle is red, insert case 2
                    parent.color = BLACK
                    uncle.color = BLACK
                    grandparent.color = RED
                    node = grandparent  # propogate upwards
                else:
                    if node == parent.left_child:  # insert case 3
                        node = node.parent
                        parent, grandparent = node.parent, node.parent.parent
                        self._rotate_right(node)
                        
                    parent.color = BLACK  # insert case 4
                    grandparent.color = RED
                    self._rotate_left(grandparent)
            if node == self.root:
                break
        self.root.color = BLACK
        
    def delete(self, key):
        node_to_remove = self._delete(key, self.root)
        if node_to_remove is not None:
            self._delete_fix(node_to_remove)
    
    def _delete(self, key, node):
        if node is None:
            return None
    
        # find node
        if key < node.value:
            return self._delete(key, node.left_child)
        elif key > node.value:
            return self._delete(key, node.right_child)
        else:
            # node has 2 children
            if node.left_child and node.right_child:
                right_min = self._min(node.right_child)
                # replace node with successor
                node.value = right_min.value
                # delete old successor node
                return self._delete(right_min.value, node.right_child)
            else:
                # select child to replace
                if node.left_child:
                    child = node.left_child
                elif node.right_child:
                    child = node.right_child
                else:
                    child = None
    
                # replace node with child
                if child is not None:
                    child.parent = node.parent

                # if deleted node was root 
                if node.parent is None:
                    self.root = child

                # assign new parent to children
                elif node == node.parent.left_child:
                    node.parent.left_child = child
                else:
                    node.parent.right_child = child
    
                return child

    def _min(self, node):
        current = node
        while current.left_child is not None:
            current = current.left_child
        return current

    def _delete_fix(self, node):
        while node != self.root and node.color == BLACK:
            
            # node is left child
            if node == node.parent.left_child:
                sibling = node.parent.right_child

                # case 1
                if sibling.color == RED:
                    sibling.color = BLACK
                    node.parent.color = RED
                    self._rotate_left(node.parent)
                    sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child
                
                # case 2a
                if (sibling.left_child is None or sibling.left_child.color == BLACK) and (sibling.right_child is None or sibling.right_child.color == BLACK):
                    sibling.color = RED
                    node = node.parent

                else:
                    # case 2b
                    if sibling.right_child is None or sibling.right_child.color == BLACK:
                        sibling.left_child.color = BLACK
                        sibling.color = RED
                        self.rotate_right(sibling)
                        sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child
                    
                    # case 2c
                    sibling.color = node.parent.color
                    node.parent.color = BLACK
                    if sibling.right_child:
                        sibling.right_child.color = BLACK
                    self._rotate_left(node.parent)
                    node = self.root
            
            # node is right child
            else:
                sibling = node.parent.left_child

                # case 1
                if sibling.color == RED:
                    sibling.color = BLACK
                    node.parent.color = RED
                    self._rotate_right(node.parent)
                    sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child

                # case 2a
                if (sibling.left_child is None or sibling.left_child.color == BLACK) and (sibling.right_child is None or sibling.right_child.color == BLACK):
                    sibling.color = RED
                    node = node.parent

                else:
                    # case 2b
                    if sibling.left_child is None or sibling.left_child.color == BLACK:
                        sibling.right_child.color = BLACK
                        sibling.color = RED
                        self._rotate_left(sibling)
                        sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child

                    # case 2c
                    sibling.color = node.parent.color
                    node.parent.color = BLACK
                    if sibling.left_child:
                        sibling.left_child.color = BLACK
                    self._rotate_right(node.parent)
                    node = self.root
        
        node.color = BLACK
        
    def print_tree(self):
        self._print_tree(self.root, 0)

    def _print_tree(self, node, level):
        if node is None:
            return
        if node is not None:
            self._print_tree(node.right_child, level + 1)
            print(f"{'----'*level}{str(node.value)} ({'BLACK' if node.color == BLACK else 'RED'})")
            self._print_tree(node.left_child, level + 1)

## Test the Binary Tree and Red-Black Tree Classes

In [6]:
# Binary Tree Testing

tree = BinaryTree()

tree.insert(10)
tree.insert(5)
tree.insert(15)
tree.insert(3)
tree.insert(7)
tree.insert(12)
tree.insert(18)
tree.insert(14)
tree.insert(9)
tree.insert(8)
tree.insert(4)
tree.insert(2)
tree.print_tree()

print()
print("AFTER REMOVING 10")
print()

tree.delete(10)
# print(f"tree val: {node.value}\ntree left: {node.left_child}\ntree right: {node.right_child.value}\ntree parent: {node.parent.value}")

tree.print_tree()
# print()
# print(f"8: {tree.search(8)[1]}")
# print(f"18: {tree.search(18)[1]}")
# print(f"4: {tree.search(4)[1]}")

--------18
----15
------------14
--------12
10
------------9
----------------8
--------7
----5
------------4
--------3
------------2

AFTER REMOVING 10

--------18
----15
--------14
12
------------9
----------------8
--------7
----5
------------4
--------3
------------2


In [19]:
# Red Black Tree Testing

RBtree = RedBlackTree()
RBtree.insert(10)
RBtree.insert(5)
RBtree.insert(15)
RBtree.insert(3)
RBtree.insert(7)
RBtree.insert(12)
RBtree.insert(18)
RBtree.insert(28)
RBtree.insert(64)
RBtree.insert(23)
RBtree.insert(194)
RBtree.print_tree()

print()
print(f"{RBtree.search(10)}")
print(f"{RBtree.search(28)}")

print()
print(f"{RBtree.search(64)}")
print(f"{RBtree.search(18)}")
print(f"{RBtree.search(12)}")
print(f"{RBtree.search(5)}")

print()
print(f"{RBtree.search(194)}")
print(f"{RBtree.search(23)}")
print(f"{RBtree.search(7)}")
print(f"{RBtree.search(3)}")

print()
RBtree.delete(10)
RBtree.delete(28)
RBtree.delete(15)
RBtree.delete(12)
RBtree.delete(23)
RBtree.print_tree()

print()
print(f"{RBtree.search(194)}")
print(f"{RBtree.search(18)}")

------------194 (RED)
--------64 (BLACK)
----28 (RED)
------------23 (RED)
--------18 (BLACK)
15 (BLACK)
--------12 (BLACK)
----10 (RED)
------------7 (RED)
--------5 (BLACK)
------------3 (RED)

(<__main__.Node object at 0x10a272b10>, 'l', 1, 1)
(<__main__.Node object at 0x10a339dd0>, 'r', 1, 1)

(<__main__.Node object at 0x10a339910>, 'rr', 2, 1)
(<__main__.Node object at 0x10a338f50>, 'rl', 2, 1)
(<__main__.Node object at 0x10a33b0d0>, 'lr', 2, 1)
(<__main__.Node object at 0x10a2996d0>, 'll', 2, 1)

(<__main__.Node object at 0x10a33b950>, 'rrr', 2, 2)
(<__main__.Node object at 0x10a33a490>, 'rlr', 2, 2)
(<__main__.Node object at 0x10a3395d0>, 'llr', 2, 2)
(<__main__.Node object at 0x10a0aa990>, 'lll', 2, 2)

----194 (BLACK)
64 (BLACK)
----18 (BLACK)
------------7 (RED)
--------5 (BLACK)
------------3 (RED)

(<__main__.Node object at 0x10a33b950>, 'r', 2, 0)
(<__main__.Node object at 0x10a0ab0d0>, 'l', 2, 0)


## Efficiency Testing

In [None]:
# Import modules for testing

# timeit from: https://docs.python.org/3/library/timeit.html
import timeit
import random

In [47]:
# Create random datasets
data_10000 = random.sample(range(0, 100000), 10000)
data_10000_remove = random.sample(data_10000, 5000)
data_10000_remaining = [item for item in data_10000 if item not in data_10000_remove]
data_10000_find = random.sample(data_10000_remaining, 2500)

In [48]:
# Initialize Binary Tree and Red-Black Tree

bt = BinaryTree()
rbt = RedBlackTree()

# insert, delete, search
time_bt = [0] * 3
time_rbt = [0] * 3

In [59]:
# Measure insert time

# Binary Tree
start_time = time.time()
for item in data_10000:
    bt.insert(item)
end_time = time.time()

time_bt[0] = round(end_time - start_time, 3)

print(time_bt)

[0.24, 0, 0]
