In [1]:
from typing import Optional

class Node:
    def __init__(self, val: int):
        self.val: int = val
        self.left: Optional[Node] = None
        self.right: Optional[Node] = None

## Binary Tree

- A binary tree is a hierachical data structure where each node has (at most) 2 children, the left and right child
- There are no restrictions for the value of the children vs their parents

- Insertion/deletion from a BT is $O(N)$. Since there are no order-restrictions for the nodes, you basically just traverse the tree until you find an empty spot

- This doesn't guarantee the tree remains balanced or anything fancy. It literally just inserts a value

In [24]:
from collections import deque

def insert(root: Node, insert_val: int):
    '''
    Traverse the height of the tree, so this is O(log N). If balanced tree is needed, then O(N)
    '''
    if root.left is None:
        root.left = Node(insert_val)
        return
    
    if root.right is None:
        root.right = Node(insert_val)
        return
    
    insert(root.left, insert_val)
        

def find(root: Node, find_val: int):
    '''
    Traverse the whole tree in the worst case, so O(N)
    '''
    pq = deque([root])
    while pq:
        curr = pq.pop()
        if curr.val == find_val:
            return True
        if curr.left:
            pq.append(curr.left)
        if curr.right:
            pq.append(curr.right)
    return False
            
def delete(root: Node, delete_val: int):
    '''
    Traverse the whole tree in the worst case, so O(N)
    '''
    ## Find the node with the value to delete. 
    pq = deque([root])
    while pq:
        target = pq.pop()
        if target.val == delete_val:
            break
        if target.left:
            pq.append(target.left)
        if target.right:
            pq.append(target.right)
    
    if target.val != delete_val:
        return "No such value in tree"
    
    ## Once you find the node to delete, find the bottom right node and swap it with the node to delete (i.e. swap values)
    pq = deque([(None, root)])
    while pq:
        parent, curr = pq.pop()
        if curr.left:
            pq.append((curr, curr.left))
        if curr.right:
            pq.append((curr, curr.right))
    
    ## Change target to the value of the last node
    target.val = curr.val
    
    ## Remove the last node from the Tree
    if parent:
        if parent.left == curr:
            parent.left = None
        else:
            parent.right = None
    
    return root

In [12]:
def inorder(root: Node):
    if root.left:
        inorder(root.left)
    print(root.val)
    if root.right:
        inorder(root.right)


In [19]:
root = Node(2)
root.left = Node(3)
root.right = Node(4)
root.left.left = Node(5)
root.left.right = Node(6)
inorder(root)

5
3
6
2
4


In [20]:
insert(root, 10)
inorder(root)

10
5
3
6
2
4


In [22]:
find(root,10)

True

In [23]:
delete(root, 10)
inorder(root)

5
3
6
2
4


## Binary Search Tree

- Similar to a binary search tree is a hierachical data structure where each node has (at most) 2 children, the left and right child
- Unlike a binary tree, however, BSTs enforce that every thing to the left of a parent (all children, grandchildren, etc) must be smaller than the root. And everything to the right must be larger
- Because of this, all subtrees in a BST must also be BSTs themselves
- Also because of this condition, BSTs generally do not accept duplicate values

In [102]:
def insert(root: Node|None, insert_val: int):
    '''
    O(log N) because we can split the tree in half depending on whether insert is larger or smaller than root
    In worst case of unbalanced tree, O(N)
    '''
    if root is None:
        return Node(insert_val)

    if insert_val < root.val:
        root.left = insert(root.left, insert_val)
    elif insert_val > root.val:
        root.right = insert(root.right, insert_val)
    else:
        # do nothing if values match, cannot have duplicates in BST
        return

    return root

def find(root: Node|None, find_val: int):
    '''
    O(log N)
    O(N) if unbalanced tree
    '''
    if not root:
        return None
    
    if root.val == find_val:
        return root
    
    if root.val > find_val:
        return find(root.left, find_val)
    elif root.val < find_val:
        return find(root.right, find_val)
    
def get_in_order_successor(root: Node):
    '''
    O(log N) because you only traverse 1 path
    '''
    curr = root.right
    while curr and curr.left:
        curr = curr.left
    return curr

