# Binary Search Tree

In [23]:
class Node(object):
    def __init__(self, data):
        self._left, self._right = None, None
        self.data = data
        
    def __repr__(self):
        return 'Node({})'.format(self.data)
   
    @property
    def left(self):
        return self._left
    
    @left.setter
    def left(self, node):
        self._left = node
    
    @property
    def right(self):
        return self._right
    
    @right.setter
    def right(self, node):
        self._right = node
    
class BinarySearchTree(object):        
    def __init__(self, root=None):
        self.root = root
        self.search_mode = 'in_order'
        
            
    # O(logN) time complexity if balanced, it could reduce to O(N)
    def insert(self, data, **kwargs):     
        """Insert from root"""
        BinarySearchTree.insert_node(self.root, data, **kwargs)
        
    # O(logN) time complexity if balanced, it could reduce to O(N)
    def remove(self, data):     
        """Insert from root"""
        BinarySearchTree.remove_node(self.root, data)
    
    @staticmethod
    def insert_node(node, data, **kwargs):
        node_consturctor = kwargs.get('node_constructor', None) or Node
        if node:
            if data < node.data:
                if node.left is None:
                    node.left = node_consturctor(data)
                else:
                    BinarySearchTree.insert_node(node.left, data, **kwargs)
            elif data > node.data:
                if node.right is None:
                    node.right = node_consturctor(data)
                else:
                    BinarySearchTree.insert_node(node.right, data, **kwargs)
        else:
            node.data = data
        return node
         
    @staticmethod
    def remove_node(node, data):            

        if not node:
            return None
        
        if data < node.data:
            node.left = BinarySearchTree.remove_node(node.left, data)
        elif data > node.data:
            node.right = BinarySearchTree.remove_node(node.right, data)
        else:
            if not (node.left and node.right):  # leaf
                del node
                return None
            if not node.left:
                tmp = node.right
                del node
                return tmp
            if not node.right:
                tmp = node.left
                del node
                return tmp
            predeccessor = BinarySearchTree.get_max_node(node.left)
            node.data = predeccessor.data
            node.left = BinarySearchTree.remove_node(node.left, predeccessor.data)
        return node
            
    def get_min(self):
        return self.get_min_node(self.root)
    
    @staticmethod
    def get_min_node(node):
        if node.left:
            return BinarySearchTree.get_max_node(node.left)
        return node
        
    def get_max(self):
        return self.get_max_node(self.root)
    
    @staticmethod
    def get_max_node(node):
        if node.right:
            return BinarySearchTree.get_max_node(node.right)
        return node
             
    def search_decorator(func):
        def interface(*args, **kwargs):
            res = func(*args, **kwargs)
            if isinstance(res, Node):
                return res
            elif 'data' in kwargs:
                for node in res:
                    if node.data == kwargs['data']:
                        return node   
            return res
        return interface
    
    @staticmethod
    @search_decorator
    def in_order(root, **kwargs):
        """left -> root -> right"""
        f = BinarySearchTree.in_order
        res = []
        if root:
            left = f(root.left, **kwargs)
            if isinstance(left, Node):
                return left
            right = f(root.right, **kwargs)
            if isinstance(right, Node):
                return right
            res = left + [root] + right
        return res

    @staticmethod
    @search_decorator
    def pre_order(root, **kwargs):
        """root -> left -> right"""
        f = BinarySearchTree.pre_order
        res = []
        if root:
            left = f(root.left, **kwargs)
            if isinstance(left, Node):
                return left
            right = f(root.right, **kwargs)
            if isinstance(right, Node):
                return right
            res = [root] + left + right      
        return res

    @staticmethod
    @search_decorator
    def post_order(root, **kwargs):
        """root -> right -> root"""
        f = BinarySearchTree.post_order
        res = []
        if root:
            left = f(root.left, **kwargs)
            if isinstance(left, Node):
                return left
            right = f(root.right, **kwargs)
            if isinstance(right, Node):
                return right
            res = left + right + [root]
        return res
    
    def traversal(self, 
                  order:"in_order|post_order|post_order"=None,
                  data=None):
        order = order or self.search_mode
        if order == 'in_order':
            return BinarySearchTree.in_order(self.root, data=data)
        elif order == 'pre_order':
            return BinarySearchTree.pre_order(self.root, data=data)
        elif order == 'post_order':
            return BinarySearchTree.post_order(self.root, data=data)
        else:
            raise NotImplementedError()
            
    def search(self, data, *args, **kwargs):
        return self.traversal(*args, data=data, **kwargs)
    

In [24]:
bt = BinarySearchTree(Node(10))
bt.insert(20)
bt.insert(15)
bt.insert(2)
bt.insert(66)
bt.insert(24)
bt.insert(17)
bt.insert(21)
bt.root.right

Node(20)

In [3]:
bt.traversal('in_order')

[Node(2), Node(10), Node(15), Node(17), Node(20), Node(21), Node(24), Node(66)]

