# Goal of this project
Implenet an AVL tree using the BinarySearchTree class provided below.

Reference:
https://runestone.academy/ns/books/published/pythonds3/Trees/AVLTreeImplementation.html

# Binary Search Tree

## Implementation code

In [None]:
class TreeNode:
    def __init__(self, key, value, left=None, right=None, parent=None):
        self.key = key
        self.value = value
        self.left_child = left
        self.right_child = right
        self.parent = parent

    def is_left_child(self):
        return self.parent and self.parent.left_child is self

    def is_right_child(self):
        return self.parent and self.parent.right_child is self

    def is_root(self):
        return not self.parent

    def is_leaf(self):
        return not (self.right_child or self.left_child)

    def has_any_child(self):
        return self.right_child or self.left_child

    def has_children(self):
        return self.right_child and self.left_child

    def replace_value(self, key, value, left, right):
        self.key = key
        self.value = value
        self.left_child = left
        self.right_child = right
        if self.left_child:
            self.left_child.parent = self
        if self.right_child:
            self.right_child.parent = self

    def find_successor(self):
        successor = None
        if self.right_child:
            successor = self.right_child.find_min()
        else:
            if self.parent:
                if self.is_left_child():
                    successor = self.parent
                else:
                    self.parent.right_child = None
                    successor = self.parent.find_successor()
                    self.parent.right_child = self
        return successor

    def find_min(self):
        current = self
        while current.left_child:
            current = current.left_child
        return current

    def splice_out(self):
        if self.is_leaf():
            if self.is_left_child():
                self.parent.left_child = None
            else:
                self.parent.right_child = None
        elif self.has_any_child():
            if self.left_child:
                if self.is_left_child():
                    self.parent.left_child = self.left_child
                else:
                    self.parent.right_child = self.left_child
                self.left_child.parent = self.parent
            else:
                if self.is_left_child():
                    self.parent.left_child = self.right_child
                else:
                    self.parent.right_child = self.right_child
                self.right_child.parent = self.parent

    def __iter__(self):
        if self:
            if self.left_child:
                for elem in self.left_child:
                    yield elem
            yield self.key
            if self.right_child:
                for elem in self.right_child:
                    yield elem


class BinarySearchTree:
    def __init__(self):
        self.root = None
        self.size = 0

    def __len__(self):
        return self.size

    def __iter__(self):
        return self.root.__iter__()

    def put(self, key, value):
        if self.root:
            self._put(key, value, self.root)
        else:
            self.root = TreeNode(key, value)
        self.size = self.size + 1

    def _put(self, key, value, current_node):
        if key < current_node.key:
            if current_node.left_child:
                self._put(key, value, current_node.left_child)
            else:
                current_node.left_child = TreeNode(
                    key, value, parent=current_node
                )
        else:
            if current_node.right_child:
                self._put(key, value, current_node.right_child)
            else:
                current_node.right_child = TreeNode(
                    key, value, parent=current_node
                )

    def __setitem__(self, key, value):
        self.put(key, value)

    def get(self, key):
        if self.root:
            result = self._get(key, self.root)
            if result:
                return result.value
        return None

    def _get(self, key, current_node):
        if not current_node:
            return None
        if current_node.key == key:
            return current_node
        elif key < current_node.key:
            return self._get(key, current_node.left_child)
        else:
            return self._get(key, current_node.right_child)

    def __getitem__(self, key):
        return self.get(key)

    def __contains__(self, key):
        return bool(self._get(key, self.root))

    def delete(self, key):
        if self.size > 1:
            node_to_remove = self._get(key, self.root)
            if node_to_remove:
                self._delete(node_to_remove)
                self.size = self.size - 1
            else:
                raise KeyError("Error, key not in tree")
        elif self.size == 1 and self.root.key == key:
            self.root = None
            self.size = self.size - 1
        else:
            raise KeyError("Error, key not in tree")

    def _delete(self, current_node):
        if current_node.is_leaf():  # removing a leaf
            if current_node == current_node.parent.left_child:
                current_node.parent.left_child = None
            else:
                current_node.parent.right_child = None
        elif current_node.has_children():  # removing a node with two children
            successor = current_node.find_successor()
            successor.splice_out()
            current_node.key = successor.key
            current_node.value = successor.value
        else:  # removing a node with one child
            if current_node.left_child:
                if current_node.is_left_child():
                    current_node.left_child.parent = current_node.parent
                    current_node.parent.left_child = current_node.left_child
                elif current_node.is_right_child():
                    current_node.left_child.parent = current_node.parent
                    current_node.parent.right_child = current_node.left_child
                else:
                    current_node.replace_value(
                        current_node.left_child.key,
                        current_node.left_child.value,
                        current_node.left_child.left_child,
                        current_node.left_child.right_child,
                    )
            else:
                if current_node.is_left_child():
                    current_node.right_child.parent = current_node.parent
                    current_node.parent.left_child = current_node.right_child
                elif current_node.is_right_child():
                    current_node.right_child.parent = current_node.parent
                    current_node.parent.right_child = current_node.right_child
                else:
                    current_node.replace_value(
                        current_node.right_child.key,
                        current_node.right_child.value,
                        current_node.right_child.left_child,
                        current_node.right_child.right_child,
                    )

    def __delitem__(self, key):
        self.delete(key)


## Examples

