#### Creating a Binary Tree

In [153]:
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
    
    def add(self, data):
        if self.data == data:
            return
        
        # If given data or value is smaller than root value
        # Then traverse Left Link
        if data < self.data:
            # left Node is empty then make a BinarySubTree
            if self.left is None:
                self.left = Node(data)
                return
            
            else:
                # If left Node is not empty then add below to it
                self.left.add(data)
                self.left = self.left.fixImbalanceIfExists()
                
        # If given data or value is greater than root value
        # Then Traverse Right link
        if data > self.data:
            if self.right is None:
                self.right = Node(data)
                return
            else:
                # If Right node is empty then make a BinarySubTree
                self.right.add(data)
                self.right = self.right.fixImbalanceIfExists()
            
        
        if self.left and self.data > data:
            return self.left.add(data)
        if self.right and self.data < data:
            return self.right.add(data)
    
    def findMin(self):
        if self.left:
            return self.left.findMin()
        return self.data
    
    def delete(self, target):
        if self.data == target:
            if self.left and self.right:
                minValue = self.right.findMin()
                self.data = minValue
                self.right = self.right.delete(minValue)
                return self
            else:
                return self.left or self.right
        
        if self.right and target > self.data:
            self.right = self.right.delete(target)
        
        if self.left and target < self.data:
            self.left = self.left.delete(target)
        
        return self.fixImbalanceIfExists()
    
    def isBalanced(self):
        leftHeight = self.left.height()+1 if self.left else 0
        rightHeight = self.right.height()+1 if self.right else 0
        return abs(leftHeight - rightHeight) < 2
    
    def toStr(self):
        if not self.isBalanced():
            return str(self.data)+'*'
        return str(self.data)
    
    # Remember small values always goes to left
    # large values always goes to right
    def search(self, target):
        if self.data == target:
            print("found it")
            return self
        if self.left and self.data > target:
            return self.left.search(target)
        if self.right and self.data < target:
            return self.right.search(target)
        
        print("Value is not in trees")
    
    # Start with root or parent node first
    # Traverse all left node first then right
    def traversePreorder(self):
        print(self.data)
        if self.left:
            self.left.traversePreorder()
        
        if self.right:
            self.right.traversePreorder()
        
    # Start with smallest node
    # Start with left leaf node
    def traverseInorder(self):
        if self.left:
            self.left.traverseInorder()
        print(self.data)
        if self.right:
            self.right.traverseInorder()
    
    # Start with left most node
    # visit left node to right
    # vist root node in the end
    def traversePostorder(self):
        if self.left:
            self.left.traversePostorder()
        if self.right:
            self.right.traversePostorder()
        print(self.data)
    
    ## Remember height is important becoz it determines the maximum runtime for the search of tree.
    ## if leaf node is reached function return kicks his height back up to the stack and 
    ## node above it return deepest height on either side.
    def height(self, h =0):
        leftHeight = self.left.height(h+1) if self.left else h
        rightHeight = self.right.height(h+1) if self.right else h
        return max(leftHeight, rightHeight)
    
    # Get the Nodes at particular Depth
    def getNodesAtDepth(self, depth, nodes=[]):
        if depth == 0:
            nodes.append(self)
            return nodes
        if self.left:
            self.left.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
        if self.right:
            self.right.getNodesAtDepth(depth-1, nodes)
        else:
            nodes.extend([None]*2**(depth-1))
        return nodes
    
    # Show the Right Node and left node difference
    def getLeftRightHeightDifference(self):
        leftHeight = self.left.height()+1 if self.left else 0
        rightHeight = self.right.height()+1 if self.right else 0
        return leftHeight - rightHeight
    
    # Fixes itself if found the imbalance
    def fixImbalanceIfExists(self):
        if self.getLeftRightHeightDifference() > 1:
            # Left imbalance
            if self.left.getLeftRightHeightDifference() > 0:
                # left left imbalance
                return rotateRight(self)
            else:
                # left right imbalance
                self.left = rotateLeft(self.left)
                return rotateRight(self)
            
        elif self.getLeftRightHeightDifference() < -1:
            # right imbalance
            if self.unbalancedLeftLeftright.getLeftRightHeightDifference() < 0:
                # right right
                return rotateLeft(self)
            else:
                # right left
                self.right = rotateRight(self.right)
                return rotateLeft(self)
        return self
    
    # Recursively call fixImbalanceIfExists to re-balance itself
    def rebalance(self):
        if self.left:
            self.left.rebalance()
            self.left = self.left.fixImbalanceIfExists()
        if self.right:
            self.right.rebalance()
            self.right = self.right.fixImbalanceIfExists()
    
        

