In [1]:
class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None
        self.height = 0

In [2]:
class AVLTree:
    def __init__(self):
        self.root = None
    
    def _get_height(self, node):
        if not node:
            return -1
        return node.height
    
    def _balance_factor(self, node):
        if not node:
           return 0
        return self._get_height(node.left) - self._get_height(node.right)

    def _left_rotate(self, x):
        y = x.right
        B = y.left 
        
        x.right = B
        y.left = x
         
        x.height = 1 + max(self._get_height(x.left), self._get_height(x.right))
        y.height = 1 + max(self._get_height(y.left), self._get_height(y.right))
        
        return y
    
    def _right_rotate(self, x):
        y = x.left 
        B = y.right
        
        x.left = B 
        y.right = x 
        
        x.height = 1 + max(self._get_height(x.left), self._get_height(x.right))
        y.height = 1 + max(self._get_height(y.left), self._get_height(y.right))
        
        return y 
    
    def _left_right_rotate(self, x):
        x.left = self._left_rotate(x.left)
        return self._right_rotate(x)
    
    def _right_left_rotate(self, x):
        x.right = self._right_rotate(x.right)
        return self._left_rotate(x)
    
    def _fix_AVL(self, node):
        balance_factor_x = self._balance_factor(node)
        
        if balance_factor_x > 1:
            balance_factor_y = self._balance_factor(node.left)
            if balance_factor_y >= 0:
                return self._right_rotate(node)
            else:
                return self._left_right_rotate(node)
        elif balance_factor_x < -1:
            balance_factor_y = self._balance_factor(node.right)
            if balance_factor_y <= 0:
                return self._left_rotate(node)
            else:
                return self._right_left_rotate(node)
        return node
        
    def insert(self, value):
        self.root = self._insert(self.root, value)
    
    def _insert(self, node, value):
        if not node:
            return Node(value)
        elif value <= node.value:
            node.left = self._insert(node.left, value)
        else:
            node.right = self._insert(node.right, value)
        
        node.height = 1 + max(self._get_height(node.left), \
            self._get_height(node.right))
        return self._fix_AVL(node)
    
    def delete(self, value):
        self.root = self._delete(self.root, value)
    
    def _delete(self, node, value):
        if not node:
            return None
        elif value < node.value:
            node.left = self._delete(node.left, value)
        elif value > node.value:
            node.right = self._delete(node.right, value)
        else:
            # case 1: No Child
            if not node.left and not node.right:
                return None
            # case 2: One Child
            elif not node.left and node.right:
                return node.right
            elif node.left and not node.right:
                return node.left
            # case 3: Two Children
            temp = self._get_min_val_node(node.right)
            node.value = temp.value
            node.right = self._delete(node.right, temp.value)
            
        node.height = 1 + max(self._get_height(node.left), \
            self._get_height(node.right))   
        return self._fix_AVL(node)

    def _get_min_val_node(self, node):
        temp = node
        while temp.left:
            temp = temp.left
        return temp
    
    def search(self, value):
        return self._search(self.root, value)
    
    def _search(self, node, value):
        if not node:
            return None
        elif value < node.value:
            return self._search(node.left, value)
        elif value > node.value:
            return self._search(node.right, value)
        else:
            return node
    
    def is_in_AVL(self, value):
        return self.search(value) is not None
        
    def update(self, old_value, value):
        if self.is_in_AVL(old_value):
            self.delete(old_value)
            self.insert(value)
        
    def in_order_traversal(self):
        result = []
        self._in_order_traversal(self.root, result)
        return result
    
    def _in_order_traversal(self, node, result):
        # When you pass a list to a function, you are passing a reference 
        # to the same list object, not a copy of it. 
        # As a result, any modifications made to the list inside the 
        # function will affect the original list outside the function.
        if node:
            self._in_order_traversal(node.left, result)
            result.append(node.value)
            self._in_order_traversal(node.right, result)
    
    def is_AVL(self):
        return self._is_AVL(self.root)
    
    def _is_AVL(self, node):
        # AVL property: the absolute value of the height difference
        # between the left and right subtrees does not exceed 1
        if not node:
            return True
        left_height = self._get_height(node.left)
        right_height = self._get_height(node.right)
        if abs(left_height - right_height) > 1:
            return False
        return self._is_AVL(node.left) and self._is_AVL(node.right)

In [3]:
# Example usage:
avl = AVLTree()
avl.insert(10)
avl.insert(5)
avl.insert(20)
avl.insert(3)
avl.insert(7)
avl.insert(15)
avl.insert(30)
avl.insert(11)
avl.insert(29)
avl.insert(10)
avl.insert(2)
print(avl.is_AVL())
avl.delete(15)
print(avl.in_order_traversal())
print(avl.is_AVL())
n = 20
print(avl.search(n), avl.is_in_AVL(n))
avl.update(n, -1)
print(avl.search(n), avl.is_in_AVL(n))
print(avl.in_order_traversal())
print(avl.is_AVL())
avl.update(10,15)
print(avl.in_order_traversal())
print(avl.is_AVL())
avl.delete(0)
print(avl.in_order_traversal())

True
[2, 3, 5, 7, 10, 10, 11, 20, 29, 30]
True
<__main__.Node object at 0x00000281047AF130> True
None False
[-1, 2, 3, 5, 7, 10, 10, 11, 29, 30]
True
[-1, 2, 3, 5, 7, 10, 11, 15, 29, 30]
True
[-1, 2, 3, 5, 7, 10, 11, 15, 29, 30]


In [4]:
# Example usage:
import random
n = 1 << 14
cnt = 0
for i in range(n):
    avl = AVLTree()
    nums = []
    for _ in range(1000):
        num = random.randint(1, 10000)
        avl.insert(num)
        nums.append(num)
    if avl.in_order_traversal() == sorted(nums):
        cnt += 1
print(cnt)
print(cnt == n)

16384
True