In [None]:
class TreePrinter:
    def print_tree(self, node, indent="", position="Root"):
        """Print the tree in a vertical structure."""
        if node is not None:
            # Display the current node
            print(indent + f"{position}: {node.key}")

            # Update indentation for child nodes
            child_indent = indent + "   "

            # Print left and right children
            self.print_tree(node.left_child, child_indent, "L")
            self.print_tree(node.right_child, child_indent, "R")


my_tree = BinarySearchTree()
my_tree[4] = "4"
my_tree[1] = "1"
my_tree[10] = "10"
my_tree[3] = "3"
my_tree[7] = "7"
my_tree[20] = "20"
my_tree[8] = "8"
my_tree[9] = "9"
my_tree[6] = "6"
my_tree[2] = "2"
my_tree[22] = "22"

printer = TreePrinter()
printer.print_tree(my_tree.root)

Root: 4
   L: 1
      R: 3
         L: 2
   R: 10
      L: 7
         L: 6
         R: 8
            R: 9
      R: 20
         R: 22


# AVL Tree

## Implementation Code

In [None]:
class AVLTreeNode(TreeNode):
    def __init__(self, key, value, left=None, right=None, parent=None):
        super().__init__(key, value, left, right, parent)
        self.balance_factor = 0


class AVLTree(BinarySearchTree):
    def put(self, key, value):
        """Override 'put' to ensure new nodes are AVLTreeNode instances."""
        if self.root:
            self._put(key, value, self.root)
        else:
            self.root = AVLTreeNode(key, value)
        self.size += 1

    def _put(self, key, value, current_node):
        """Override '_put' to create AVLTreeNode instances."""
        if key < current_node.key:
            if current_node.left_child:
                self._put(key, value, current_node.left_child)
            else:
                current_node.left_child = AVLTreeNode(key, value, parent=current_node)
                self.update_balance(current_node.left_child)
        else:
            if current_node.right_child:
                self._put(key, value, current_node.right_child)
            else:
                current_node.right_child = AVLTreeNode(key, value, parent=current_node)
                self.update_balance(current_node.right_child)

    def update_balance(self, node):
        """Update balance factors and rebalance the tree if necessary."""
        if node.balance_factor > 1 or node.balance_factor < -1:
            self.rebalance(node)
            return
        if node.parent:
            if node.is_left_child():
                node.parent.balance_factor += 1
            elif node.is_right_child():
                node.parent.balance_factor -= 1

            if node.parent.balance_factor != 0:
                self.update_balance(node.parent)

    def rebalance(self, node):
        """Rebalance the tree based on the node's balance factor."""
        if node.balance_factor < 0:
            if node.right_child.balance_factor > 0:
                self.rotate_right(node.right_child)
            self.rotate_left(node)
        elif node.balance_factor > 0:
            if node.left_child.balance_factor < 0:
                self.rotate_left(node.left_child)
            self.rotate_right(node)

    def rotate_left(self, rotation_root):
        """Perform a left rotation."""
        new_root = rotation_root.right_child
        rotation_root.right_child = new_root.left_child
        if new_root.left_child:
            new_root.left_child.parent = rotation_root
        new_root.parent = rotation_root.parent
        if rotation_root.is_root():
            self.root = new_root
        else:
            if rotation_root.is_left_child():
                rotation_root.parent.left_child = new_root
            else:
                rotation_root.parent.right_child = new_root
        new_root.left_child = rotation_root
        rotation_root.parent = new_root

        # Update balance factors
        rotation_root.balance_factor = (
            rotation_root.balance_factor + 1 - min(new_root.balance_factor, 0)
        )
        new_root.balance_factor = (
            new_root.balance_factor + 1 + max(rotation_root.balance_factor, 0)
        )

    def rotate_right(self, rotation_root):
        """Perform a right rotation."""
        new_root = rotation_root.left_child
        rotation_root.left_child = new_root.right_child
        if new_root.right_child:
            new_root.right_child.parent = rotation_root
        new_root.parent = rotation_root.parent
        if rotation_root.is_root():
            self.root = new_root
        else:
            if rotation_root.is_left_child():
                rotation_root.parent.left_child = new_root
            else:
                rotation_root.parent.right_child = new_root
        new_root.right_child = rotation_root
        rotation_root.parent = new_root

        # Update balance factors
        rotation_root.balance_factor = (
            rotation_root.balance_factor - 1 - max(new_root.balance_factor, 0)
        )
        new_root.balance_factor = (
            new_root.balance_factor - 1 + min(rotation_root.balance_factor, 0)
        )

## Examples

In [None]:
class TreePrinter:
    def print_tree(self, node, indent="", position="Root"):
        """Print the tree in a vertical structure."""
        if node is not None:
            # Display the current node
            print(indent + f"{position}: {node.key}")

            # Update indentation for child nodes
            child_indent = indent + "   "

            # Print left and right children
            self.print_tree(node.left_child, child_indent, "L")
            self.print_tree(node.right_child, child_indent, "R")


my_tree = AVLTree()
my_tree[4] = "4"
my_tree[1] = "1"
my_tree[10] = "10"
my_tree[3] = "3"
my_tree[7] = "7"
my_tree[20] = "20"
my_tree[8] = "8"
my_tree[9] = "9"
my_tree[6] = "6"
my_tree[2] = "2"
my_tree[22] = "22"

printer = TreePrinter()
printer.print_tree(my_tree.root)

Root: 8
   L: 4
      L: 2
         L: 1
         R: 3
      R: 7
         L: 6
   R: 10
      L: 9
      R: 20
         R: 22
