# Red Black Tree

## Tree Implementations

In [261]:
# Constant identifiers for BLACK and RED
BLACK = 0
RED = 1

In [262]:
# Node Data Structure
class Node:
    def __init__(self, value, color=RED):
        self.value = value
        self.color = color
        
        self.parent = None
        self.left_child = None
        self.right_child = None

In [263]:
# Binary Tree Data Structure
class BinaryTree:
    def __init__(self):
        self.root = None

    def search(self, key):
        return self._search(key, self.root, "")

    def _search(self, key, node: Node, path):
        if node is None:
            return "not present"
        
        if key < node.value:
            path += "l"
            return self._search(key, node.left_child, path)
        elif key > node.value:
            path += "r"
            return self._search(key, node.right_child, path)
        else:
            return node, path

    def insert(self, value):
        if self.root is None:
            self.root = Node(value)
        else:
            self._insert(value, self.root)
    
    def _insert(self, value, node: Node):
        if value < node.value:
            if node.left_child is None:
                node.left_child = Node(value)
                node.left_child.parent = node
            else:
                self._insert(value, node.left_child)
        else:
            if node.right_child is None:
                node.right_child = Node(value)
                node.right_child.parent = node
            else:
                self._insert(value, node.right_child)
        
    def delete(self, key):
        self._delete(key, self.root)

    def _delete(self, key, node: Node):
        if key < node.value:
            self._delete(key, node.left_child)
        elif key > node.value:
            self._delete(key, node.right_child)
        else:
            # node doesn't exist
            if node is None:
                return
            
            # delete case 1: no children
            if node.left_child is None and node.right_child is None:
                if node.parent is None:
                    self.root = None
                elif node.parent.left_child == node:
                    node.parent.left_child = None
                else:
                    node.parent.right_child = None

            # delete case 2a: one child (right)
            elif node.left_child is None:
                if node.parent is None:
                    self.root = node.right_child
                elif node.parent.left_child == node:
                    node.parent.left_child = node.right_child
                else:
                    node.parent.right_child = node.right_child
                node.right_child.parent = node.parent

            # delete case 2b: one child (left)
            elif node.right_child is None:
                if node.parent is None:
                    self.root = node.left_child
                elif node.parent.left_child == node:
                    node.parent.left_child = node.left_child
                else:
                    node.parent.right_child = node.left_child
                node.left_child.parent = node.parent

            # delete case 3: two children
            else:
                # find smallest in right subtree
                min_right_subtree = self._min(node.right_child)
                node.value = min_right_subtree.value
                self._delete(min_right_subtree.value, node.right_child)

    def _min(self, node):
        current = node
        while current.left_child is not None:
            current = current.left_child
        return current

    def print_tree(self):
        self._print_tree(self.root, 0)

    def _print_tree(self, node, level):
        if node is None:
            return
        if node is not None:
            self._print_tree(node.right_child, level + 1)
            print(f"{'----'*level}{str(node.value)}")
            self._print_tree(node.left_child, level + 1)

