In [1]:
from tree_utils import printTree as ptree
from tree_classes import Value, RedBlackNode
from tree_classes import RED, BLACK, RIGHT, LEFT, NULL

In [2]:
# to make the BinaryTree class availabe in this notebook
%run bst.ipynb

In [3]:
class RedBlackTree(BinaryTree):
    def __init__(self):
        super().__init__()
        self.size = 0    

    def insert(self, value):
        if self.root is None:
            self.root = RedBlackNode(value, side = None, parent = None, color = BLACK)
            self.size += 1
            return self.root
        
        node = RedBlackNode(value = value, color = RED)
        node = self._insert(self.root, node)
        self.size += 1
        self._insertFix(node)
        return node

    def _insert(self, parent: RedBlackNode, node: RedBlackNode):
        if node < parent:
            if not parent.leftChild:
                node.parent = parent
                node.side = LEFT
                parent.leftChild = node
            else:
                self._insert(parent.leftChild, node)
        else:
            if not parent.rightChild:
                node.parent = parent
                node.side = RIGHT
                parent.rightChild = node
            else:
                self._insert(parent.rightChild, node)
        return node
    
    def _leftRotate(self, node: RedBlackNode):
        rightChild = node.rightChild
        if not rightChild:
            raise Exception('invalid left rotate')
        node.rightChild = rightChild.leftChild        
        if rightChild.leftChild:
            rightChild.leftChild.parent = node
            node.rightChild.side = RIGHT
        rightChild.parent = node.parent
        if node.isRoot:
            self.root = rightChild
        elif node.isLeft:
            node.parent.leftChild = rightChild
        else:
            node.parent.rightChild = rightChild
        rightChild.leftChild = node
        rightChild.side = node.side
        node.parent = rightChild
        node.side = LEFT

    def _rightRotate(self, node: RedBlackNode):
        leftChild = node.leftChild
        if not leftChild:
            raise Exception('invalid right rotate')
        node.leftChild = leftChild.rightChild        
        if leftChild.rightChild:
            leftChild.rightChild.parent = node
            node.leftChild.side = LEFT
        leftChild.parent = node.parent
        if node.isRoot:
            self.root = leftChild
        elif node.isLeft:
            node.parent.leftChild = leftChild
        else:
            node.parent.rightChild = leftChild
        leftChild.rightChild = node
        leftChild.side = node.side
        node.parent = leftChild
        node.side = RIGHT

    def _insertFix(self, node:RedBlackNode):
        if node.isRoot:
            node.color = BLACK
            return
        if node.parent.color == BLACK:
            return
        if node.aunt and node.aunt.color == RED:
            node.parent.color = BLACK
            node.aunt.color = BLACK
            node.grandParent.color = RED
            self._insertFix(node.grandParent)
            return
        if node.parent.isLeft:
            if node.isRight:
                self._leftRotate(node.parent)
                self._insertFix(node.leftChild)
                return
            node.parent.color = BLACK
            node.grandParent.color = RED
            self._rightRotate(node.grandParent)
            return
        if node.isLeft:
            self._rightRotate(node.parent)
            self._insertFix(node.rightChild)
            return
        node.parent.color = BLACK
        node.grandParent.color = RED
        self._leftRotate(node.grandParent)

    def delete(self, key):
        if not self.root:
            return
        deleteNode = self.search(key)
        color = deleteNode.color
        if not deleteNode:
            return
        if deleteNode.isLeaf:
            self._transplant(deleteNode, None)
            return
        elif len(deleteNode.children) == 1:
            deleteNodeChild = deleteNode.children[0]
            self._transplant(deleteNode, deleteNodeChild)
            if color == BLACK:
                self._delete_fix(deleteNodeChild)
            return
        leftMostNode = self._inOrderTraversal(deleteNode.rightChild)[0]
        color = leftMostNode.color
        leftMostNodeRightChild = leftMostNode.rightChild
        if leftMostNode.parent != deleteNode:
            self._transplant(leftMostNode, leftMostNode.rightChild)
            leftMostNode.rightChild = deleteNode.rightChild
            leftMostNode.rightChild.parent = leftMostNode
        self._transplant(deleteNode, leftMostNode)
        leftMostNode.leftChild = deleteNode.leftChild
        leftMostNode.leftChild.parent = leftMostNode
        leftMostNode.parent = deleteNode.parent
        leftMostNode.side = deleteNode.side
        if color == BLACK and leftMostNodeRightChild:
            self._delete_fix(leftMostNodeRightChild)

    def _delete_fix(self, node:RedBlackNode):
        if node.isRoot or node.color!=BLACK:
            node.color = BLACK
            return
        if node.sibling and node.sibling.isRed:
            node.sibling.color = BLACK
            node.parent.color = RED
            rotation = getattr(self, '_leftRotate' if node.isLeft else '_rightRotate')
            rotation(node.parent)
        leftChildisBlack = node.sibling.leftChild is None or node.sibling.leftChild.isBlack
        rightChildisBlack = node.sibling.rightChild is None or node.sibling.rightChild.isBlack
        if leftChildisBlack and rightChildisBlack:
            if node.sibling:
                node.sibling.color = RED
            self._delete_fix(node.parent)
            return
        childToCheck = getattr( node.sibling, 'rightChild' if node.isLeft else 'leftChild' )
        if (childToCheck is None) or (childToCheck.isBlack):
            childToColor = getattr( node.sibling, 'leftChild' if node.isLeft else 'rightChild' )
            if childToColor:
                childToColor.color = BLACK
            if node.sibling:
                node.sibling.color = RED
                rotation = getattr(self, '_rightRotate' if node.isLeft else '_leftRotate')
                rotation(node.sibling)
        if node.sibling:
            node.sibling.color = node.parent.color
            childToColor = getattr( node.sibling, 'rightChild' if node.isLeft else 'leftChild' )
        node.parent.color = BLACK        
        if childToColor:
            childToColor.color = BLACK
        rotation = getattr(self, '_leftRotate' if node.isLeft else '_rightRotate')
        rotation(node.parent)
        self._delete_fix(self.root)
        return