In [154]:
class Tree:
    def __init__(self, root, name=''):
        self.root = root
        self.name = name
    
    def search(self, target):
        return self.root.search(target)
    
    def traversePreorder(self):
        self.root.traversePreorder()
    
    def traverseInorder(self):
        self.root.traverseInorder()
    
    def traversePostorder(self):
        self.root.traversePostorder()
    
    def height(self):
        return self.root.height()
    
    def add(self, data):
        self.root.add(data)
        self.root = self.root.fixImbalanceIfExists()
    
    def delete(self, target):
        self.root = self.root.delete(target)
    
    def rebalance(self):
        self.root.rebalance()
        self.root = self.root.fixImbalanceIfExists()
    
    def getNodesAtDepth(self, depth):
        return self.root.getNodesAtDepth(depth)
    
    def _nodeToChar(self, n, spacing):
        if n is None:
            return '_'+(' '*spacing)
        spacing = spacing-len(n.toStr())+1
        return n.toStr()+(' '*spacing)
    
    def print(self, label=''):
        print(self.name+' '+label)
        height = self.root.height()
        spacing = 3
        width = int((2**height-1) * (spacing+1) + 1)
        # Root offset
        offset = int((width-1)/2)
        for depth in range(0, height+1):
            if depth > 0:
                # print directional lines
                print(' '*(offset+1) + (' '*(spacing+2)).join(['/'+ (' '*(spacing-2))+'\\']*(2**(depth-1))))
            row = self.root.getNodesAtDepth(depth, [])
            print((' '*offset)+''.join([self._nodeToChar(n, spacing) for n in row]))
            spacing = offset+1
            offset = int(offset/2)  -1
        print('')

In [155]:
### Rotate Left
def rotateLeft(root):
    pivot = root.right
    reattachNode = pivot.left
    root.right = reattachNode
    pivot.left = root
    return pivot
    

In [156]:
#### Rotate Right
def rotateRight(root):
    pivot = root.left
    reattachNode = pivot.right
    root.left = reattachNode
    pivot.right = root
    return pivot

In [157]:
# Adding a data by add method
tree = Tree(Node(50))
tree.add(25)
tree.add(75)
tree.add(10)
tree.add(35)
tree.add(30)
tree.add(5)
tree.add(2)
tree.add(1)
tree.print()

 
              35  
       /             \
      5               50              
   /     \         /     \
  2       25      _       75      
 / \     / \     / \     / \
1   _   10  30  _   _   _   _   



In [158]:
# Deleting a children with value 30
tree.delete(25)
tree.print()

 
              35  
       /             \
      5               50              
   /     \         /     \
  2       30      _       75      
 / \     / \     / \     / \
1   _   10  _   _   _   _   _   



In [159]:
# Traverse Pre-Order
tree.traversePreorder()

35
5
2
1
30
10
50
75


In [160]:
# Traverse In-Order
tree.traverseInorder()

1
2
5
10
30
35
50
75


In [161]:
# Traverse Post-Order
tree.traversePostorder()

1
2
10
30
5
75
50
35


#### Getting The Height of a Tree

In [162]:
tree.root.height()

3

#### Getting All Nodes At Particular Depth

In [163]:
nodes = []
# tree.root.getNodesAtDepth(2, nodes)
[ _.data  for _ in tree.root.getNodesAtDepth(3, nodes) if _ is not None ]

[1, 10]

In [164]:
UnbalancedLeftLeft = Tree(Node(30), 'unbalanced left left')
UnbalancedLeftLeft.root.left = Node(20)
UnbalancedLeftLeft.root.left.right = Node(21)
UnbalancedLeftLeft.root.left.left = Node(10)
UnbalancedLeftLeft.root.left.left.left = Node(9)
UnbalancedLeftLeft.root.left.left.right = Node(11)
UnbalancedLeftLeft.print()

