# Trees  

* [Binary Search Tree]()
* [AVL Tree]()
* [Red-Black Tree]()


## Binary Search Tree
`O(logn)` best case time complexty, `O(n)` worst time complexity and `O(n)` space complexity

In [17]:
import random
from collections import deque

class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None

class BinarySearchTree:
    def __init__(self):
        self.root = None
    
    def insert(self, root, node):
        if self.root is None:
            self.root = node
            return
        
        if node.value <= root.value:
            if root.left is None:
                root.left = node
            else:
                self.insert(root.left, node)
        else:
            if root.right is None:
                root.right = node
            else:
                self.insert(root.right, node)
    
    def search(self, root, node):
        if root is None:
            return None
        
        if node.value < root.value:
            if root.left is not None:
                self.search(root.left, node)
        elif node.value > root.value:
            if root.right is not None:
                self.search(root.right, node)
        else:
            print('find it')
            return root
    
    def delete(self, root, node):
        if root is None:
            return root
        
        if node.value < root.value:
            if root.left is not None:
                root.left = self.delete(root.left, node)
        elif node.value > root.value:
            if root.right is not None:
                root.right = self.delete(root.right, node)
        else:
            if root.left is None and root.right is None:
                root = None
            elif root.right is None:
                root = root.left
            elif root.left is None:
                root = root.right
            else:
                node = root.right
                while node.left is not None:
                    node = node.left
                root.value = node.value
                root.right = self.delete(root.right, node)
                                
        return root
            
    
    def dfs_pre_order(self, root):
        if root:
            print(root.value)
            self.dfs_pre_order(root.left)
            self.dfs_pre_order(root.right)
    
    def dfs_in_order(self, root):
        if root:
            self.dfs_in_order(root.left)
            print(root.value)
            self.dfs_in_order(root.right)
    
    def dfs_post_order(self, root):
        if root:
            self.dfs_post_order(root.left)
            self.dfs_post_order(root.right)
            print(root.value)
    
    def bfs(self, queue, root):
        queue.append(root)
        while len(queue):
            node = queue.popleft()
            if node.left is not None:
                queue.append(node.left)
            if node.right is not None:
                queue.append(node.right)
        
            print(node.value)
    
    
    

bst = BinarySearchTree()
nums = random.sample(range(0,30), 7)
queue = deque()
print(f'nums: {nums}')
for val in nums:
    node = Node(val)
    bst.insert(bst.root, node)
bst.insert(bst.root, Node(17))
bst.dfs_in_order(bst.root)
print()
bst.root = bst.delete(bst.root, Node(17))
#bst.search(bst.root, Node(17))
bst.dfs_pre_order(bst.root)
print()
bst.dfs_in_order(bst.root)
print()
bst.dfs_post_order(bst.root)
print()
bst.bfs(queue,bst.root)


nums: [15, 22, 3, 21, 16, 2, 14]
2
3
14
15
16
17
21
22

15
3
2
14
22
21
16

2
3
14
15
16
21
22

2
14
3
16
21
22
15

15
3
22
2
14
21
16


## AVL Tree
strictly balanced binary tree, `O(logn)` best and worst cases time complexity, `O(n)` space complexity, good for search the tree because it's strictly balanced

In [11]:
import sys

class TNode:
    def __init__(self, val):
        self.value = val
        self.height = 1
        self.left = None
        self.right = None