In [264]:
# Red-Black Tree Data Structure
class RedBlackTree:
    def __init__(self):
        self.root = None

    def search(self, key):
        return self._search(key, self.root, "", 0, 0)

    def _search(self, key, node: Node, path, black_travelled, red_travelled):
        if node is None:
            return None
        
        if node.color == BLACK:
            black_travelled += 1
        else:
            red_travelled += 1

        if key < node.value:
            path += "l"
            return self._search(key, node.left_child, path, black_travelled, red_travelled)
        elif key > node.value:
            path += "r"
            return self._search(key, node.right_child, path, black_travelled, red_travelled)
        else:
            return node, path, black_travelled, red_travelled

    def _rotate_left(self, n: Node):
        y = n.right_child

        # left(y) becomes N's new right child
        n.right_child = y.left_child
        if y.left_child is not None:
            y.left_child.parent = n

        # fix y's parent
        y.parent = n.parent
        if n.parent is None:
            self.root = y
        elif n == n.parent.left_child:
            n.parent.left_child = y
        else:
            n.parent.right_child = y 
        y.left_child = n

        # y becomes N's parent
        n.parent = y

    def _rotate_right(self, n: Node):
        y = n.left_child

        # left(y) becomes N's new right child
        n.left_child = y.right_child
        if y.right_child is not None:
            y.right_child.parent = n

        # fix y's parent
        y.parent = n.parent
        if n.parent is None:
            self.root = y
        elif n == n.parent.right_child:
            n.parent.right_child = y
        else:
            n.parent.left_child = y 
        y.right_child = n

        # y becomes N's parent
        n.parent = y

    def insert(self, value):
        # case 1
        if self.root is None:
            new_node = Node(value, RED)
            self.root = new_node
        else:
            new_node = self._insert(value, self.root)
        self._insert_fix(new_node)
    
    def _insert(self, value, node: Node):
        if value < node.value:
            if node.left_child is None:
                new_node = Node(value, RED)
                node.left_child = new_node
                node.left_child.parent = node
                return new_node
            else:
                return self._insert(value, node.left_child)
        else:
            if node.right_child is None:
                new_node = Node(value, RED)
                node.right_child = new_node
                node.right_child.parent = node
                return new_node
            else:
                return self._insert(value, node.right_child)
            
    def _insert_fix(self, node: Node):
        while node.parent is not None and node.parent.color == RED:
            # get parent and parent
            parent = node.parent
            if parent is None:
                grandparent = None
            else:
                grandparent = parent.parent

            # parent is left child
            if parent == grandparent.left_child:
                uncle = grandparent.right_child if parent == grandparent.left_child else grandparent.left_child
                
                # case 2
                if uncle is not None and uncle.color == RED:
                    parent.color = BLACK
                    uncle.color = BLACK
                    grandparent.color = RED
                    node = grandparent
                
                else:
                    # case 3
                    if node == parent.right_child:  
                        node = node.parent
                        self._rotate_left(node)
                        parent, grandparent = node.parent, node.parent.parent  # TODO to check

                    # case 4
                    parent.color = BLACK
                    grandparent.color = RED
                    self._rotate_right(grandparent)

            # parent is right child
            else:
                uncle = grandparent.right_child if parent == grandparent.left_child else grandparent.left_child
                
                # case 2
                if uncle is not None and uncle.color == RED:
                    parent.color = BLACK
                    uncle.color = BLACK
                    grandparent.color = RED
                    node = grandparent
                
                else:
                    # case 3
                    if node == parent.left_child:
                        node = node.parent
                        self._rotate_right(node)
                        parent, grandparent = node.parent, node.parent.parent  # TODO to check
                        
                    # case 4
                    parent.color = BLACK
                    grandparent.color = RED
                    self._rotate_left(grandparent)

            if node == self.root:
                break
        
        self.root.color = BLACK
        
    def delete(self, key):
        node_to_remove = self._delete(key, self.root)
        if node_to_remove is not None:
            self._delete_fix(node_to_remove)
    
    def _delete(self, key, node):
        if node is None:
            return None
    
        # find node
        if key < node.value:
            return self._delete(key, node.left_child)
        elif key > node.value:
            return self._delete(key, node.right_child)
        else:
            # node has 2 children
            if node.left_child and node.right_child:
                right_min = self._min(node.right_child)
                # replace node with successor
                node.value = right_min.value
                # delete old successor node
                return self._delete(right_min.value, node.right_child)
            else:
                # select child to replace
                if node.left_child:
                    child = node.left_child
                elif node.right_child:
                    child = node.right_child
                else:
                    child = None
    
                # replace node with child
                if child is not None:
                    child.parent = node.parent

                # if deleted node was root 
                if node.parent is None:
                    self.root = child

                # assign new parent to children
                elif node == node.parent.left_child:
                    node.parent.left_child = child
                else:
                    node.parent.right_child = child
    
                return child

    def _min(self, node):
        current = node
        while current.left_child is not None:
            current = current.left_child
        return current

    def _delete_fix(self, node):
        while node != self.root and node.color == BLACK:
            
            # node is left child
            if node == node.parent.left_child:
                sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child
                if sibling is None: break

                # case 1
                if sibling.color == RED:
                    sibling.color = BLACK
                    node.parent.color = RED
                    self._rotate_left(node.parent)
                    sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child
                    if sibling is None: break
                
                # case 2a
                if (sibling.left_child is None or sibling.left_child.color == BLACK) and (sibling.right_child is None or sibling.right_child.color == BLACK):
                    sibling.color = RED
                    node = node.parent

                else:
                    # case 2b
                    if sibling.right_child is None or sibling.right_child.color == BLACK:
                        sibling.left_child.color = BLACK
                        sibling.color = RED
                        self._rotate_right(sibling)
                        sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child
                        if sibling is None: break
                    
                    # case 2c
                    sibling.color = node.parent.color
                    node.parent.color = BLACK
                    if sibling.right_child:
                        sibling.right_child.color = BLACK
                    self._rotate_left(node.parent)
                    node = self.root
            
            # node is right child
            else:
                sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child
                if sibling is None: break

                # case 1
                if sibling.color == RED:
                    sibling.color = BLACK
                    node.parent.color = RED
                    self._rotate_right(node.parent)
                    sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child
                    if sibling is None: break

                # case 2a
                if (sibling.left_child is None or sibling.left_child.color == BLACK) and (sibling.right_child is None or sibling.right_child.color == BLACK):
                    sibling.color = RED
                    node = node.parent

                else:
                    # case 2b
                    if sibling.left_child is None or sibling.left_child.color == BLACK:
                        sibling.right_child.color = BLACK
                        sibling.color = RED
                        self._rotate_left(sibling)
                        sibling = node.parent.right_child if node == node.parent.left_child else node.parent.left_child
                        # if sibling is None: break

                    # case 2c
                    sibling.color = node.parent.color
                    node.parent.color = BLACK
                    if sibling.left_child:
                        sibling.left_child.color = BLACK
                    self._rotate_right(node.parent)
                    node = self.root

        node.color = BLACK
        
    def print_tree(self):
        self._print_tree(self.root, 0)

    def _print_tree(self, node, level):
        if node is None:
            return
        if node is not None:
            self._print_tree(node.right_child, level + 1)
            print(f"{'----'*level}{str(node.value)} ({'BLACK' if node.color == BLACK else 'RED'})")
            self._print_tree(node.left_child, level + 1)

