I'm really glad I went with the class method for storing the binary tree, which for some reason I overlooked in my original implementation.

In [146]:
class Node:
    def __init__(self, value, parent):
        self.value = value
        self.parent = parent
        self.left = None
        self.right = None
        self.balance_factor = 0  # balance_factor = right_height - left_height
        self.height = 0
        # while the purist implementation of AVL calculates the change in height every time, its easier saving it
    
    def set_right(self, right, no_retrace = False):
        if self.right is not None:
            old_height = self.right.height
        else:  # no node existing, height of -1
            old_height = -1
        self.right = right
        if right is not None:
            if right.height + 1 > self.height: self.height = right.height + 1
            self.right.parent = self  # redundant if freshly creating the node
            right_height = self.right.height
        else: right_height = -1
        self.balance_factor += right_height - old_height  # change in heights, added
        if not no_retrace: retrace(self)
        
    def set_left(self, left, no_retrace = False):
        if self.left is not None:
            old_height = self.left.height
        else:  # no node existing, height of -1
            old_height = -1
        self.left = left
        if left is not None:
            if left.height + 1 > self.height: self.height = left.height + 1
            self.left.parent = self
            left_height = self.left.height
        else: left_height = -1
        self.balance_factor -= left_height - old_height  # change in heights, subtracted
        if not no_retrace: retrace(self)