In [10]:
t = RedBlackTree()

for i in range(1, 10+1):
    t.insert(Value(i, i**2))

# values = [4.5,8,6,5,7,2,3,1,4,11,9, 20, 21, 23, 24]
# # values = [2,1,3]
# values = [Value(v,v**2) for v in values]

# for value in values:
#     t.insert(value)

# t.insert(4)
# t.insert(5)
# t.insert(2)
# t.insert(3)
# t.insert(1)
# t.insert(4.5)
# t.insert(6)
# t.insert(7)
# t.insert(8)
# t.insert(11)
# t.insert(9)
t


                        [4:16]:b
               /¯¯¯¯¯¯¯¯¯¯¯¯       ¯¯¯¯¯¯¯¯¯¯¯¯\
         [2:4]:b                        [6:36]:b
       /¯¯¯¯¯¯   ¯¯¯¯¯¯\               /¯¯¯¯¯¯   ¯¯¯¯¯¯\
 [1:1]:b         [3:9]:b        [5:25]:b        [8:64]:r
                                                   /¯¯¯ ¯¯¯\
                                            [7:49]:b[9:81]:b
                                                            ¯\
                                                    [10:100]:r

In [11]:
t.inOrderTraversal()

[[1:1]:b,
 [2:4]:b,
 [3:9]:b,
 [4:16]:b,
 [5:25]:b,
 [6:36]:b,
 [7:49]:b,
 [8:64]:r,
 [9:81]:b,
 [10:100]:r]

In [12]:
# some tests, these are handy to quickly verify if everything is working. I will eventually move them to a test file

In [13]:
def test(t):
    #checks to validate if all pointers are set correctly after rotation.
    parentCheck = lambda node: all([n.parent == node for n in node.children])
    sideCheck = lambda node: all( [ node.leftChild.side == LEFT if node.leftChild else True, node.rightChild.side == RIGHT if node.rightChild else True ] )
    rootCheck = lambda tree: ( tree.root.parent is None ) and ( tree.root.side is None )

    #check if parent and child both have red colors
    consecutiveRedCheck = lambda node: all( [ node.color != child.color for child in node.children if node.color == RED ] )

    #a BFS traversal while performing all the checks
    queue = [t.root]
    visited = []
    assert rootCheck(t), "root: {} isn't valid".format(t.root)
    while queue:
        node = queue.pop(0)
        if not parentCheck(node):
            raise Exception('children {children} of node: {node} do not point to it'.format(children = node.children, node = node))
        if not sideCheck(node):
            raise Exception('children {children} of node: {node} do not have the correct side attribute'.format(children = node.children, node = node))
        if not consecutiveRedCheck(node):
            raise Exception('node {node} and children {children} have red color'.format(node = node, children = node.children))
        queue.extend( node.children )
        visited.append(node)
    return True

In [14]:
test(t)

True

In [15]:
for node in t.inOrderTraversal():
    t.delete(node.key)
    if t.isEmpty:
        break
    test(t)
print('isEmpty: ', t.isEmpty)

isEmpty:  True