In [4]:
bt.traversal('post_order')

[Node(2), Node(17), Node(15), Node(21), Node(24), Node(66), Node(20), Node(10)]

In [5]:
bt.traversal('pre_order')

[Node(10), Node(2), Node(20), Node(15), Node(17), Node(66), Node(24), Node(21)]

In [6]:
bt.search(data=17)

Node(17)

In [7]:
bt.get_max()

Node(66)

In [8]:
bt.get_min()

Node(2)

In [9]:
bt.remove(15)
bt.traversal()

[Node(2), Node(10), Node(20), Node(21), Node(24), Node(66)]

# AVL Tree

In [25]:
class HNode(Node):    
    def __init__(self, *args, **kwargs):
        super(HNode, self).__init__(*args, **kwargs)
        self._height = 0
        
    def __repr__(self):
        return 'HNode({})'.format(self.data)
    
    @property
    def height(self):
        return self._height
    
    def set_height(self):        
        if self.left is None and self.right is None:
            self._height = 0
        else:
            self._height = max(self.left_height, self.right_height) + 1
        return self._height


    @Node.left.setter
    def left(self, node):
        self._left = node
        self.set_height()
            
    @Node.right.setter
    def right(self, node):
        self._right = node
        self.set_height()
        
    @property
    def sub_diff(self):
        return self.left_height - self.right_height 
    
    @property
    def left_height(self):
        if self.left:
            return self.left.height
        return -1
    
    @property
    def right_height(self):
        if self.right:
            return self.right.height
        return -1
    
    @property
    def is_balance(self):
        return abs(self.sub_diff) <= 1        
        
    def balance(self, data):
        
        if self.sub_diff > 1:
            if data < self.left.data:  # left left heavy
                return self.rotate('right')
            if data > self.left.data:  # left right heavy
                self.left = self.left.rotate('left')
                return self.rotate('right')
            
        if self.sub_diff < -1:
            if data > self.right.data:
                return self.rotate('left')  # right right heavy
            if data < self.right.data:  # right left heavy
                self.right = self.right.rotate('right')
                return self.rotate('left')
            
        return self
        
    def rotate(self, to:"left|right"):
        if to == 'right':
            tmp = self.left
            tmp_right = tmp.right
            # update
            tmp.right = self
            self.left = tmp_right        
            print('Node {} right rotate to {}!'.format(self, tmp))
            return tmp  # return new root
        if to == 'left':
            tmp = self.right
            tmp_left = tmp.left
            # update
            tmp.left = self
            self.right = tmp_left
            print('Node {} left rotate to {}!'.format(self, tmp))
            return tmp  # return new root
        raise NotImplementedError()
            
class AVLTree(BinarySearchTree):    
    def __init__(self, *args, **kwargs):
        super(AVLTree, self).__init__(*args, **kwargs)
        
    def insert(self, data):    
        AVLTree.insert_node(self.root, data, tree=self)  # pass self as keyword argument to update self.root
        self.update_height()
        
    def remove(self, data):
        AVLTree.remove_node(self.root, data, tree=self)  # pass self as keyword argument to update self.root
        self.update_height()
    
    def rotate_decorator(func):
        def interface(*args, **kwargs):
            node = func(*args, **kwargs)
            
            data = args[1]
            tree = kwargs.get('tree')
                        
            new_root = node.balance(data)
            
            if node == tree.root:
                tree.root = new_root
                    
        return interface
    
    def update_height(self):
        for n in self.traversal(order='in_order'):
            n.set_height()
    
    @property
    def is_balance(self):
        return self.root.is_balance
    
    @rotate_decorator
    def insert_node(*args, **kwargs):
        return BinarySearchTree.insert_node(*args, node_constructor=HNode, **kwargs)
   
    @rotate_decorator
    def remove_node(*args, **kwargs):
        return BinarySearchTree.remove_node(*args, **kwargs) 

In [11]:
bt = AVLTree(HNode(1))
bt.insert(2)
bt.insert(3)
bt.insert(4)
bt.insert(5)
bt.insert(6)

Node HNode(1) left rotate to HNode(2)!
Node HNode(2) left rotate to HNode(3)!


In [12]:
bt.root

HNode(3)

In [13]:
bt.root.height

2

In [14]:
bt.traversal()

[HNode(1), HNode(2), HNode(3), HNode(4), HNode(5), HNode(6)]

In [15]:
bt.is_balance

True

In [16]:
bt = AVLTree(HNode(1))
bt.insert(5)
bt.insert(3)

Node HNode(5) right rotate to HNode(3)!
Node HNode(1) left rotate to HNode(3)!


In [17]:
bt.root

HNode(3)

In [18]:
bt.root.left

HNode(1)

In [19]:
bt.root.right

HNode(5)

In [26]:
bt = AVLTree(HNode('D'))
bt.insert('B')
bt.insert('E')
bt.insert('A')
bt.insert('C')

In [28]:
bt.is_balance

True