In [1]:
class Node:
    
    def __init__(self, data): 
        self.data = data 
        self.left = None
        self.right = None
        self.height = 1

In [2]:
class AVLTree: 

    def insert(self, root, key):
        # step 1: BST
        if root == None: 
            return Node(key) 
        elif key < root.data: 
            root.left = self.insert(root.left, key) 
        else: 
            root.right = self.insert(root.right, key) 
        # step 2: update the height of the ancestor node 
        root.height = 1 + max(self.get_height(root.left), self.get_height(root.right)) 
        # step 3: get the balance factor 
        balance = self.get_balance(root) 
        # 不平衡狀態
        # step 4: if the node is unbalanced, then try out 4 cases
        # case 1: left left (LL)
        if balance > 1 and key < root.left.data: # left height > right height, key < left 
            return self.right_rotate(root) 
        # case 2: right right (RR)
        if balance < -1 and key > root.right.data: # left height < right height, key > right
            return self.left_rotate(root) 
        # case 3: left right (LR)
        if balance > 1 and key > root.left.data: # left height > right height, key > left
            root.left = self.left_rotate(root.left) 
            return self.right_rotate(root) 
        # case 4: right left (RL)
        if balance < -1 and key < root.right.data: # left height < right height, key < right
            root.right = self.right_rotate(root.right) 
            return self.left_rotate(root) 
        # 平衡狀態
        return root 

    def remove(self, root, key):
        # step 1: BST
        if not root: 
            return root 
        elif key < root.data: 
            root.left = self.remove(root.left, key) 
        elif key > root.data: 
            root.right = self.remove(root.right, key) 
        else: 
            if root.left == None: 
                temp = root.right 
                root = None
                return temp 
            elif root.right == None: 
                temp = root.left 
                root = None
                return temp 
            temp = self.get_min(root.right) 
            root.data = temp.data 
            root.right = self.remove(root.right, temp.data)   
        # if the tree has only one node, simply return it 
        if root == None: 
            return root 
        # step 2: update the height
        root.height = 1 + max(self.get_height(root.left), self.get_height(root.right)) 
        # step 3: get the balance factor 
        balance = self.get_balance(root) 
        # step 4: if the node is unbalanced, then try out 4 cases 
        if balance > 1 and self.get_balance(root.left) >= 0: 
            return self.rightRotate(root) 
        if balance < -1 and self.get_balance(root.right) <= 0: 
            return self.leftRotate(root) 
        if balance > 1 and self.get_balance(root.left) < 0: 
            root.left = self.leftRotate(root.left) 
            return self.rightRotate(root) 
        if balance < -1 and self.get_balance(root.right) > 0: 
            root.right = self.rightRotate(root.right) 
            return self.leftRotate(root) 
        return root 
  
    def left_rotate(self, z): 
        y = z.right 
        T2 = y.left 
        # rotation 
        y.left = z 
        z.right = T2 
        # update heights 
        z.height = 1 + max(self.get_height(z.left), self.get_height(z.right)) 
        y.height = 1 + max(self.get_height(y.left), self.get_height(y.right)) 
        # return the new root 
        return y 
  
    def right_rotate(self, z): 
        y = z.left 
        T3 = y.right 
        # rotation 
        y.right = z 
        z.left = T3 
        # update heights 
        z.height = 1 + max(self.get_height(z.left), self.get_height(z.right)) 
        y.height = 1 + max(self.get_height(y.left), self.get_height(y.right)) 
        # return the new root 
        return y 
  
    def get_height(self, root): 
        if root == None: 
            return 0
        return root.height 
  
    def get_balance(self, root): 
        if root == None: 
            return 0
        return self.get_height(root.left) - self.get_height(root.right) 
    
    def get_min(self, root): 
        if root == None or root.left == None: 
            return root
        return self.get_min(root.left) 

    def inorder_traverse(self, root):
        if root == None: 
            return
        self.inorder_traverse(root.left)
        print(root.data)
        self.inorder_traverse(root.right)

In [3]:
tree = AVLTree() 
root = None  
root = tree.insert(root, 1) 
root = tree.insert(root, 2) 
root = tree.insert(root, 3) 
root = tree.insert(root, 4) 
root = tree.insert(root, 5) 
root = tree.insert(root, 6) 
tree.inorder_traverse(root)

1
2
3
4
5
6


In [4]:
root = tree.remove(root, 6) 
tree.inorder_traverse(root)

1
2
3
4
5


In [5]:
root = tree.remove(root, 1) 
tree.inorder_traverse(root)

2
3
4
5
