# Binary Search Tree

![image](./images/bst_01.png)

#### Define Node class

In [1]:
# this code makes the tree that we'll traverse

class Node(object):
        
    def __init__(self,value = None):
        self.value = value
        self.left = None
        self.right = None
        
    def set_value(self,value):
        self.value = value
        
    def get_value(self):
        return self.value
        
    def set_left_child(self,left):
        self.left = left
        
    def set_right_child(self, right):
        self.right = right
        
    def get_left_child(self):
        return self.left
    
    def get_right_child(self):
        return self.right

    def has_left_child(self):
        return self.left != None
    
    def has_right_child(self):
        return self.right != None
    
    # define __repr_ to decide what a print statement displays for a Node object
    def __repr__(self):
        return f"Node({self.get_value()})"
    
    def __str__(self):
        return f"Node({self.get_value()})"


In [2]:
from collections import deque
class Queue():
    def __init__(self):
        self.q = deque()
        
    def enq(self,value):
        self.q.appendleft(value)
        
    def deq(self):
        if len(self.q) > 0:
            return self.q.pop()
        else:
            return None
    
    def __len__(self):
        return len(self.q)
    
    def __repr__(self):
        if len(self.q) > 0:
            s = "<enqueue here>\n_________________\n" 
            s += "\n_________________\n".join([str(item) for item in self.q])
            s += "\n_________________\n<dequeue here>"
            return s
        else:
            return "<queue is empty>"

#### Define insert

Let's assume that duplicates are overriden by the new node that is to be inserted.  Other options are to keep a counter of duplicate nodes, or to keep a list of duplicates nodes with the same value.

