In [1]:
from tree_classes import Value, BinaryNode, Tree, NodeBase
from tree_classes import RIGHT, LEFT

In [2]:
class BinaryTree(Tree):
    def __init__(self):
        super().__init__()
        self.size = 0

    def insert(self, value):
        if self.root is None:
            self.root = BinaryNode(value, side = None, parent = None)
            self.size += 1
            return self.root
        
        node = BinaryNode(value = value)
        node = self._insert(self.root, node)
        self.size += 1
        return node

    def _insert(self, parent: BinaryNode, node: BinaryNode):
        if node < parent:
            if not parent.leftChild:
                node.parent = parent
                node.side = LEFT
                parent.leftChild = node
            else:
                self._insert(parent.leftChild, node)
        else:
            if not parent.rightChild:
                node.parent = parent
                node.side = RIGHT
                parent.rightChild = node
            else:
                self._insert(parent.rightChild, node)
        return node
    
    def search(self, key):
        return self._search(self.root, key)

    def _search(self, node: BinaryNode, key):
        if node.key == key:
            return node
        if node.isLeaf:
            return
        if key < node.key:
            if not node.leftChild:
                return
            return self._search(node.leftChild, key)
        if not node.rightChild:
            return
        return self._search(node.rightChild, key)
    
    def inOrderTraversal(self, node = None):
        node = node or self.root
        return self._inOrderTraversal(node)

    def _inOrderTraversal(self, node: BinaryNode):
        res = [node]
        if node.isLeaf:
            return res
        if node.leftChild:
            res = self._inOrderTraversal(node.leftChild) + res
        if node.rightChild:
            res = res + self._inOrderTraversal(node.rightChild)
        return res
    
    def _transplant(self, deleteNode: BinaryNode, replaceNode):
        if deleteNode.isRoot:
            self.root = replaceNode
        elif deleteNode.isLeft:
            deleteNode.parent.leftChild = replaceNode
        else:
            deleteNode.parent.rightChild = replaceNode
        if replaceNode:
            replaceNode.parent = deleteNode.parent
            replaceNode.side = deleteNode.side

    def delete(self, key):
        if not self.root:
            return
        deleteNode = self.search(key)
        if not deleteNode:
            return
        if deleteNode.isLeaf:
            self._transplant(deleteNode, None)
            return
        elif len(deleteNode.children) == 1:
            self._transplant(deleteNode, deleteNode.children[0])
            return
        leftMostNode = self._inOrderTraversal(deleteNode.rightChild)[0]
        if leftMostNode.parent != deleteNode:
            self._transplant(leftMostNode, leftMostNode.rightChild)
            leftMostNode.rightChild = deleteNode.rightChild
            leftMostNode.rightChild.parent = leftMostNode
        self._transplant(deleteNode, leftMostNode)
        leftMostNode.leftChild = deleteNode.leftChild
        leftMostNode.leftChild.parent = leftMostNode
        leftMostNode.parent = deleteNode.parent
        leftMostNode.side = deleteNode.side


In [3]:
t = BinaryTree()

# for i in range(1, 7+1):
#     t.insert(Value(i, i**2))

# values = [4.5,8,6,5,7,2,3,1,4,11,9, 20, 21, 23, 24]
# values = [Value(v,v**2) for v in values]

# for value in values:
#     t.insert(value)

# values = [4.5,8,6,5,7,2,3,1,4,11,9, 20, 21, 23, 24]
values = [2,1,3]
values = [Value(v,v**2) for v in values]

for value in values:
    t.insert(value)

# t.insert(4)
# t.insert(5)
# t.insert(2)
# t.insert(3)
# t.insert(1)
# t.insert(4.5)
# t.insert(6)
# t.insert(7)
# t.insert(8)
# t.insert(11)
# t.insert(9)
t


[2:4]
 /¯ ¯\
[1:1][3:9]

In [4]:
t.delete(3)
t


[2:4]
 /¯
[1:1]

In [5]:
#checks to validate if all pointers are set correctly after rotation.
parentCheck = lambda node: all([n.parent == node for n in node.children])
sideCheck = lambda node: all( [ node.leftChild.side == LEFT if node.leftChild else True, node.rightChild.side == RIGHT if node.rightChild else True ] )
rootCheck = lambda tree: ( tree.root.parent is None ) and ( tree.root.side is None )

#check if parent and child both have red colors
consecutiveRedCheck = lambda node: all( [ node.color != child.color for child in node.children if node.color == RED ] )

#a BFS traversal while performing all the checks
queue = [t.root]
visited = []
assert rootCheck(t), "root: {} isn't valid".format(t.root)
while queue:
    node = queue.pop(0)
    if not parentCheck(node):
        raise Exception('children {children} of node: {node} do not point to it'.format(children = node.children, node = node))
    if not sideCheck(node):
        raise Exception('children {children} of node: {node} do not have the correct side attribute'.format(children = node.children, node = node))
    queue.extend( node.children )
    visited.append(node)