class AVLTree:
    def __init__(self):
        self.root = None
    
    def get_height(self, root):
        if not root:
            return 0
        
        return root.height
    
    def get_balance_factor(self, root):
        if not root:
            return 0
        
        return self.get_height(root.left) - self.get_height(root.right)
    
    def left_rotate(self, root):
        z = root
        y = root.right
        x = y.left
        
        z.right = x
        y.left = z
        
        z.height = 1 + max(self.get_height(z.left), self.get_height(z.right))
        y.height = 1 + max(self.get_height(y.left), self.get_height(y.right))
        
        return y
    
    def right_rotate(self, root):
        z = root
        y = z.left
        x = y.right
        
        z.left = x
        y.right = z
        
        z.height = 1 + max(self.get_height(z.left), self.get_height(z.right))
        y.height = 1 + max(self.get_height(y.left), self.get_height(y.right))
        
        return y        
        
    def insert(self, root, val):
        if not root:
            root = TNode(val)
            return root
        
        if val >= root.value:
            root.right = self.insert(root.right, val)
        else:
            root.left = self.insert(root.left, val)
        
        root.height = 1 + max(self.get_height(root.left), self.get_height(root.right))
        bf = self.get_balance_factor(root)
        
        if bf > 1:
            if val < root.left.value:
                return self.right_rotate(root)
            else:
                root.left = self.left_rotate(root.left)
                return self.right_rotate(root)
        
        if bf < -1:
            if val > root.right.value:
                return self.left_rotate(root)
            else:
                root.right = self.right_rotat(root.right)
                return self.left_rotate(root)
        
        return root
    
    def get_min_value_node(self, root):
        if root.left is None:
            return root
        
        return self.get_min_value_node(root.left)        
    
    def delete(self, root, val):
        if not root:
            return root
        
        if val > root.value:
            root.right = self.delete(root.right, val)
        elif val < root.value:
            root.left = self.delete(root.left, val)
        else:
            if root.right is None and root.left is None:
                root = None
                return root
            elif root.right is None:
                root.value = root.left.value
                root.left = None
                return root
            elif root.left is None:
                root.value = root.right.value
                root.right = None
                return root
            
            temp = self.get_min_value_node(root.right)
            root.value = temp.value
            root.right = self.delete(root.right, temp.value)
        
        if root is None:
            return root
        
        root.height = 1 + max(self.get_height(root.right), self.get_height(root.left))
        
        bf = self.get_balance_factor(root)
        
        if bf > 1:
            if self.get_balance_factor(root.left) >= 0:
                return self.right_rotate(root)
            else:
                root.left = self.left_rotate(root.left)
                return self.right_rotate(root)
            
        if bf < -1:
            if self.get_balance_factor(root.right) < 0:
                return self.left_rotate(root)
            else: 
                root.right = self.right_roate(root.right)
                return self.left_rotate(root)
        
        return root
    
    def print_helper(self, root, indent, last):
           if root != None:
            sys.stdout.write(indent)
            if last:
                sys.stdout.write("R----")
                indent += "     "
            else:
                sys.stdout.write("L----")
                indent += "|    "
            print(root.value)
            self.print_helper(root.left, indent, False)
            self.print_helper(root.right, indent, True)
            

avl_t = AVLTree()
nums = [33, 13, 52, 9, 8] 
for val in nums: 
    avl_t.root = avl_t.insert(avl_t.root, val) 
    
avl_t.print_helper(avl_t.root, "", True)
root = avl_t.delete(avl_t.root, 9)
print('after deletion...')
avl_t.print_helper(avl_t.root, "", True)


R----33
     L----9
     |    L----8
     |    R----13
     R----52
after deletion...
R----33
     L----13
     |    L----8
     R----52


## Red-Black Tree
O(logn) time complexity and O(n) space complexity, good for expand and shrink the tree

In [None]:
class RBTreeNode:
    def __init__(self, val):
        self.val = val
        self.color = 0 # default is black
        self.left = None
        self.right = None
        self.parent = None

class RBT:
    def __init__(self):
        self.root = None
    
    def insert(self, root, val):
        if root is None:
            return RBTreeNode(val)
        
        
    

## QuadTree
O(n) time complexity to build and O(logn) space complexity

In [24]:
class Node:
    def __init__(self, val, is_leaf, top_left, top_right, bottom_left, bottom_right):
        self.val = val
        self.is_leaf = is_leaf
        self.top_left = top_left
        self.top_right = top_right
        self.bottom_left = bottom_left
        self.bottom_right = bottom_right
        