In [3]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    """
    define insert here
    can use a for loop (try one or both ways)
    """
    def insert_with_loop(self,new_value):
        n = self.get_root()
        while n:
            if new_value < n.get_value():
                if not n.has_left_child():
                    n.set_left_child(Node(new_value))
                    return;
                n = n.get_left_child()
            elif new_value > n.get_value():
                if not n.has_right_child():
                    n.set_right_child(Node(new_value))
                    return;
                n = n.get_right_child()
            else:
                return
        self.set_root(new_value)


    """
    define insert here (can use recursion)
    try one or both ways
    """  
    def insert_with_recursion(self,value):
        def insert(node, value):
            if node is None:
                return Node(value)
            if value < node.get_value():
                node.left = insert(node.left, value)
            elif value > node.get_value():
                node.right = insert(node.right, value)
            return node
        self.root = insert(self.get_root(), value)
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [4]:
tree = Tree()
tree.insert_with_loop(5)
tree.insert_with_loop(6)
tree.insert_with_loop(4)
tree.insert_with_loop(2)
tree.insert_with_loop(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


In [5]:
tree = Tree()
tree.insert_with_recursion(5)
tree.insert_with_recursion(6)
tree.insert_with_recursion(4)
tree.insert_with_recursion(2)
tree.insert_with_recursion(5) # insert duplicate
print(tree)

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Search

Define a search function that takes a value, and returns true if a node containing that value exists in the tree, otherwise false.

In [6]:
class Tree():
    def __init__(self):
        self.root = None
        
    def set_root(self,value):
        self.root = Node(value)
        
    def get_root(self):
        return self.root
    
    def compare(self,node, new_node):
        """
        0 means new_node equals node
        -1 means new node less than existing node
        1 means new node greater than existing node 
        """
        if new_node.get_value() == node.get_value():
            return 0
        elif new_node.get_value() < node.get_value():
            return -1
        else:
            return 1
    
    def insert(self,new_value):
        new_node = Node(new_value)
        node = self.get_root()
        if node == None:
            self.root = new_node
            return
        
        while(True):
            comparison = self.compare(node, new_node)
            if comparison == 0:
                # override with new node
                node = new_node
                break # override node, and stop looping
            elif comparison == -1:
                # go left
                if node.has_left_child():
                    node = node.get_left_child()
                else:
                    node.set_left_child(new_node)
                    break #inserted node, so stop looping
            else: #comparison == 1
                # go right
                if node.has_right_child():
                    node = node.get_right_child()
                else:
                    node.set_right_child(new_node)
                    break # inserted node, so stop looping
                    
    """
    implement search
    """
    def search(self,value):
        node = self.get_root()
        while node and node.get_value() != value:
            if value < node.get_value():
                node = node.get_left_child()
            else:
                node = node.get_right_child()
        return not node is None
                    
    def __repr__(self):
        level = 0
        q = Queue()
        visit_order = list()
        node = self.get_root()
        q.enq( (node,level) )
        while(len(q) > 0):
            node, level = q.deq()
            if node == None:
                visit_order.append( ("<empty>", level))
                continue
            visit_order.append( (node, level) )
            if node.has_left_child():
                q.enq( (node.get_left_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

            if node.has_right_child():
                q.enq( (node.get_right_child(), level +1 ))
            else:
                q.enq( (None, level +1) )

        s = "Tree\n"
        previous_level = -1
        for i in range(len(visit_order)):
            node, level = visit_order[i]
            if level == previous_level:
                s += " | " + str(node) 
            else:
                s += "\n" + str(node)
                previous_level = level

                
        return s


In [7]:
tree = Tree()
tree.insert(5)
tree.insert(6)
tree.insert(4)
tree.insert(2)

print(f"""
search for 8: {tree.search(8)}
search for 2: {tree.search(2)}
""")
print(tree)


search for 8: False
search for 2: True

Tree

Node(5)
Node(4) | Node(6)
Node(2) | <empty> | <empty> | <empty>
<empty> | <empty>


## Bonus: deletion

Try implementing deletion yourself.  You can also check out this explanation [here](https://www.geeksforgeeks.org/binary-search-tree-set-2-delete/)

In [9]:
# Redefine the tree start by pulling out all the ugly __repr__ code

def tree_to_string(root):
    '''Convert the tree to printable form. Rather than define this in __repr__.'''
    level = 0
    q = Queue()
    visit_order = list()
    node = root
    q.enq( (node,level) )
    while(len(q) > 0):
        node, level = q.deq()
        if node == None:
            visit_order.append( ("<empty>", level))
            continue
        visit_order.append( (node, level) )
        if node.has_left_child():
            q.enq( (node.get_left_child(), level +1 ))
        else:
            q.enq( (None, level +1) )

        if node.has_right_child():
            q.enq( (node.get_right_child(), level +1 ))
        else:
            q.enq( (None, level +1) )

    s = "Tree\n"
    previous_level = -1
    for i in range(len(visit_order)):
        node, level = visit_order[i]
        if level == previous_level:
            s += " | " + str(node) 
        else:
            s += "\n" + str(node)
            previous_level = level
    return s
    

In [38]:
class Tree():
    '''A BST'''
    def __init__(self):
        self.root = None
        
    def get_root(self):
        return self.root
    
    def insert(self,new_value):
        '''Add a new value to the tree'''
        n = self.get_root()
        while n:
            if new_value < n.get_value():
                if not n.has_left_child():
                    n.set_left_child(Node(new_value))
                    return;
                n = n.get_left_child()
            elif new_value > n.get_value():
                if not n.has_right_child():
                    n.set_right_child(Node(new_value))
                    return;
                n = n.get_right_child()
            else:
                return
        self.root = Node(new_value)

    def search(self,value):
        '''Return true if value is in the tree, false if not'''
        return self._find_node(value) is not None
    
    def _find_node(self, value, start=None):
        '''Find the node holding value or None'''
        node = self.get_root() if not start else start
        while node and node.get_value() != value:
            if value < node.get_value():
                node = node.get_left_child()
            else:
                node = node.get_right_child()
        return node
    
    def _find_parent(self, node, start=None):
        '''Find the parent Node of node. Return None if node is the root.
        start allows search from some intermediate node. 
        If node is not part of the tree, then this will not work.'''
        parent = None
        current = self.get_root() if not start else start
        while current and current != node:
            parent = current
            if node.value < current.value:
                current = current.left
            else:
                current = current.right
        return parent

    
    def delete_value(self, value):
        '''Delete the value from the tree'''
        self.root = self._delete_value_recursive(self.root, value)

    def _delete_value_recursive(self, root, value):
        '''Delete the value from the tree recurively'''
        if not root:
            return None
        
        if value < root.value:
            root.left = self._delete_value_recursive(root.left, value)
            return root
        elif value > root.value:
            root.right = self._delete_value_recursive(root.right, value)
            return root
        else:
            # This is the node to delete
            if root.left and root.right:
                replacement = root.right
                while replacement.left:
                    replacement = replacement.left
                root.right = self._delete_value_recursive(root.right, replacement.value)
                root.value = replacement.value
                return root
            elif root.left or root.right:
                return root.left if root.left else root.right
            else:
                return None

    
    # Can't figure out how to make this work for the root of the tree
    # A recursive solution seems easier
    def _delete_node(self, node, start=None):
        '''Delete the node from the tree'''
        parent = self._find_parent(node, start=start)
        if node.left and node.right:
            # This is complicated. 
            # Find either the smallest value in the right sub tree or the largest value in the left sub tree. 
            # Copy its value to node and then remove that node from the tree.
            replacement = node.right
            while replacement.left:
                replacement = replacement.left
            self._delete_node(replacement, start=node)
            node.value = replacement.value
        elif node.left or node.right:
            # one child
            if node.value < parent.value:
                parent.left = node.left if node.left else node.right
            else:
                parent.right = node.left if node.left else node.right
        else:
            # no children
            if node.value < parent.value:
                parent.left = None
            else:
                parent.right = None
                    
    def __repr__(self):
        return tree_to_string(self.get_root())


In [40]:
# test the tree
tree = Tree()
tree.insert(5)
tree.insert(3)
tree.insert(7)
tree.insert(4)
tree.insert(6)
tree.insert(10)
tree.insert(9)
tree.insert(8)
tree.insert(8.5)

print(f"""
search for 8: {tree.search(8)}
search for 2: {tree.search(2)}
""")
print(tree)

print()

# test out parent and node find
node = tree._find_node(7)
parent = tree._find_parent(node)
print(parent)

node = tree._find_node(8.5)
parent = tree._find_parent(node)
print(parent)

node = tree._find_node(5)
parent = tree._find_parent(node)
print(parent)

subtree = tree._find_node(7)
target = tree._find_node(8.5)
parent = tree._find_parent(target, start=subtree)
print(parent)

print()
print(tree.search(4))
tree.delete_value(4)
print(tree.search(4))
print(tree)

print()
print(tree.search(3))
tree.delete_value(3)
print(tree.search(3))
print(tree)

print()
print(tree.search(7))
tree.delete_value(7)
print(tree.search(7))
print(tree)

print()
print(tree.search(5))
tree.delete_value(5)
print(tree.search(5))
print(tree)



search for 8: True
search for 2: False

Tree

Node(5)
Node(3) | Node(7)
<empty> | Node(4) | Node(6) | Node(10)
<empty> | <empty> | <empty> | <empty> | Node(9) | <empty>
Node(8) | <empty>
<empty> | Node(8.5)
<empty> | <empty>

Node(5)
Node(8)
None
Node(8)

True
False
Tree

Node(5)
Node(3) | Node(7)
<empty> | <empty> | Node(6) | Node(10)
<empty> | <empty> | Node(9) | <empty>
Node(8) | <empty>
<empty> | Node(8.5)
<empty> | <empty>

True
False
Tree

Node(5)
<empty> | Node(7)
Node(6) | Node(10)
<empty> | <empty> | Node(9) | <empty>
Node(8) | <empty>
<empty> | Node(8.5)
<empty> | <empty>

True
False
Tree

Node(5)
<empty> | Node(8)
Node(6) | Node(10)
<empty> | <empty> | Node(9) | <empty>
Node(8.5) | <empty>
<empty> | <empty>

True
False
Tree

Node(8)
Node(6) | Node(10)
<empty> | <empty> | Node(9) | <empty>
Node(8.5) | <empty>
<empty> | <empty>


## Solution notebook
The solution for insertion and search is [here](04 binary_search_tree_solution.ipynb)