In [1]:
"""
red-black tree

properties:
1. Every node is either red or black;
2. Root node is black;
3. Every Nil leaf is black;
4. A red node has two black children nodes;
5. Every path from root to leaf has same number of black nodes (black-height).

1 to 1 correspondence to a 2-3 tree
"""

# TODO
# merge all Nil to one rbt.nil

# base node class
class Node:
    def __init__(self,value):
        self._value = value
        self._left = None
        self._right = None
        self._parent = None
    
    @property
    def value(self):
        return self._value
    
    @property
    def left(self):
        return self._left
    
    @left.setter
    def left(self,l):
        self._left = l
    
    @property
    def right(self):
        return self._right
    
    @right.setter
    def right(self,r):
        self._right = r
        
    @property
    def parent(self):
        return self._parent
    
    @parent.setter
    def parent(self,p):
        self._parent = p
        
    @property
    def color(self):
        return self._color
    
    @color.setter
    def color(self,color):
        self._color = color
        
# empty node class        
class Nil(Node):
    def __init__(self):
        Node.__init__(self, None)
        self._color = 'B'
        
    def __str__(self,level=0, offset=0):
        return "\t"*level+"Nil"+"\n"
    
    
# tree node class       
class RBNode(Node):
    
    def __init__(self,value):
        Node.__init__(self, value)
        self._left = Nil()
        self._right = Nil()
        self._parent = Nil()
        self._color = None
    
    # print self, left node and right node
    def __str__(self, level=0, offset=0):
        ret = "\t"*level+repr(self._value)+ " " + repr(self._color)+"\n"
        ret += self._left.__str__(level+1)
        ret += self._right.__str__(level+1)
        return ret
        
# tree class
class RBTree():
    
    def __init__(self):
        self._root = Nil()
    
    # print all nodes 
    def __str__(self):
        return self._root.__str__(0,0)
#         return self.inorder_walk(self._root, 0, 0, lambda n,l,o: n.__str__(l,o))
        
    def inorder_walk(self, node, level, offset, method):
        if not isinstance(node, Nil):
            self.inorder_walk(node.left, level+1, offset-1, method)
            method(node, level, offset)
            self.inorder_walk(node.right, level+1, offset+1, method)
            
    # rotation op to change pointer structure
    def _left_rotate(self, node_x):
        node_y = node_x.right
        node_x.right = node_y.left
        if not isinstance(node_y.left, Nil):
            node_y.left.parent = node_x
        node_y.parent = node_x.parent
        if isinstance(node_x.parent, Nil):
            self._root = node_y
        elif node_x == node_x.parent.left:
            node_x.parent.left = node_y
        else:
            node_x.parent.right = node_y
        node_y.left = node_x
        node_x.parent = node_y
        
    # symmetric to _left_rotate
    # where node_x => node_y, node_y => node_x
    # .left => .right and .right => .left
    def _right_rotate(self, node_y):
        node_x = node_y.left
        node_y.left = node_x.right
        if not isinstance(node_x.right, Nil):
            node_x.right.parent = node_y
        node_x.parent = node_y.parent
        if isinstance(node_y.parent, Nil):
            self._root = node_x
        elif node_y == node_y.parent.right:
            node_y.parent.right = node_x
        else:
            node_y.parent.left = node_x
        node_x.right = node_y
        node_y.parent = node_x
    
    # same as 
    def _transplant(self, u, v):
        # when u is root
        if isinstance(u.parent, Nil):
            self.root = v
        # when u is its parent's left children
        elif u.parent.left == u:
            u.parent.left = v
        # u is its parent's right children
        else:
            u.parent.right = v
        v.parent = u.parent
        
    # find successore in sorted order
    def _successor(self, rbnode):
        if isinstance(rbnode.right, Nil):
            return self._min(rbnode.right)
        else:
            p = rbnode.parent
            while isinstance(p, Nil) and p.right == rbnode:
                rbnode = p
                p = p.parent
            return p
    
    
    # find predecessor in sorted order
    def _predecessor(self, rbnode):
        if isinstance(rbnode.left, Nil):
            return self._max(rbnode.left)
        else:
            p = rbnode.parent
            while isinstance(p, Nil) and p.left == rbnode:
                rbnode = p
                p = p.parent
            return p
        
    def _min(self, rbnode):
        while isinstance(rbnode, Nil) and isinstance(rbnode.left, Nil):
            rbnode = rbnode.left
        return rbnode

    def _max(self, rbnode):
        while isinstance(rbnode, Nil) and isinstance(rbnode.right, Nil):
            rbnode = rbnode.right
        return rbnode        
    
    def insert(self, rbnode):
        # same as BST insert
        value = rbnode.value
        n = self._root
        parent = Nil()
        while not isinstance(n, Nil):
            parent = n
            if n.value > value:
                n = n.left
            else:
                n = n.right
        rbnode.parent = parent
        if isinstance(parent, Nil):
            self._root = rbnode
        else:
            if parent.value > value:
                parent.left = rbnode
            else:
                parent.right = rbnode

        # additional
        rbnode.left = Nil()
        rbnode.right = Nil()
        rbnode.color = 'R'
        
        # fix potential property 2 violation
        # where rbnode's parent is a red node
        n = rbnode
        while not isinstance(n.parent,Nil) and n.parent.color == 'R':
            if not isinstance(n.parent.parent, Nil):
                # parent is left child, red uncle
                # or parent is right child, red uncle
                if n.parent == n.parent.parent.left:
                    if not isinstance(n.parent.parent.right,Nil) and n.parent.parent.right.color == 'R':
                        print('case 1')
                        n.parent.parent.color = 'R'                
                        n.parent.parent.left.color = 'B'
                        n.parent.parent.right.color = 'B'
                        n = n.parent.parent
                        # continue while loop
                    else:
                        if n == n.parent.right:
                            print('case 2')
                            n = n.parent
                            self._left_rotate(n)
                        else:
                            print('case 3')
                        n.parent.color = 'B'
                        n.parent.parent.color = 'R'
                        self._right_rotate(n.parent.parent)
                        # stop loop
                        
                # symmetric as the above "if"
                elif n.parent == n.parent.parent.right:
                    if not isinstance(n.parent.parent.left,Nil) and n.parent.parent.left.color == 'R':
                        print('case 4')
                        n.parent.parent.color = 'R'                
                        n.parent.parent.left.color = 'B'
                        n.parent.parent.right.color = 'B'
                        n = n.parent.parent
                    else:
                        if n == n.parent.left:
                            print('case 5')
                            n = n.parent
                            self._right_rotate(n)
                        else:
                            print('case 6')
                        n.parent.color = 'B'
                        n.parent.parent.color = 'R'
                        self._left_rotate(n.parent.parent)
                        
        # fix potential property 4 violation
        # where rbnode is root                
        self._root.color = 'B'
        
    def delete(self, rbnode):
        c = rbnode.color
        if isinstance(rbnode.left, Nil):
            x = rbnode.right
            self._transplant(rbnode, rbnode.left)
        elif isinstance(rbnode.right, Nil):
            x = rbnode.left
            self._transplant(rbnode, rbnode.right) 
        else:
            suc = self._successor(rbnode)
            c = suc.color
            # if successor is not node's right child, but in the subtree
            # then move it to rbnode's place
            if suc.parent != rbnode:
                self._transplant(suc, suc.right)
                suc.right = rbnode.right
                suc.right.parent = suc
            self._transplant(rbnode, suc)
            suc.left = rbnode.left
            rbnode.left.parent = suc
            
        # fixup