class QuadTree:
    def __init__(self, grid):
        self.grid = grid
        
    def build(self, tr, tc, br, bc):
        if tr==br-1 and tc == bc-1:
            return Node(self.grid[tr][tc], 1, None, None, None, None)
        
        mr = (tr+br)//2
        mc = (tc+bc)//2
        
        top_left = self.build(tr, tc, mr, mc)
        top_right = self.build(tr, mc, mr, bc)
        bottom_left = self.build(mr, tc, br, mc)
        bottom_right = self.build(mr, mc, br, bc)
        
        if top_left.is_leaf and top_right.is_leaf and bottom_left.is_leaf and bottom_right.is_leaf and top_left.val == top_right.val == bottom_left.val == bottom_right.val:
            return Node(top_left.val, 1, None, None, None, None)
        else:
            return Node(1, 0, top_left, top_right, bottom_left, bottom_right)
        
    def construct(self):
        n = len(self.grid)
        return self.build(0, 0, n, n)
    
def print_node(node, count):
    delimitor = '-'*count
    if node.is_leaf:
        print(f'{delimitor}{[node.is_leaf, node.val]}')
    else:
        print(f'{delimitor}{[node.is_leaf, node.val]}')
        print_node(node.top_left, count+1)
        print_node(node.top_right, count+1)
        print_node(node.bottom_left, count+1)
        print_node(node.bottom_right, count+1)
                
    
grid = [[1,1,0,0],
        [1,1,1,0],
        [0,0,1,1],
        [0,0,1,1]]

qt = QuadTree(grid)
nodes = qt.construct()
print_node(nodes, 0)    

[0, 1]
-[1, 1]
-[0, 1]
--[1, 0]
--[1, 0]
--[1, 1]
--[1, 0]
-[1, 0]
-[1, 1]


In [30]:
def intersect(qt1, qt2):
    if qt1.is_leaf:
        return qt2 if qt1.val else qt1
    elif qt2.is_leaf:
        return qt1 if qt2.val else qt2
    else:
        tl = intersect(qt1.top_left, qt2.top_left)
        tr = intersect(qt1.top_right, qt2.top_right)
        bl = intersect(qt1.bottom_left, qt2.bottom_left)
        br = intersect(qt1.bottom_right, qt2.bottom_right)
        
        if tl.is_leaf and tr.is_leaf and bl.is_leaf and br.is_leaf and tl.val == tr.val == bl.val == br.val:
            return Node(tl.val, 1, None, None, None, None)
        else:
            return Node(1, 0, tl, tr, bl, br)

def union(qt1, qt2):
    if qt1.is_leaf:
        return qt1 if qt1.val else qt2
    elif qt2.is_leaf:
        return qt2 if qt2.val else qt1
    else:
        tl = union(qt1.top_left, qt2.top_left)
        tr = union(qt1.top_right, qt2.top_right)
        bl = union(qt1.bottom_left, qt2.bottom_left)
        br = union(qt1.bottom_right, qt2.bottom_right)
        
        if tl.is_leaf and tr.is_leaf and bl.is_leaf and br.is_leaf and tl.val == tr.val == bl.val == br.val:
            return Node(tl.val, 1, None, None, None, None)
        else:
            return Node(1, 0, tl, tr, bl, br)

grid1 = [[1,0,0,1],
         [1,1,0,0],
         [0,0,1,1],
         [0,0,1,1]]

grid2 = [[1,1,0,0],
         [1,1,1,0],
         [0,0,1,1],
         [0,0,1,1]]

qt1 = QuadTree(grid1)
qt2 = QuadTree(grid2)
qt1_nodes = qt1.construct()
qt2_nodes = qt2.construct()

print('intersect:')
intersect_nodes = intersect(qt1_nodes, qt2_nodes)
print_node(intersect_nodes, 0)

print('union:')
union_nodes = union(qt1_nodes, qt2_nodes)
print_node(union_nodes, 0)

    

intersect:
[0, 1]
-[0, 1]
--[1, 1]
--[1, 0]
--[1, 1]
--[1, 1]
-[1, 0]
-[1, 0]
-[1, 1]
union:
[0, 1]
-[1, 1]
-[0, 1]
--[1, 0]
--[1, 1]
--[1, 1]
--[1, 0]
-[1, 0]
-[1, 1]