def delete(root: Node|None, delete_val: int):
    '''
    O(log N) from the get_in_order_successtor
    + O(log N) from the recursive delete
    '''
    
    if not root:
        return None
    
    #If delete val is smaller than root, search on left subtree
    if root.val > delete_val:
        root.left = delete(root.left, delete_val)

    #If delete val is larger than root, search on right subtree
    elif root.val < delete_val:
        root.right = delete(root.right, delete_val)
    
    # if value matches, then we have found the node to delete
    else:
        if (not root.left) and (not root.right):
            # if node to delete has no children, just return none. The parent.left/parent.right will simply be assigned None value
            return None
        elif (not root.left):
            # if node to delete has 1 child, return the child. This will be assigned as one of the children of current node's parent
            return root.right
        elif (not root.right):
            return root.left
        else:
            # if node to delete has 2 children, then it gets trickier
            # because the node's children can also have their own children, and the BST structure must be preserved
            # Therefore, when deleting a node with multiple children, you want to find the **in-order successor**
            # That is, what is the next largest value that can take the place of the deleted node
            successor = get_in_order_successor(root)
            if successor:
                root.val = successor.val
                root.right = delete(root.right, successor.val)
        
    return root

def inorder(root: Node):
    if root.left:
        inorder(root.left)
    print(root.val)
    if root.right:
        inorder(root.right)

def make_bst(arr: list[int]):
    

In [106]:
root = Node(10)
root.left = Node(5)
root.right = Node(15)
root.right.left = Node(12)
root.right.right = Node(18)

In [107]:
insert(root, 1)
inorder(root)

1
5
10
12
15
18


In [82]:
find(root, 1)

In [108]:
delete(root, 12)
inorder(root)

1
5
10
15
18


### Constructing a BST

- There are 2 ways to construct a BST given an array of integers
    - Recursively, in $O(N \log N)$ time and $O(N)$ space
    - Iteratively, in $O(N)$ time and $O(N)$ space

In [121]:
def inorder(root: Node):
    if root.left:
        inorder(root.left)
    print(root.val)
    if root.right:
        inorder(root.right)

In [125]:
arr = [10,5,2,8,3,6,]

def add_val_to_tree(root: Node|None, val: int) -> Node|None:
    '''
    Time complexity: O(log N) for each val, because you traverse the height of the tree at worst. At worst, tree is very inbalanced, so O(N)
    Space complexity: O(log N) if tree is balanced, O(N) is tree is inbalanced, because stack memory is used in the recursion
    '''
    if not root:
        return Node(val)
    if val < root.val:
        root.left = make_bst(root.left, val)
    elif val > root.val:
        root.right = make_bst(root.right, val)
    
    return root

def make_bst_recursive(arr: list[int]) -> Node|None:    
    '''
    Time complexity: Each call to `add_val_to_tree` is O(log N) on average. Since there are N vals in arr, then it is O(N log N). In the worst case, this may be N^2
    Space complexity: Same as `add_val_to_tree`; O(log N) if tree is balanced, O(N) is tree is inbalanced, because stack memory is used in the recursion
    '''
    root = None
    for val in arr:
        root = add_val_to_tree(root, val)
    return root

root = make_bst_recursive(arr)
inorder(root)

2
3
5
6
8
10


In [137]:
def make_bst_iterative(arr: list[int]) -> Node|None:
    '''
    Time complexity: 
    Space complexity: 
    '''
    root = None
    for val in arr:
        root = bst_iterative_insert(root, val)
    return root

def bst_iterative_insert(root: Node|None, val: int) -> Node:
    if not root:
        return Node(val)

    curr = root
    while curr:
        if val < curr.val:
            if not curr.left:
                curr.left = Node(val)
                return root
            else:
                curr = curr.left
        if val > curr.val:
            if not curr.right:
                curr.right = Node(val)
                return root
            else:
                curr = curr.right

def inorder(root: Node):
    if root.left:
        inorder(root.left)
    print(root.val)
    if root.right:
        inorder(root.right)

In [139]:
arr = [10,5,2,8,3,6,]
root = make_bst_iterative(arr)
inorder(root)

2
3
5
6
8
10
