I'm really glad I went with the class method for storing the binary tree, which for some reason I overlooked in my original implementation.

In [2]:
class Node:
    def __init__(self, value, parent):
        self.value = value
        self.parent = parent
        self.left = None
        self.right = None
        self.balance_factor = 0  # balance_factor = right_height - left_height
        self.height = 0
        # while the purist implementation of AVL calculates the change in height every time, its easier saving it
    
    def set_right(self, right, no_retrace = False):
        self.right = right
        if right is not None:
            self.right.parent = self  # redundant if freshly creating the node
        if not no_retrace: retrace(self)
        
    def set_left(self, left, no_retrace = False):
        self.left = left
        if left is not None:
            self.left.parent = self
        if not no_retrace: retrace(self)

In [3]:
def construct_tree(sorted_array, parent=None):
    n = len(sorted_array)
    subtree = Node(sorted_array[n//2], parent)
    if n > 1:
        subtree.set_left(construct_tree(sorted_array[:n//2], subtree))
        if n > 2:
            subtree.set_right(construct_tree(sorted_array[(n//2)+1:], subtree))
    return subtree

In [6]:
construct_tree([1,2,3,4,5,5,6,7,8,9])

<__main__.Node at 0x104b43ac0>

In [4]:
# copied from https://stackoverflow.com/questions/34012886/print-binary-tree-level-by-level-in-python
# thanks @yozn
def printTree(node, level=0):
    if node.left is not None:
        printTree(node.left, level + 1)
    print(' ' * 4 * level + '-> ' + str(node.value))
    if node.right is not None:
        printTree(node.right, level + 1)
        
def testPrintTree(node, level=0):
    if node.left is not None:
        testPrintTree(node.left, level + 1)
    print(' ' * 4 * level + '-> ' + str(node.value)+", "+ str(node.balance_factor))
    if node.right is not None:
        testPrintTree(node.right, level + 1)

In [49]:
printTree(construct_tree([1,2,3,4,5,6,7,8,9]))

            -> 1
        -> 2
    -> 3
        -> 4
-> 5
            -> 6
        -> 7
    -> 8
        -> 9


Some accessor functions. So much more elegant than the list implementation!

In [5]:
class Node(Node):
    def distance_to_bottom_left(self):
        if self.left is None:
            return 0
        return 1 + self.left.distance_to_bottom_left()
    
    def distance_to_bottom_right(self):
        if self.right is None:
            return 0
        return 1 + self.right.distance_to_bottom_right()
    
    def get_min(self):
        if self.left is None:
            return self.value
        return self.left.get_min()
    
    def get_max(self):
        if self.right is None:
            return self.value
        return self.right.get_max()
    
    def exists(self, x):
        if self.value == x:
            return True
        elif x > self.value:
            if self.right is not None:
                return self.right.exists(x)
            else: return False
        elif x < self.value:
            if self.left is not None:
                return self.left.exists(x)
            else: return False
        else:
            raise ValueError
            
    def export(self):
        # exports in the Newick format https://en.wikipedia.org/wiki/Newick_format
        pass

In [10]:
test_tree = construct_tree(list(range(1,21)))

In [11]:
print(test_tree.distance_to_bottom_left())
print(test_tree.distance_to_bottom_right())
print(test_tree.get_min())
print(test_tree.get_max())
print(test_tree.exists(7))
print(test_tree.exists(21))

4
3
1
20
True
False


In [12]:
printTree(test_tree)

                -> 1
            -> 2
        -> 3
                -> 4
            -> 5
    -> 6
                -> 7
            -> 8
        -> 9
            -> 10
-> 11
                -> 12
            -> 13
        -> 14
            -> 15
    -> 16
                -> 17
            -> 18
        -> 19
            -> 20


In [13]:
testPrintTree(test_tree)

                -> 1, 0
            -> 2, -1
        -> 3, 0
                -> 4, 0
            -> 5, -1
    -> 6, 0
                -> 7, 0
            -> 8, -1
        -> 9, -1
            -> 10, 0
-> 11, 0
                -> 12, 0
            -> 13, -1
        -> 14, -1
            -> 15, 0
    -> 16, 0
                -> 17, 0
            -> 18, -1
        -> 19, -1
            -> 20, 0


In [7]:
# https://en.wikipedia.org/wiki/AVL_tree
def left_rotation(root):
    root.right.parent = root.parent
    old_right = root.right
    root.set_right(root.right.left, no_retrace=True)
    old_right.set_left(root, no_retrace=True)
    
def right_rotation(root):
    root.left.parent = root.parent
    old_left = root.left
    root.set_left(root.left.right, no_retrace=True)
    old_left.set_right(root, no_retrace=True)
    
def rightleft_rotation(root):
    right_rotation(root.right)
    left_rotation(root)
    
def leftright_rotation(root):
    left_rotation(root.left)
    right_rotation(root)

In [8]:
def retrace(node):  # correct the heights and balance factors of ancestor nodes
    # arg must be the edited node because its parent doesn't know which child has been edited
    current_node = node
    while current_node is not None:
        right_height = current_node.right.height if current_node.right is not None else -1
        left_height = current_node.left.height if current_node.left is not None else -1
        current_node.height = max(right_height, left_height) + 1
        current_node.balance_factor = right_height - left_height
        current_node = current_node.parent

def rebalance(node):  # not suitable for multi-element insertions
    if node.parent is not None:
        current_node = node.parent  # the lowest node that can necessitate rotation is the parent of the edited node
    else: return node
    
    while current_node.parent is not None:
        if abs(current_node.balance_factor) == 2:
            # more efficient than checking each node for +2 or -2 bc only 2 rotations max are necessary
            if current_node.balance_factor == 2:
                if current_node.right.balance_factor == -1:
                    # fixing the child of the parent was such a mess
                    if current_node.parent.right.right.value == current_node.right.value:
                        current_node.parent.set_right(current_node.right.left, no_retrace=True)
                    else:
                        current_node.parent.set_left(current_node.right.left, no_retrace=True)
                    rightleft_rotation(current_node)
                    retrace(current_node)
                else:
                    if current_node.parent.right.right.value == current_node.right.value:
                        current_node.parent.set_right(current_node.right, no_retrace=True)
                    else:
                        current_node.parent.set_left(current_node.right, no_retrace=True)
                    left_rotation(current_node)
                    retrace(current_node)
            else:  # -2
                if current_node.left.balance_factor == 1:
                    if current_node.parent.left.left.value == current_node.left.value:
                        current_node.parent.set_left(current_node.left.right, no_retrace=True)
                    else:
                        current_node.parent.set_right(current_node.left.right, no_retrace=True)
                    leftright_rotation(current_node)
                    retrace(current_node)
                else:
                    if current_node.parent.left.left.value == current_node.left.value:
                        current_node.parent.set_left(current_node.left, no_retrace=True)
                    else:
                        current_node.parent.set_right(current_node.left, no_retrace=True)
                    right_rotation(current_node)
                    retrace(current_node)
                    
                current_node = current_node.parent.parent
        else:
            current_node = current_node.parent
    # current_node is root
    # special case for root because parent is None
    if abs(current_node.balance_factor) == 2:
        if current_node.balance_factor == 2:
            if current_node.right.balance_factor == -1:
                rightleft_rotation(current_node)
                retrace(current_node)
            else:
                left_rotation(current_node)
                retrace(current_node)
        else:  # -2
            if current_node.left.balance_factor == 1:
                leftright_rotation(current_node)
                retrace(current_node)
            else:
                right_rotation(current_node)
                retrace(current_node)
        return current_node.parent
    else:
        return current_node
        

In [9]:
def insert(x, tree):
    current_val = tree.value
    current_node = tree
    while True:
        if x > current_node.value:
            if current_node.right is not None:
                current_node = current_node.right
            else:
                current_node.set_right(Node(x, current_node))
                return rebalance(current_node)
        elif x <= current_node.value:
            if current_node.left is not None:
                current_node = current_node.left
            else:
                current_node.set_left(Node(x, current_node))
                return rebalance(current_node)

In [123]:
insert(0, test_tree)

In [124]:
testPrintTree(test_tree)  # test without rebalancing, only retracing

                    -> 0, 0
                -> 1, -1
            -> 2, -2
        -> 3, -1
                -> 4, 0
            -> 5, -1
    -> 6, -1
                -> 7, 0
            -> 8, -1
        -> 9, -1
            -> 10, 0
-> 11, -1
                -> 12, 0
            -> 13, -1
        -> 14, -1
            -> 15, 0
    -> 16, 0
                -> 17, 0
            -> 18, -1
        -> 19, -1
            -> 20, 0


In [14]:
test_tree = construct_tree(list(range(1,21)))

In [15]:
insert(0, test_tree)

<__main__.Node at 0x103dcb040>

In [16]:
testPrintTree(test_tree)

                -> 0, 0
            -> 1, 0
                -> 2, 0
        -> 3, 0
                -> 4, 0
            -> 5, -1
    -> 6, 0
                -> 7, 0
            -> 8, -1
        -> 9, -1
            -> 10, 0
-> 11, 0
                -> 12, 0
            -> 13, -1
        -> 14, -1
            -> 15, 0
    -> 16, 0
                -> 17, 0
            -> 18, -1
        -> 19, -1
            -> 20, 0


In [17]:
insert(16.5, test_tree)

<__main__.Node at 0x103dcb040>

In [18]:
testPrintTree(test_tree)

                -> 0, 0
            -> 1, 0
                -> 2, 0
        -> 3, 0
                -> 4, 0
            -> 5, -1
    -> 6, 0
                -> 7, 0
            -> 8, -1
        -> 9, -1
            -> 10, 0
-> 11, 0
                -> 12, 0
            -> 13, -1
        -> 14, -1
            -> 15, 0
    -> 16, 0
                -> 16.5, 0
            -> 17, 0
                -> 18, 0
        -> 19, -1
            -> 20, 0


In [223]:
test_tree = construct_tree(list(range(1,32)))

In [231]:
insert(12.125, test_tree)

12.5 -2 13 <__main__.Node object at 0x106618d30> None


In [232]:
testPrintTree(test_tree)

                -> 1, 0
            -> 2, 0
                -> 3, 0
        -> 4, 0
                -> 5, 0
            -> 6, 0
                -> 7, 0
    -> 8, 1
                -> 9, 0
            -> 10, 0
                -> 11, 0
        -> 12, 1
                    -> 12.125, 0
                -> 12.25, 0
                    -> 12.5, 0
            -> 13, 0
                    -> 13.5, 0
                -> 14, 0
                    -> 15, 0
-> 16, -1
                -> 17, 0
            -> 18, 0
                -> 19, 0
        -> 20, 0
                -> 21, 0
            -> 22, 0
                -> 23, 0
    -> 24, 0
                -> 25, 0
            -> 26, 0
                -> 27, 0
        -> 28, 0
                -> 29, 0
            -> 30, 0
                -> 31, 0


Testing rotations with a worst case degenerate tree

In [19]:
new256 = construct_tree([256])
for num in range(255, 1, -1):
    new256 = insert(num, new256)

In [20]:
construct_tree([256])

<__main__.Node at 0x103dcb100>

In [21]:
testPrintTree(new256)

                            -> 2, 0
                        -> 3, 0
                            -> 4, 0
                    -> 5, 0
                            -> 6, 0
                        -> 7, 0
                            -> 8, 0
                -> 9, 0
                            -> 10, 0
                        -> 11, 0
                            -> 12, 0
                    -> 13, 0
                            -> 14, 0
                        -> 15, 0
                            -> 16, 0
            -> 17, 0
                            -> 18, 0
                        -> 19, 0
                            -> 20, 0
                    -> 21, 0
                            -> 22, 0
                        -> 23, 0
                            -> 24, 0
                -> 25, 0
                            -> 26, 0
                        -> 27, 0
                            -> 28, 0
                    -> 29, 0
                            -> 30, 0
                        -> 31, 0
  

Wow this is euphoric!