In [5]:
def construct_tree(sorted_array, parent=None):
    n = len(sorted_array)
    subtree = Node(sorted_array[n//2], parent)
    if n > 1:
        subtree.set_left(construct_tree(sorted_array[:n//2], subtree))
        if n > 2:
            subtree.set_right(construct_tree(sorted_array[(n//2)+1:], subtree))
    return subtree

In [6]:
construct_tree([1,2,3,4,5,5,6,7,8,9])

<__main__.Node at 0x104b43ac0>

In [100]:
# copied from https://stackoverflow.com/questions/34012886/print-binary-tree-level-by-level-in-python
# thanks @yozn
def printTree(node, level=0):
    if node.left is not None:
        printTree(node.left, level + 1)
    print(' ' * 4 * level + '-> ' + str(node.value))
    if node.right is not None:
        printTree(node.right, level + 1)
        
def testPrintTree(node, level=0):
    if node.left is not None:
        testPrintTree(node.left, level + 1)
    print(' ' * 4 * level + '-> ' + str(node.value)+", "+ str(node.balance_factor))
    if node.right is not None:
        testPrintTree(node.right, level + 1)

In [49]:
printTree(construct_tree([1,2,3,4,5,6,7,8,9]))

            -> 1
        -> 2
    -> 3
        -> 4
-> 5
            -> 6
        -> 7
    -> 8
        -> 9


Some accessor functions. So much more elegant than the list implementation!

In [147]:
class Node(Node):
    def distance_to_bottom_left(self):
        if self.left is None:
            return 0
        return 1 + self.left.distance_to_bottom_left()
    
    def distance_to_bottom_right(self):
        if self.right is None:
            return 0
        return 1 + self.right.distance_to_bottom_right()
    
    def get_min(self):
        if self.left is None:
            return self.value
        return self.left.get_min()
    
    def get_max(self):
        if self.right is None:
            return self.value
        return self.right.get_max()
    
    def exists(self, x):
        if self.value == x:
            return True
        elif x > self.value:
            if self.right is not None:
                return self.right.exists(x)
            else: return False
        elif x < self.value:
            if self.left is not None:
                return self.left.exists(x)
            else: return False
        else:
            raise ValueError
            
    def export(self):
        # exports in the Newick format https://en.wikipedia.org/wiki/Newick_format
        pass

In [89]:
test_tree = construct_tree(list(range(1,21)))

In [27]:
print(test_tree.distance_to_bottom_left())
print(test_tree.distance_to_bottom_right())
print(test_tree.get_min())
print(test_tree.get_max())
print(test_tree.exists(7))
print(test_tree.exists(21))

4
3
1
20
True
False


In [90]:
printTree(test_tree)

                -> 1
            -> 2
        -> 3
                -> 4
            -> 5
    -> 6
                -> 7
            -> 8
        -> 9
            -> 10
-> 11
                -> 12
            -> 13
        -> 14
            -> 15
    -> 16
                -> 17
            -> 18
        -> 19
            -> 20


In [91]:
testPrintTree(test_tree)

                -> 1, 0
            -> 2, 1
        -> 3, 2
                -> 4, 0
            -> 5, 1
    -> 6, 3
                -> 7, 0
            -> 8, 1
        -> 9, 2
            -> 10, 0
-> 11, 4
                -> 12, 0
            -> 13, 1
        -> 14, 2
            -> 15, 0
    -> 16, 3
                -> 17, 0
            -> 18, 1
        -> 19, 2
            -> 20, 0


In [178]:
# https://en.wikipedia.org/wiki/AVL_tree
def left_rotation(root):
    root.right.parent = root.parent
    #root.parent = root.right
    old_right_left = root.right.left
    root.right.set_left(root, no_retrace=True)
    root.set_right(old_right_left, no_retrace=True)
    
def right_rotation(root):
    root.left.parent = root.parent
    #root.parent = root.left
    old_left_right = root.left.right
    root.left.set_right(root, no_retrace=True)
    root.set_left(old_left_right, no_retrace=True)
    
def rightleft_rotation(root):
    right_rotation(root.right)
    left_rotation(root)
    
def leftright_rotation(root):
    left_rotation(root.left)
    right_rotation(root)

In [185]:
def retrace(node):  # correct the heights and balance factors of ancestor nodes
    # arg must be the edited node because its parent doesn't know which child has been edited
    height_counter = node.height + 1
    current_node = node.parent
    while current_node is not None:
        if height_counter > current_node.height:
            current_node.height = height_counter
            right_height = current_node.right.height if current_node.right is not None else -1
            left_height = current_node.left.height if current_node.left is not None else -1
            current_node.balance_factor = right_height - left_height
        height_counter += 1
        current_node = current_node.parent

def rebalance(node):  # not suitable for multi-element insertions
    current_node = node.parent  # the lowest node that can necessitate rotation is the parent of the edited node
    while current_node is not None:
        if abs(current_node.balance_factor) == 2:
            print(current_node.value, current_node.balance_factor, current_node.parent.value, current_node.left, current_node.right)
            # more efficient than checking each node for +2 or -2 bc only 2 rotations max are necessary
            if current_node.balance_factor == 2:
                if current_node.right.balance_factor == -1:
                    # fixing the child of the parent was such a mess
                    if current_node.parent.right.right.value == current_node.right.value:
                        current_node.parent.set_right(current_node.right.left, no_retrace=True)
                    else:
                        current_node.parent.set_left(current_node.right.left, no_retrace=True)
                    rightleft_rotation(current_node)
                else:
                    if current_node.parent.right.right.value == current_node.right.value:
                        current_node.parent.set_right(current_node.right, no_retrace=True)
                    else:
                        current_node.parent.set_left(current_node.right, no_retrace=True)
                    left_rotation(current_node)
            else:  # -2
                if current_node.left.balance_factor == 1:
                    if current_node.parent.left.left.value == current_node.left.value:
                        current_node.parent.set_left(current_node.left.right, no_retrace=True)
                    else:
                        current_node.parent.set_right(current_node.left.right, no_retrace=True)
                    leftright_rotation(current_node)
                else:
                    if current_node.parent.left.left.value == current_node.left.value:
                        current_node.parent.set_left(current_node.left, no_retrace=True)
                    else:
                        current_node.parent.set_right(current_node.left, no_retrace=True)
                    right_rotation(current_node)
                    
                current_node = current_node.parent.parent
        else:
            current_node = current_node.parent
        time.sleep(0.1)
        

In [128]:
def insert(x, tree):
    current_val = tree.value
    current_node = tree
    while x != current_node.value:
        if x > current_node.value:
            if current_node.right is not None:
                current_node = current_node.right
            else:
                current_node.set_right(Node(x, current_node))
                rebalance(current_node)
                return
        elif x <= current_node.value:
            if current_node.left is not None:
                current_node = current_node.left
            else:
                current_node.set_left(Node(x, current_node))
                rebalance(current_node)
                return

In [123]:
insert(0, test_tree)

In [124]:
testPrintTree(test_tree)  # test without rebalancing, only retracing

                    -> 0, 0
                -> 1, -1
            -> 2, -2
        -> 3, -1
                -> 4, 0
            -> 5, -1
    -> 6, -1
                -> 7, 0
            -> 8, -1
        -> 9, -1
            -> 10, 0
-> 11, -1
                -> 12, 0
            -> 13, -1
        -> 14, -1
            -> 15, 0
    -> 16, 0
                -> 17, 0
            -> 18, -1
        -> 19, -1
            -> 20, 0


In [186]:
test_tree = construct_tree(list(range(1,21)))

In [187]:
insert(0, test_tree)

2 -2 3 <__main__.Node object at 0x106be4f70> None


In [188]:
testPrintTree(test_tree)

                -> 0, 0
            -> 1, 2
                -> 2, 2
        -> 3, 0
                -> 4, 0
            -> 5, -1
    -> 6, -1
                -> 7, 0
            -> 8, -1
        -> 9, -1
            -> 10, 0
-> 11, -1
                -> 12, 0
            -> 13, -1
        -> 14, -1
            -> 15, 0
    -> 16, 0
                -> 17, 0
            -> 18, -1
        -> 19, -1
            -> 20, 0