unbalanced left left 
              30* 
       /             \
      20              _               
   /     \         /     \
  10      21      _       _       
 / \     / \     / \     / \
9   11  _   _   _   _   _   _   



In [166]:
# UnbalancedLeftLeft.root
UnbalancedLeftLeft.root = rotateRight(UnbalancedLeftLeft.root)

In [167]:
UnbalancedLeftLeft.print()

unbalanced left left 
      20  
   /     \
  10      30      
 / \     / \
9   11  21  _   



In [181]:
unbalanced_right_right = Tree(Node(10), 'Unbalanced right right')
unbalanced_right_right.root.right = Node(20)
unbalanced_right_right.root.right.left = Node(19)
unbalanced_right_right.root.right.right = Node(30)
unbalanced_right_right.root.right.right.left = Node(29)
unbalanced_right_right.root.right.right.right = Node(31)
unbalanced_right_right.print()

Unbalanced right right 
              10* 
       /             \
      _               20              
   /     \         /     \
  _       _       19      30      
 / \     / \     / \     / \
_   _   _   _   _   _   29  31  



In [182]:
unbalanced_right_right.root = rotateLeft(unbalanced_right_right.root)
unbalanced_right_right.print()

Unbalanced right right 
      20  
   /     \
  10      30      
 / \     / \
_   19  29  31  



In [183]:
unbalanced_left_right = Tree(Node(30), 'Unbalanced left right')
unbalanced_left_right.root.right = Node(31)
unbalanced_left_right.root.left = Node(10)
unbalanced_left_right.root.left.right = Node(20)
unbalanced_left_right.root.left.left = Node(9)
unbalanced_left_right.root.left.right.left = Node(19)
unbalanced_left_right.root.left.right.right = Node(21)
unbalanced_left_right.print()

Unbalanced left right 
              30* 
       /             \
      10              31              
   /     \         /     \
  9       20      _       _       
 / \     / \     / \     / \
_   _   19  21  _   _   _   _   



In [184]:
unbalanced_left_right.root.left = rotateLeft(unbalanced_left_right.root.left)
unbalanced_left_right.root = rotateRight(unbalanced_left_right.root)
unbalanced_left_right.print()

Unbalanced left right 
      20  
   /     \
  10      30      
 / \     / \
9   19  21  31  



In [185]:
unbalance_right_left = Tree(Node(30), 'UNBALANCED RIGHT LEFT')
unbalance_right_left.root.left = Node(31)
unbalance_right_left.root.right = Node(10)
unbalance_right_left.root.right.left = Node(20)
unbalance_right_left.root.right.right = Node(9)
unbalance_right_left.root.right.left.right = Node(19)
unbalance_right_left.root.right.left.left = Node(21)
unbalance_right_left.print()

UNBALANCED RIGHT LEFT 
              30* 
       /             \
      31              10              
   /     \         /     \
  _       _       20      9       
 / \     / \     / \     / \
_   _   _   _   21  19  _   _   



In [186]:
unbalance_right_left.root.right = rotateRight(unbalance_right_left.root.right)
unbalance_right_left.root = rotateLeft(unbalance_right_left.root)
unbalance_right_left.print()

UNBALANCED RIGHT LEFT 
      20  
   /     \
  30      10      
 / \     / \
31  21  19  9   



In [187]:
# Fix the imbalance if exists
tree = Tree(Node(50))
tree.root.left = Node(25)
tree.root.right = Node(75)
tree.root.left.left = Node(10)
tree.root.left.right = Node(35)
tree.root.left.right.left = Node(30)
tree.root.left.left.left = Node(5)
tree.root.left.left.right = Node(13)
tree.print()
tree.rebalance()
tree.print()

 
              50* 
       /             \
      25              75              
   /     \         /     \
  10      35      _       _       
 / \     / \     / \     / \
5   13  30  _   _   _   _   _   

 
              35  
       /             \
      25              50              
   /     \         /     \
  10      30      _       75      
 / \     / \     / \     / \
5   13  _   _   _   _   _   _   

