# Binary Search Tree

In [9]:
class Node:
    def __init__(self, val):
        self.key = val
        self.lchild = None
        self.rchild = None

class Tree:
    def __init__(self):
        self.root = None
        
    def insert(self,val):
        n = Node(val)
        if self.is_empty():
            self.root = n
        else:
            tmp = self.root
            prev_tmp = None
            while True:
                prev_tmp = tmp
                if n.key <= tmp.key:
                    tmp = tmp.lchild
                    if tmp is None:
                        prev_tmp.lchild = n
                        break
                else:
                    tmp = tmp.rchild
                    if tmp is None:
                        prev_tmp.rchild = n
                        break
    
    def remove(self, val):
        # Replace with the left most child in the right subtree of nodeto be deleted 
        tmp = self.root
        prev_tmp = None
        while tmp is not None:
            prev_tmp = tmp
            if val <= tmp.key:
                tmp = tmp.lchild
                if tmp is not None and tmp.key == val:
                    if tmp.lchild is None and tmp.rchild is None: # Leaf node
                        prev_tmp.lchild = None
                        break
                    #one of the child nodes is None
                    elif tmp.lchild is None:
                        prev_tmp.lchild = tmp.rchild
                    elif tmp.rchild is None:
                        prev_tmp.lchild = tmp.lchild
                    #del node has both children
                    else:
                        ptr = tmp.rchild #right subtree
                        prev_ptr = tmp
                        while ptr.lchild is not None: # find sucessor left most child of right subtree
                            prev_ptr = ptr
                            ptr = ptr.lchild

                        ptr.lchild = tmp.lchild
                        prev_tmp.lchild = ptr #copy the value from successor
                        #prev_ptr.lchild = None #delete the node
            else:
                tmp = tmp.rchild            
                if tmp is not None and tmp.key == val:
                    if tmp.lchild is None and tmp.rchild is None: # Leaf node
                        prev_tmp.rchild = None
                        break
                    #one of the child nodes is None
                    elif tmp.lchild is None:
                        prev_tmp.rchild = tmp.rchild
                    elif tmp.rchild is None:
                        prev_tmp.rchild = tmp.lchild  
                    #del node has both children
                    else:
                        ptr = tmp.rchild #right subtree
                        prev_ptr = tmp
                        while ptr.lchild is not None: # find sucessor left most child of right subtree
                            prev_ptr = ptr
                            ptr = ptr.lchild

                        ptr.lchild = tmp.lchild
                        prev_tmp.rchild = ptr
                        #tmp.key = ptr.key #copy the value from successor
                        #prev_ptr.lchild = None #delete the node


    def is_empty(self):
        return self.root is None
        
        
    def in_order(self):
        self.in_print(self.root)
        
    def in_print(self,ptr):
        if ptr is None:
           return
        else:
            self.in_print(ptr.lchild)
            print(ptr.key)
            self.in_print(ptr.rchild)
            
    def pre_order(self):
        self.pre_print(self.root)
        
    def pre_print(self,ptr):
        if ptr is None:
            return
        else:
            print(ptr.key)
            self.pre_print(ptr.lchild)
            self.pre_print(ptr.rchild)

    def post_order(self):
        self.post_print(self.root)
        
    def post_print(self,ptr):
        if ptr is None:
            return
        else:
            self.post_print(ptr.lchild)
            self.post_print(ptr.rchild)
            print(ptr.key)

            
nodes = [4,6,1,2,7,9,5,3]
t = Tree()
for i in nodes:
    t.insert(i)
t.in_order()
print("preorder")
t.pre_order()
print("post_order")
t.post_order()
t.remove(6)
print("After removal")
t.in_order()

#           4
#      1       6
#        2    5  7
#          3      9

1
2
3
4
5
6
7
9
preorder
4
1
2
3
6
5
7
9
post_order
3
2
1
5
9
7
6
4
After removal
1
2
3
4
5
7
9


# Implementing AVL tree

Notes: 
Avoid using parent pointer in insertion

-----------------------------------------------------------------------
store the balance factor in the node itself as an int that's either:

-1: the node's left subtree is a level higher than the right one (left-heavy)
0 the node is balanced; or
1 the right subtree is higher (right-heavy).
You insert(Node subtree) function returns a boolean, which is true if the insertion made the height of subtree increased. You update the balance factor and rebalance the tree as you return from the recursive insert() calls.

This is probably best explained with a few examples:

If the current node is at balance factor -1, you're inserting into the right subtree, and insert(rchild) returns true, you:

update the balance factor of the current node to 0 - the left subtree was higher before the insertion, and the right subtree's height increased, so they're the same height now; and
return false - the shallower tree's height increased, so the current node's height stays the same
If you're inserting into either subtree, and insert(…) returns false:

the balance factor of the current node is unchanged - the subtree heights are the same as before, and so is the balance
return false - the subtree heights haven't changed, so neither has the current node height
If the current node's balance factor is 0, you're inserting into the left subtree, and insert(lchild) returns true:

the balance factor changes to -1 - the subtrees were the same height before the insertion, and the insertion made the left one higher
return true
(Analogously, if inserting into the right subtree, the balance factor will change to 1.)

If the current node's balance factor is -1, you're inserting into the left subtree, and insert(lchild) returns true:

The balance factor would change to -2, which means you have to rebalance the node by doing the appropriate rotation. I'll admit I'm drawing a blank at what each of the four rotations will do to the balance factor and what insert(current) will return, hopefully the previous examples explain the approach to tracking the nodes' balance sufficiently.

-------------------------------------------------------------

It is cleaner to use height function instead of manuall trying to calculate and update balance factor after every insertion

