In [15]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

<IPython.core.display.Javascript object>

In [64]:
class BSTNode:
    
    @classmethod
    def construct_from_list(cls, bst_list, print_step = True):
        print(f'Input List:\n {bst_list}')
        bst = BSTNode(bst_list[0], print_step=print_step)
        if len(bst_list) > 1:
            for e in bst_list:                
                bst.insert(e)
        return bst 
    
    def __init__(self, val, parent = None, child_type = None, print_step = False):
        self.val = val
        self.left = None
        self.right = None
        self.parent = parent
        self.child_type = child_type
        self.print_step = print_step
        
    def insert(self, val):

        if self.val:
            if val < self.val:
                if self.left is None:
                    self.left = BSTNode(val, self, 'left')
                else:
                    self.left.insert(val)
            elif val > self.val:
                if self.right is None:
                    self.right = BSTNode(val, self, 'right')
                else:
                    self.right.insert(val)
        else:
            self.val = val
        
        if self.print_step:
            print(f'\nInserting {val}\n')
            self.print_tree()
            
            
    def find_element(self, val):
        if self.val:
            if self.val == val:               
                return self
            elif self.val < val:
                if self.right:
                    return self.right.find_element(val)
                else:
                    return None
            else:
                if self.left:
                    return self.left.find_element(val)
                else:
                    return None
        else:
            return None
    def find_maximum(self):
        x = self
        while x.right:
            x = x.right
        return x
    
    def find_minimum(self):
        x = self
        while x.left:
            x = x.left
        return x
        
    def delete(self, val):
        node = self.find_element(val)
        if self.print_step:
            self.print_tree()
        if node:
            # case - 1: node is a leaf node
            if not node.left and not node.right:
                if self.print_step:
                    print(f'\nNode {val} has no children(leaf)')
                node.del_node()
            # case - 2: node has single child
            elif not node.left and node.right: 
                if self.print_step:
                    print(f'\nNode {val} has single child (right)')
                node.right.short_circuit()
                    
            elif node.left and not node.right:
                if self.print_step:
                    print(f'\nNode {val} has single child (left)')
                node.left.short_circuit()
                #node.del_node()
            # case - 3: node has both children
            elif node.left and node.right:
                if self.print_step:
                    print(f'\nNode {val} has both children')
                    print(f'\nFinding successor of {val}')
                successor = self.find_successor()                
                node.val = successor.val
                if self.print_step:
                    print(f'\nSuccessor is {successor.val}')
                    print(f'\nReplacing {val} with {successor.val}\n')
                    self.print_tree()
                if self.print_step:
                    print(f'\nDeleting {successor.val}')
                successor.delete(successor.val)
        else:
            print(f'element {val} not found in the tree')
            
        if self.print_step:
            self.print_tree()
            
    def short_circuit(self):  
        if self.print_step:
             print(f'Short circuiting  {self.val}')
        if self.parent.child_type == 'right':            
            self.parent.parent.right = self
            self.child_type = 'right'
        elif self.parent.child_type == 'left':             
            self.parent.parent.left = self
            self.child_type = 'left'
        self.parent = self.parent.parent
                    
    def find_successor(self):
        if self.right:
            return self.right.find_minimum()
        else:
            return self
    def find_predessor(self):
        if self.left:
            return self.left.find_maximum()
        else:
            return self
    
    def go_left(self):
        if self.left:
            return self.left.go_left()
        else:
            return self
    
    def del_node(self):
        if self.print_step:
            print(f'\nRemoving Node {val}')
        self.val = None
        if self.child_type:
            if self.child_type == 'left':
                self.parent.left = None
            else:
                self.parent.right = None
        self.parent = None
    
    def print_tree(self):
        lines, *_ = self.display()
        for line in lines:
            print(line)
            
        
    def display(self):
        
        # No child.
        if not self.left and not self.right:
            line = str(self.val)
            width = len(line)
            height = 1
            middle = width // 2
            return [line], width, height, middle

        # Only left child.
        elif self.left and not self.right:
            lines, n, p, x = self.left.display()
            s = str(self.val)
            u = len(s)
            first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s
            second_line = x * ' ' + '/' + (n - x - 1 + u) * ' '
            shifted_lines = [line + u * ' ' for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, n + u // 2

        # Only right child.
        elif not self.left and self.right:
            lines, n, p, x = self.right.display()
            s = str(self.val)
            u = len(s)
            first_line = s + x * '_' + (n - x) * ' '
            second_line = (u + x) * ' ' + '\\' + (n - x - 1) * ' '
            shifted_lines = [u * ' ' + line for line in lines]
            return [first_line, second_line] + shifted_lines, n + u, p + 2, u // 2

        # Two children.
        else:
            left, n, p, x = self.left.display()
            right, m, q, y = self.right.display()
            s = str(self.val)
            u = len(s)
            first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s + y * '_' + (m - y) * ' '
            second_line = x * ' ' + '/' + (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' '
            if p < q:
                left += [n * ' '] * (q - p)
            elif q < p:
                right += [m * ' '] * (p - q)
            zipped_lines = zip(left, right)
            lines = [first_line, second_line] + [a + u * ' ' + b for a, b in zipped_lines]
            return lines, n + m + u, max(p, q) + 2, n + u // 2
           
            



In [65]:
bst = BSTNode.construct_from_list([5,4,10,2,8,13,44,9])



Input List:
 [5, 4, 10, 2, 8, 13, 44, 9]

Inserting 5

5

Inserting 4

 5
/ 
4 

Inserting 10

 5_ 
/  \
4 10

Inserting 2

  5_ 
 /  \
 4 10
/    
2    

Inserting 8

  5__ 
 /   \
 4  10
/  /  
2  8  

Inserting 13

  5__   
 /   \  
 4  10_ 
/  /   \
2  8  13

Inserting 44

  5__     
 /   \    
 4  10_   
/  /   \  
2  8  13_ 
         \
        44

Inserting 9

  5___     
 /    \    
 4  _10_   
/  /    \  
2  8   13_ 
    \     \
    9    44


In [62]:
bst.print_tree()

  5___     
 /    \    
 4  _10_   
/  /    \  
2  8   13_ 
    \     \
    9    44


In [23]:
bst.find_maximum()

44

In [24]:
bst.find_minimum()

2

In [36]:
bst.find_predessor().val

4

In [26]:
bst.find_successor()

8

In [63]:
bst.delete(5)

  5___     
 /    \    
 4  _10_   
/  /    \  
2  8   13_ 
    \     \
    9    44

Node 5 has both children

Finding successor of 5

Successor is 8

Replacing 5 with 8

  8___     
 /    \    
 4  _10_   
/  /    \  
2  8   13_ 
    \     \
    9    44

Deleting 8
  8__     
 /   \    
 4  10_   
/  /   \  
2  9  13_ 
         \
        44


In [59]:
bst.print_tree()

  8__     
 /   \    
 4  10_   
/  /   \  
2  9  13_ 
         \
        44


In [257]:
rt.delete(13)

In [258]:
rt.print_tree()

   8__ 
  /   \
 _4  10
/   /  
2   9  
 \     
 3     


In [259]:
rt.delete(2)

In [260]:
rt.print_tree()

  8__ 
 /   \
 4  10
/  /  
3  9  


In [261]:
rt.delete(8)

In [262]:
rt.print_tree()

  9_ 
 /  \
 4 10
/    
3    


In [263]:
rt.delete(9)
print_tree(rt)

  10
 /  
 4  
/   
3   


In [264]:
rt.delete(3)
rt.print_tree()

 10
/  
4  


In [265]:
rt.delete(4)
rt.print_tree()

10


In [266]:
rt.delete(10)
rt.print_tree()

None