#         if c == 'B':
            
            
        
        

In [2]:
rbt = RBTree()
print("INIT")
print(rbt)

for i in [2,3,5,4,0,6,1,9,7,8]:
    rbt.insert(RBNode(i))
    print("AFTER inserting {0}".format(i))
    print(rbt)

INIT
Nil

AFTER inserting 2
2 'B'
	Nil
	Nil

AFTER inserting 3
2 'B'
	Nil
	3 'R'
		Nil
		Nil

case 6
AFTER inserting 5
3 'B'
	2 'R'
		Nil
		Nil
	5 'R'
		Nil
		Nil

case 4
AFTER inserting 4
3 'B'
	2 'B'
		Nil
		Nil
	5 'B'
		4 'R'
			Nil
			Nil
		Nil

AFTER inserting 0
3 'B'
	2 'B'
		0 'R'
			Nil
			Nil
		Nil
	5 'B'
		4 'R'
			Nil
			Nil
		Nil

AFTER inserting 6
3 'B'
	2 'B'
		0 'R'
			Nil
			Nil
		Nil
	5 'B'
		4 'R'
			Nil
			Nil
		6 'R'
			Nil
			Nil

case 2
AFTER inserting 1
3 'B'
	1 'B'
		0 'R'
			Nil
			Nil
		2 'R'
			Nil
			Nil
	5 'B'
		4 'R'
			Nil
			Nil
		6 'R'
			Nil
			Nil

case 4
AFTER inserting 9
3 'B'
	1 'B'
		0 'R'
			Nil
			Nil
		2 'R'
			Nil
			Nil
	5 'R'
		4 'B'
			Nil
			Nil
		6 'B'
			Nil
			9 'R'
				Nil
				Nil

case 5
AFTER inserting 7
3 'B'
	1 'B'
		0 'R'
			Nil
			Nil
		2 'R'
			Nil
			Nil
	5 'R'
		4 'B'
			Nil
			Nil
		7 'B'
			6 'R'
				Nil
				Nil
			9 'R'
				Nil
				Nil

case 4
case 6
AFTER inserting 8
5 'B'
	3 'R'
		1 'B'
			0 'R'
				Nil
				Nil
			2 'R'
	