# Binary Search Tree

![image](./images/bst_01.png)

#### Define Node class

In [1]:
# this code makes the tree that we'll traverse

class Node(object):
        
    def __init__(self,value = None):
        self.value = value
        self.left = None
        self.right = None
        
    def set_value(self,value):
        self.value = value
        
    def get_value(self):
        return self.value
        
    def set_left_child(self,left):
        self.left = left
        
    def set_right_child(self, right):
        self.right = right
        
    def get_left_child(self):
        return self.left
    
    def get_right_child(self):
        return self.right

    def has_left_child(self):
        return self.left != None
    
    def has_right_child(self):
        return self.right != None
    
    # define __repr_ to decide what a print statement displays for a Node object
    def __repr__(self):
        return f"Node({self.get_value()})"
    
    def __str__(self):
        return f"Node({self.get_value()})"


In [3]:
from collections import deque
class Queue():
    def __init__(self):
        self.q = deque()
        
    def enq(self,value):
        self.q.appendleft(value)
        
    def deq(self):
        if len(self.q) > 0:
            return self.q.pop()
        else:
            return None
    
    def __len__(self):
        return len(self.q)
    
    def __repr__(self):
        if len(self.q) > 0:
            s = "<enqueue here>\n_________________\n" 
            s += "\n_________________\n".join([str(item) for item in self.q])
            s += "\n_________________\n<dequeue here>"
            return s
        else:
            return "<queue is empty>"

#### Define insert

Let's assume that duplicates are overriden by the new node that is to be inserted.  Other options are to keep a counter of duplicate nodes, or to keep a list of duplicates nodes with the same value.

In [5]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    """
    define insert here
    can use a for loop (try one or both ways)
    """
    def insert_with_loop(self, new_value):
        node = self.root        
        if node is None:
            self.set_root(new_value)
            return
        
        new_node = Node(new_value)
        while True:
            cmp = self.compare(node, new_node)
            if cmp == 0:
                node.set_value(new_node.get_value())
                break
            elif cmp == -1:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break
            else:
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break
        
    """
    define insert here (can use recursion)
    try one or both ways
    """  
    def _insert_helper(self, node, new_node):
        '''
        Insert new_node in a tree rooted at node
        '''
        comparison = self.compare(node, new_node)
        if comparison == 0:
            node.set_value(new_node.get_value())
        elif comparison == -1:
            if node.has_left_child():
                self._insert_helper(node.get_left_child(), new_node) 
            else:
                node.set_left_child(new_node)
        else:
            if node.has_right_child():
                self._insert_helper(node.get_right_child(), new_node) 
            else:
                node.set_right_child(new_node)
                
    def insert_with_recursion(self,value):
        if self.root is None:
            self.set_root(value)
            return
        
        self._insert_helper(self.root, Node(value))
    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [7]:
tree = Tree()
tree.insert_with_loop(5)
tree.insert_with_loop(6)
tree.insert_with_loop(4)
tree.insert_with_loop(2)
tree.insert_with_loop(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


In [8]:
tree = Tree()
tree.insert_with_recursion(5)
tree.insert_with_recursion(6)
tree.insert_with_recursion(4)
tree.insert_with_recursion(2)
tree.insert_with_recursion(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Search

Define a search function that takes a value, and returns true if a node containing that value exists in the tree, otherwise false.

In [19]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    def insert(self,new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return
        
        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node
                node = new_node
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping
                    
    """
    define search
    """
    def search(self,value):
        search_node = Node(value)
        
        node = self.get_root()
        if node == None:
            return False
        
        found = False
        while True:
            comparison = self.compare(node, search_node)
            if comparison == 0:
                found = True
                break
            elif comparison == -1:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    break
            else: #comparison == 1
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    break 
        return found
    
    '''
    define search with recursion
    '''
    def _search_helper(self, node, search_node):
        '''
        Search for search_node in a tree rooted at node.
        '''
        comparison = self.compare(node, search_node)

        if comparison == 0:
            return True
        elif comparison == -1:
            if node.has_left_child():
                return self._search_helper(node.get_left_child(), search_node)
            else:
                return False
        else: # comparison = 1
            if node.has_right_child():
                return self._search_helper(node.get_right_child(), search_node)
            else:
                return False
            
    def search_with_recursion(self, value):
        if self.root is None:
            return False
        
        return self._search_helper(self.root, Node(value))
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [20]:
tree = Tree()
tree.insert(5)
tree.insert(6)
tree.insert(4)
tree.insert(2)

print(f"""
search for 8: {tree.search(8)}
search for 2: {tree.search(2)}
""")


print(f"""
Search recursively:
search for 8: {tree.search_with_recursion(8)}
search for 2: {tree.search_with_recursion(2)}
""")
print(tree)


search for 8: False
search for 2: True


Search recursively:
search for 8: False
search for 2: True

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Bonus: deletion

Try implementing deletion yourself.  You can also check out this explanation [here](https://www.geeksforgeeks.org/binary-search-tree-set-2-delete/)

In [75]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    def insert(self,new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return
        
        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node
                node = new_node
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping
                    
    """
    define search
    """
    def search(self, value):
        search_node = Node(value)
        
        node = self.get_root()
        if node == None:
            return False
        
        found = False
        while True:
            comparison = self.compare(node, search_node)
            if comparison == 0:
                found = True
                break
            elif comparison == -1:
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    break
            else: #comparison == 1
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    break 
        return found
    
    """
    define delete
    """
    def get_max_val_node(self, node):
        '''
        Return the node with max value found in the tree rooted at node
        '''
        if node is None:
            return
        
        current = node
        while current.has_right_child():
            current = current.get_right_child()
        return current
    
    def _delete_helper(self, node, d_node):
        '''
        Delete d_node from a tree rooted at node, and return the new root. 
        Inspired by solution given here: https://www.geeksforgeeks.org/binary-search-tree-set-2-delete/
        '''
        comparison = self.compare(node, d_node)
        if comparison == 0:
            # node has no child
            if not node.has_left_child() and not node.has_right_child():
                return None
            # node has one child
            elif not node.has_right_child(): # node has only left child
                return node.get_left_child()
            elif not node.has_left_child(): # node has only right child
                return node.get_right_child()
            else: # node has both left and right children
                inorder_predecessor_node = self.get_max_val_node(node.left)
                node.set_value(inorder_predecessor_node.get_value())
                node.set_left_child(self._delete_helper(node.left, inorder_predecessor_node))
        elif comparison == -1:
            if node.has_left_child():
                node.set_left_child(self._delete_helper(node.left, d_node))
        else: # comparison = 1
            if node.has_right_child():
                node.set_right_child(self._delete_helper(node.right, d_node))
        return node
        
    def delete(self, value):
        '''
        Deletes the node with value input to this function.
        '''
        if self.root is None:
            return
        
        delete_node = Node(value)
        self.root = self._delete_helper(self.root, delete_node)
                
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


## Solution notebook
The solution for insertion and search is [here](04 binary_search_tree_solution.ipynb)

In [76]:
tree = Tree()
tree.insert(5)
tree.insert(3)
tree.insert(15)
tree.insert(1)
tree.insert(4)
tree.insert(2)
tree.insert(10)
tree.insert(20)
tree.insert(12)

print(tree)

tree.delete(3)
print(tree)

tree.delete(1)
print(tree)

tree.delete(2)
print(tree)

tree.delete(10)
print(tree)

tree.delete(20)
print(tree)

tree.delete(15)
print(tree)

tree.delete(5)
print(tree)

tree.delete(12)
print(tree)

Tree

Node(5)
Node(3) | Node(15)
Node(1) | Node(4) | Node(10) | Node(20)
<empty> | Node(2) | <empty> | <empty> | <empty> | Node(12) | <empty> | <empty>
<empty> | <empty> | <empty> | <empty>
Tree

Node(5)
Node(2) | Node(15)
Node(1) | Node(4) | Node(10) | Node(20)
<empty> | <empty> | <empty> | <empty> | <empty> | Node(12) | <empty> | <empty>
<empty> | <empty>
Tree

Node(5)
Node(2) | Node(15)
<empty> | Node(4) | Node(10) | Node(20)
<empty> | <empty> | <empty> | Node(12) | <empty> | <empty>
<empty> | <empty>
Tree

Node(5)
Node(4) | Node(15)
<empty> | <empty> | Node(10) | Node(20)
<empty> | Node(12) | <empty> | <empty>
<empty> | <empty>
Tree

Node(5)
Node(4) | Node(15)
<empty> | <empty> | Node(12) | Node(20)
<empty> | <empty> | <empty> | <empty>
Tree

Node(5)
Node(4) | Node(15)
<empty> | <empty> | Node(12) | <empty>
<empty> | <empty>
Tree

Node(5)
Node(4) | Node(12)
<empty> | <empty> | <empty> | <empty>
Tree

Node(4)
<empty> | Node(12)
<empty> | <empty>
Tree

Node(4)
<empty> | <empty>