## Test the Binary Tree and Red-Black Tree Classes

In [316]:
# Binary Tree Testing

tree = BinaryTree()

tree.insert(10)
tree.insert(5)
tree.insert(15)
tree.insert(3)
tree.insert(7)
tree.insert(12)
tree.insert(18)
tree.insert(14)
tree.insert(9)
tree.insert(8)
tree.insert(4)
tree.insert(17)
tree.insert(23)
tree.insert(26)
tree.insert(58)
tree.insert(40)
tree.insert(2)
tree.print_tree()

print()
print()

tree.delete(12)
tree.delete(14)
tree.delete(17)
tree.delete(40)
tree.delete(9)
tree.delete(8)
tree.delete(7)
tree.delete(4)
tree.print_tree()

--------------------58
------------------------40
----------------26
------------23
--------18
------------17
----15
------------14
--------12
10
------------9
----------------8
--------7
----5
------------4
--------3
------------2


--------------------58
----------------26
------------23
--------18
----15
10
----5
--------3
------------2


In [317]:
# Red Black Tree Testing

RBtree = RedBlackTree()

RBtree.insert(10)
RBtree.insert(5)
RBtree.insert(15)
RBtree.insert(3)
RBtree.insert(7)
RBtree.insert(12)
RBtree.insert(18)
RBtree.insert(14)
RBtree.insert(9)
RBtree.insert(8)
RBtree.insert(4)
RBtree.insert(17)
RBtree.insert(23)
RBtree.insert(26)
RBtree.insert(58)
RBtree.insert(40)
RBtree.insert(2)
RBtree.print_tree()

print()
print()

RBtree.delete(12)
RBtree.delete(14)
RBtree.delete(17)
RBtree.delete(40)
RBtree.delete(9)
RBtree.delete(8)
RBtree.delete(7)
RBtree.delete(4)
RBtree.print_tree()

------------58 (BLACK)
----------------40 (RED)
--------26 (RED)
------------23 (BLACK)
----18 (BLACK)
------------17 (BLACK)
--------15 (RED)
----------------14 (RED)
------------12 (BLACK)
10 (BLACK)
------------9 (RED)
--------8 (BLACK)
------------7 (RED)
----5 (BLACK)
------------4 (RED)
--------3 (BLACK)
------------2 (RED)


------------58 (BLACK)
--------26 (RED)
------------23 (BLACK)
----18 (BLACK)
--------15 (RED)
10 (BLACK)
----5 (BLACK)
--------3 (BLACK)
------------2 (RED)


## Efficiency Testing

In [302]:
# Import modules for testing

import time
import random

In [309]:
# Create random datasets
data = random.sample(range(0, 100000), 100000)
data_remove = random.sample(data, 50000)
data_remaining = [item for item in data if item not in data_remove]
data_find = random.sample(data_remaining, 30000)

In [310]:
# Initialize Binary Tree and Red-Black Tree

bt = BinaryTree()
rbt = RedBlackTree()

# insert, delete, search
time_bt = [0] * 3
time_rbt = [0] * 3

In [311]:
# Insert

# Binary Tree
start_time = time.time()
[bt.insert(item) for item in data]
end_time = time.time()
time_bt[0] = round(end_time - start_time, 3)
print(time_bt)

# Red-Black Tree
start_time = time.time()
[rbt.insert(item) for item in data]
end_time = time.time()
time_rbt[0] = round(end_time - start_time, 3)
print(time_rbt)

[3.087, 0, 0]
[2.757, 0, 0]


In [312]:
# Delete

# Binary Tree
start_time = time.time()
[bt.delete(item) for item in data_remove]
end_time = time.time()
time_bt[1] = round(end_time - start_time, 3)
print(time_bt)

# Red-Black Tree
start_time = time.time()
[rbt.delete(item) for item in data_remove]
end_time = time.time()
time_rbt[1] = round(end_time - start_time, 3)
print(time_rbt)

[3.087, 2.528, 0]
[2.757, 1.638, 0]


In [315]:
# Search

# Binary Tree
start_time = time.time()
[bt.search(item) for item in data_find]
end_time = time.time()
time_bt[2] = round(end_time - start_time, 3)
print(time_bt)

# Red-Black Tree
start_time = time.time()
[rbt.search(item) for item in data_find]
end_time = time.time()
time_rbt[2] = round(end_time - start_time, 3)
print(time_rbt)

[3.087, 2.528, 0.935]
[2.757, 1.638, 0.772]
