In [342]:
from tree_utils import printTree as ptree

In [353]:
RED = 'red'
BLACK = 'black'
RIGHT = 'right'
LEFT = 'left'
NULL = 'NULL'

In [389]:
class Value:
    def __init__(self, key, value):
        self.key = key
        self.value = value

    def __repr__(self) -> str:
        return '[{key}:{value}]'.format(key = self.key, value = self.value)
    
    def __gt__(self, other):
        return self.key > other.key
    
    def __lt__(self, other):
        return self.key < other.key
    
    def __eq__(self, other):
        return self.key == other.key
    
    def __ge__(self, other):
        return self.key >= other.key
    
    def __le__(self, other):
        return self.key <= other.key
    
    def __ne__(self, other):
        return self.key != other.key

In [390]:
class NodeBase:

    def __init__(self, value, side = None, parent = None, leftChild = None, rightChild = None):
        if (parent and not side):
            raise Exception('side not assigned. node must have a side assigned when it is not a root node')
        self.value = value
        self.side = side
        self.parent = parent
        self.leftChild = leftChild
        self.rightChild = rightChild

    def __gt__(self, other):
        return self.value > other.value
    
    def __lt__(self, other):
        return self.value < other.value
    
    def __eq__(self, other):
        return self.value == other.value
    
    def __ge__(self, other):
        return self.value >= other.value
    
    def __le__(self, other):
        return self.value <= other.value
    
    def __ne__(self, other):
        return self.value != other.value
    
    @property
    def key(self):
        return self.value.key
    
    @property
    def isLeft(self):
        return True if self.side==LEFT else False
    
    @property
    def isRight(self):
        return True if self.side==RIGHT else False
    
    @property
    def isNone(self):
        return True if self.value is None else False
    
    @property
    def isLeaf(self):
        return (self.leftChild is None) and (self.rightChild is None)
    
    @property
    def isRoot(self):
        return True if self.parent is None else False
    
    @property
    def children(self):
        return [ child for child in [ self.leftChild, self.rightChild ] if child ]
    
    def __repr__(self) -> str:
        return "{value}".format(value = self.value)

In [391]:
class Node(NodeBase):
    
    def __init__(self,
                 value,
                 side       = None,
                 color      = RED,
                 leftChild  = None,
                 rightChild = None,
                 parent     = None):        
        super().__init__(value, side = side, leftChild = leftChild, rightChild = rightChild, parent = parent)
        self.color = color
   
    @property
    def isRed(self):
        return True if self.color==RED else False
    
    @property
    def isBlack(self):
        return not self.isRED
    
    def __repr__(self) -> str:        
        return "{value}:{color}".format(value = super().__repr__(), color = self.color[0])
    
    @property
    def grandParent(self):
        if self.isRoot:
            raise Exception('root node {node} does not have grand parent'.format(node = self))
        if self.parent.isRoot:
            raise Exception('node {node} has parent {root} which is root'.format(node = self, root = self.parent))
        return self.parent.parent
    
    @property
    def hasGrandParent(self):
        return False if self.isRoot or self.parent.isRoot else True
    
    @property
    def aunt(self):
        if self.hasGrandParent:
            return self.grandParent.rightChild if self.parent.isLeft else self.grandParent.leftChild
        raise Exception('node {node} does not have a grand parent'.format(node = self))

In [392]:
class Tree:

    def __init__(self):
        self.root = None
    
    def __repr__(self) -> str:
        linestrList, pstrList = ptree(self.root)
        lines = []
        for linestr, pstr in zip(linestrList, pstrList):
            lines.append(linestr)
            lines.append(pstr)
        return '\n'.join(lines)
    
    @property
    def height(self):
        return 1+self._height(self.root)

    def _height(self, node):
        if node is None or node.isLeaf:
            return 0
        return max( self._height(node.leftChild), self._height(node.rightChild) ) + 1