In [2]:
class Node:
    def __init__(self, val):
        self.key = val
        self.bal_factr = 0  # height of left subtree - height of right subtree
        self.lchild = None
        self.rchild = None

        
class AVLTree:
    def __init__(self):
        self.root = None

    def set_bal_factr(self, nd):
        nd.bal_factr = self.height(nd.lchild) - self.height(nd.rchild)

    def insert(self, key):
        if self.root is None:
            self.root = Node(key)
        else:
            #self.root.bal_factr += self.traverse(key, self.root, None, None)
            self.traverse(key, self.root, None)
            print("inserted {}".format(key))
            print("height is {}".format(self.height(self.root)))

    def traverse(self, key, ptr, parent):
        # insert value, current_pointer, parent_pointer, left-right_flag for parent
        if key <= ptr.key:
            if ptr.lchild is None:
                ptr.lchild = Node(key)
            else:
                self.traverse(key, ptr.lchild, ptr)  # if traverse return true then height has changed
                if ptr.bal_factr > 1:  # L case
                    if ptr.lchild.bal_factr < 0:  # R case
                        self.LRRotation(ptr, parent,'l')
                    elif ptr.lchild.bal_factr > 0:  # L case
                        self.LLRotation(ptr, parent, 'l')
        else:
            if ptr.rchild is None:
                ptr.rchild = Node(key)
            else:
                self.traverse(key, ptr.rchild, ptr)
                if ptr.bal_factr < -1:  # R case
                    if ptr.rchild.bal_factr < 0:  # R case
                        self.RRRotation(ptr, parent,'r')
                    elif ptr.rchild.bal_factr > 0:  # L case
                        self.RLRotation(ptr, parent,'r')
        self.set_bal_factr(ptr) #calculate the balance factor after the changes to that node liniage

    def LLRotation(self, current, parent, flag):
        new_root = current.lchild
        current.lchild = new_root.rchild
        new_root.rchild = current
        if parent is not None:
            if flag == 'l':
                parent.lchild = new_root
            elif flag == 'r':
                parent.rchild = new_root
        else:
            self.root = new_root
            
        self.set_bal_factr(new_root) #calculate the balance factor after the changes to that node liniage

    def RRRotation(self, current, parent, flag):
        new_root = current.rchild
        current.rchild = new_root.lchild
        new_root.lchild = current

        if parent is not None:
            if flag == 'l':
                parent.lchild = new_root
            elif flag == 'r':
                parent.rchild = new_root
        else:
            self.root = new_root

        self.set_bal_factr(new_root) #calculate the balance factor after the changes to that node liniage

    def LRRotation(self, current, parent, flag):
        new_root = current.lchild.rchild

        current.lchild.rchild = new_root.lchild
        new_root.lchild = current.lchild

        current.lchild = new_root.rchild
        new_root.rchild = current

        if parent is not None:
            if flag == 'l':
                parent.lchild = new_root
            elif flag == 'r':
                parent.rchild = new_root
        else:
            self.root = new_root

        self.set_bal_factr(new_root) #calculate the balance factor after the changes to that node liniage

    def RLRotation(self, current, parent, flag):
        new_root = current.rchild.lchild

        current.rchild.lchild = new_root.rchild
        new_root.rchild = current.rchild

        current.rchild = new_root.lchild
        new_root.lchild = current

        if parent is not None:
            if flag == 'l':
                parent.lchild = new_root
            elif flag == 'r':
                parent.rchild = new_root
        else:
            self.root = new_root

        self.set_bal_factr(new_root) #calculate the balance factor after the changes to that node liniage

    def height(self, ptr):
        if ptr is None:
            return 0
        else:
            return max(self.height(ptr.lchild), self.height(ptr.rchild)) + 1

    def is_empty(self):
        return self.root is None

    def in_order(self):
        self.in_print(self.root)
        
    def remove(self,key):
        if self.root.key == key:
            self.root = None
        else:
            #self.root.bal_factr += self.traverse(key, self.root, None, None)
            self.rm_traverse(key, self.root, None, None)
            
    def rm_traverse(self,key,ptr,parent,flag):
        if ptr.key == key:
            if ptr.lchild == None:
                if ptr.rchild == None # No child nodes
                    if flag == 'l':
                        parent.lchild = None
                    elif flag == 'r':
                        parent.rchild = None
                else: # Right child Node
                    if flag == 'l':
                        parent.lchild = ptr.rchild
                    elif flag == 'r':
                        parent.rchild = ptr.rchild
            elif ptr.rchild == None: #Left child Node
                if flag == 'l':
                    parent.lchild = ptr.lchild
                elif flag == 'r':
                    parent.rchild = ptr.lchild
            else: # ptr has two children
                new_node = self.minimum(ptr.rchild)
                
            
        elif key < ptr.key:
            rm_traverse(key,ptr.lchild,ptr,'l')
            
            

    def in_print(self, ptr):
        if ptr is None:
            return
        else:
            self.in_print(ptr.lchild)
            print(ptr.key)
            self.in_print(ptr.rchild)


# nodes = [4,6,1,2,7,9,5,3]
#nodes = [1, 2, 3, 4, 5, 6, 7, 8, 9]
nodes = [9,8,7,6,5,4,3,2,1]
t = AVLTree()
for i in nodes:
    t.insert(i)

t.in_order()
print("--test--")
print(t.root.key)
print(t.root.lchild.key)
print(t.root.rchild.key)


inserted 8
height is 2
inserted 7
height is 3
inserted 6
height is 3
inserted 5
height is 4
inserted 4
height is 3
inserted 3
height is 4
inserted 2
height is 4
inserted 1
height is 5
1
2
3
4
5
6
7
8
9
--test--
6
4
8