In [396]:
class RedBlackTree(Tree):
    def __init__(self):
        super().__init__()
        self.size = 0    

    def insert(self, value):
        if self.root is None:
            self.root = Node(value, side = None, parent = None, color = BLACK)
            self.size += 1
            return
        
        node = Node(value = value, color = RED)
        node = self._insert(self.root, node)
        self.size += 1
        self._insertFix(node)

    def _insert(self, parent: Node, node: Node):
        if node < parent:
            if not parent.leftChild:
                node.parent = parent
                node.side = LEFT
                parent.leftChild = node
            else:
                self._insert(parent.leftChild, node)
        else:
            if not parent.rightChild:
                node.parent = parent
                node.side = RIGHT
                parent.rightChild = node
            else:
                self._insert(parent.rightChild, node)
        return node
    
    def _leftRotate(self, node: Node):
        rightChild = node.rightChild
        if not rightChild:
            raise Exception('invalid left rotate')
        node.rightChild = rightChild.leftChild        
        if rightChild.leftChild:
            rightChild.leftChild.parent = node
            node.rightChild.side = RIGHT
        rightChild.parent = node.parent
        if node.isRoot:
            self.root = rightChild
        elif node.isLeft:
            node.parent.leftChild = rightChild
        else:
            node.parent.rightChild = rightChild
        rightChild.leftChild = node
        rightChild.side = node.side
        node.parent = rightChild
        node.side = LEFT

    def _rightRotate(self, node: Node):
        leftChild = node.leftChild
        if not leftChild:
            raise Exception('invalid right rotate')
        node.leftChild = leftChild.rightChild        
        if leftChild.rightChild:
            leftChild.rightChild.parent = node
            node.leftChild.side = LEFT
        leftChild.parent = node.parent
        if node.isRoot:
            self.root = leftChild
        elif node.isLeft:
            node.parent.leftChild = leftChild
        else:
            node.parent.rightChild = leftChild
        leftChild.rightChild = node
        leftChild.side = node.side
        node.parent = leftChild
        node.side = RIGHT

    def _insertFix(self, node:Node):
        if node.isRoot:
            node.color = BLACK
            return
        if node.parent.color == BLACK:
            return
        if node.aunt and node.aunt.color == RED:
            node.parent.color = BLACK
            node.aunt.color = BLACK
            node.grandParent.color = RED
            self._insertFix(node.grandParent)
            return
        if node.parent.isLeft:
            if node.isRight:
                self._leftRotate(node.parent)
                self._insertFix(node.leftChild)
                return
            node.parent.color = BLACK
            node.grandParent.color = RED
            self._rightRotate(node.grandParent)
            return
        if node.isLeft:
            self._rightRotate(node.parent)
            self._insertFix(node.rightChild)
            return
        node.parent.color = BLACK
        node.grandParent.color = RED
        self._leftRotate(node.grandParent)

    def search(self, key):
        value = self._search(self.root, key)
        return value.value if value else value

    def _search(self, node: Node, key):
        if node.key == key:
            return node.value
        if node.isLeaf:
            return
        if key < node.key:
            if not node.leftChild:
                return
            return self._search(node.leftChild, key)
        if not node.rightChild:
            return
        return self._search(node.rightChild, key)

In [399]:
t = RedBlackTree()

for i in range(1, 15+1):
    t.insert(Value(i, i**2))

# values = [4,5,2,3,1,4.5,6,7,8,11,9]
# values = [Value(v,v**2) for v in values]

# for value in values:
#     t.insert(value)

# t.insert(4)
# t.insert(5)
# t.insert(2)
# t.insert(3)
# t.insert(1)
# t.insert(4.5)
# t.insert(6)
# t.insert(7)
# t.insert(8)
# t.insert(11)
# t.insert(9)
t


                                                        [4:16]:b
                               /¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯               ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯\
                         [2:4]:b                                                        [8:64]:r
               /¯¯¯¯¯¯¯¯¯¯¯¯       ¯¯¯¯¯¯¯¯¯¯¯¯\                               /¯¯¯¯¯¯¯¯¯¯¯¯       ¯¯¯¯¯¯¯¯¯¯¯¯\
         [1:1]:b                         [3:9]:b                        [6:36]:b                      [10:100]:b
                                                                       /¯¯¯¯¯¯   ¯¯¯¯¯¯\               /¯¯¯¯¯¯   ¯¯¯¯¯¯\
                                                                [5:25]:b        [7:49]:b        [9:81]:b      [12:144]:r
                                                                                                                   /¯¯¯ ¯¯¯\
                                                                                                          [11:121]:b[14:196]:b
                    

In [402]:
t.search(13)

169

In [351]:
# some tests, these are handy to quickly verify if everything is working. I will eventually move them to a test file

In [352]:
#checks to validate if all pointers are set correctly after rotation.
parentCheck = lambda node: all([n.parent == node for n in node.children])
sideCheck = lambda node: all( [ node.leftChild.side == LEFT if node.leftChild else True, node.rightChild.side == RIGHT if node.rightChild else True ] )
rootCheck = lambda tree: ( tree.root.parent is None ) and ( tree.root.side is None )

#check if parent and child both have red colors
consecutiveRedCheck = lambda node: all( [ node.color != child.color for child in node.children if node.color == RED ] )

#a BFS traversal while performing all the checks
queue = [t.root]
visited = []
assert rootCheck(t), "root: {} isn't valid".format(t.root)
while queue:
    node = queue.pop(0)
    if not parentCheck(node):
        raise Exception('children {children} of node: {node} do not point to it'.format(children = node.children, node = node))
    if not sideCheck(node):
        raise Exception('children {children} of node: {node} do not have the correct side attribute'.format(children = node.children, node = node))
    if not consecutiveRedCheck(node):
        raise Exception('node {node} and children {children} have red color'.format(node = node, children = node.children))
    queue.extend( node.children )
    visited.append(node)