# Trees

## Binary Tree

In [None]:
class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None

## Binary Search Tree

In [None]:
def search(root, target):
    if root is None: # Base case: target not found.
        return False

    if target > root.val: # Recursive case.
        return search(root.right, target) # Return the result of searching in the right sub-tree.
    elif target < root.val: # Recursive case.
        return search(root.left, target) # Return the result of searching in the left sub-tree.
    else: # Base case: target found.
        return True

For a balanced tree, this algorithm has a time complexity of $O(log\ n)$ and a space complexity of $O(log\ n)$. (The space complexity of a recursive algorithm is the maximum number of elements in the call stack during the execution of the algorithm.)

**Note:** This is one-branch recursion, since only one of the recursive calls get executed.

### Search in a Binary Search Tree

In [None]:
def searchBST(root, val):
    if root is None:
        return None

    if val > root.val:
        return searchBST(root.right, val)
    elif val < root.val:
        return searchBST(root.left, val)
    else:
        return root

For a balanced tree, this algorithm has a time complexity of $O(log\ n)$ and a space complexity of $O(log\ n)$.

## BST Insert and Remove

In [None]:
def insert(root, val):
    if root is None: # Base case.
        return TreeNode(val) # Create a new node and return it.

    if val > root.right: # Recursive case.
        root.right = insert(root.right, val) # Insert into the right sub-tree and assign the root of the right sub-tree as the right child of the current node.
    elif val < root.val: # Recursive case.
        root.left = insert(root.left, val) # Insert into the left sub-tree and assign the root of the left sub-tree as the left child of the current node.
    return root # Return the current node.
    # Note: Implicitly, `if val == root.val`, then don't do anything (i.e., don't insert duplicates); just return the current node.

For a balanced tree, this algorithm has a time complexity of $O(log\ n)$ and a space complexity of $O(log\ n)$.

**Note:** Consider the case `val > root.val`. From the perspective of the current node, if the right child is not `None`, then the assignment of the right pointer is a re-assignment. (It doesn't really do anything new.) Only when the right child is `None` does the right pointer get assigned to a new node - the node that just got created.

In [None]:
def min_node(root):
    while root is not None and root.left is not None:
        root = root.left
    return root

For a balanced tree, this algorithm has a time complexity of $O(log\ n)$ and a space complexity of $O(1)$.

In [None]:
def remove(root, val):
    if root is None: # Either empty tree or node not found.
        return None

    if val > root.val:
        root.right = remove(root.right, val)
    elif val < root.val:
        root.left = remove(root.left, val)
    else: # Node found.
        if root.right is None:
            return root.left
        elif root.left is None:
            return root.right
        else:
            min_node_right_subtree = min_node(root.right)
            root.val = min_node_right_subtree.val
            root.right = remove(root.right, min_node_right_subtree.val)
    return root

For a balanced tree, this algorithm has a time complexity of $O(3.log\ n) = O(log\ n)$. There are 3 parts:-

1. $O(log\ n)$ to search for the target node.
2. If the target node has two children (the worst case): $O(log\ n)$ to find the inorder successor of the target node.
3. If the target node has two children (the worst case): $O(log\ n)$ to remove the inorder successor (from the right sub-tree of the target node) after copying the value of the inorder successor into the target node.

Similarly, for a balanced tree, it has a space complexity of $O(2.log\ n + 1)$ = $O(log\ n)$. There are 3 parts:-

1. $O(log\ n)$ to search for the target node.
2. If the target node has two children (the worst case): $O(1)$ to find the inorder successor of the target node.
3. If the target node has two children (the worst case): $O(log\ n)$ to remove the inorder successor (from the right sub-tree of the target node) after copying the value of the inorder successor into the target node.

### Insert into a Binary Search Tree

In [None]:
def insertIntoBST(root, val):
    if root is None:
        return TreeNode(val)

    if val > root.val:
        root.right = insertIntoBST(root.right, val)
    elif val < root.val:
        root.left = insertIntoBST(root.left, val)
    return root

For a balanced tree, this algorithm has a time complexity of $O(log\ n)$ and a space complexity of $O(log\ n)$.

### Delete Node in a BST

In [None]:
def minNode(root):
    while root is not None and root.left is not None:
        root = root.left
    return root

For a balanced tree, this algorithm has a time complexity of $O(log\ n)$ and a space complexity of $O(1)$.

In [None]:
def deleteNode(root, key):
    if root is None:
        return None

    if key > root.val:
        root.right = deleteNode(root.right, key)
    elif key < root.val:
        root.left = deleteNode(root.left, key)
    else:
        if root.right is None:
            return root.left
        elif root.left is None:
            return root.right
        else:
            inorder_successor = minNode(root.right)
            root.val = inorder_successor.val
            root.right = deleteNode(root.right, inorder_successor.val)
    return root

For a balanced tree, this algorithm has a time complexity of $O(log\ n)$ and a space complexity of $O(log\ n)$.

## Depth-First Search

---

Exploratory:

In [None]:
def my_func():
    print("Hello!")
    return

The Python interpreter is being told: "*Get out of this function.*" On this occasion, the function would have returned anyway after the `print("Hello!")` line.

In [None]:
x = my_func()
x

Hello!


In [None]:
x is None

True

In [None]:
my_func()

Hello!


In [None]:
# Equivalent to:
def my_func():
    print("Hello!")
    return None

my_func()

Hello!


In [None]:
# Also equivalent to:
def my_func():
    print("Hello!")

my_func()

Hello!


**Note:** In C, C++ and Java, such functions have a return type of `void`.

---

In [None]:
def inorder(root): # Process the left sub-tree, then the node and finally the right sub-tree.
    if root is None:
        return

    inorder(root.left)
    print(root.val)
    inorder(root.right)

The time complexity of the algorithm is $O(n)$, because all $n$ nodes need to be traversed.

The space complexity of the algorithm is $O(h)$ where $h$ is the height of the tree, which would be $O(log\ n)$ for a balanced binary tree and $O(n)$ for a skewed tree. (The maximum number of items in the call stack - during the execution of the algorithm - is $h$. Even though we visit $n$ nodes, for a balanced tree, the number of elements in the call stack never exceeds $h$. If you're not convinced about this, take an example BST, and do a dry run while keeping track of the call stack - with pen and paper.)

From Perplexity.ai:-

> The space complexity of Depth-First Search (DFS) is $O(V)$, where $V$ represents the number of vertices in the graph. In a recursive implementation of DFS, the space complexity is determined by the maximum depth of the recursion. In the worst case, this can be equal to the total number of vertices in the graph. The algorithm only needs to store nodes along the current path (in the call stack). Unlike Breadth-First Search (BFS), DFS does not need to store all nodes at a given level. **The space requirement is proportional to the length of the longest path in the graph.** In an extreme case like a linear graph (similar to a linked list), the space complexity remains $O(V)$, where the algorithm might need to store references to every node in the stack. Key insight: DFS is generally more memory-efficient compared to BFS, especially in graphs with a large branching factor, as it only maintains the nodes on the current exploration path in memory.

**Question:** Is depth-first search an example of one-branch recursion or two-branch recursion?

**Answer from ChatGPT:**

---

Depth-First Search (DFS) can be implemented using both one-branch recursion and two-branch recursion, depending on the structure of the search.

One-Branch Recursion: When DFS is applied to a graph (an adjacency list), the recursion follows one child at a time. Example: Recursive DFS for a graph typically visits one adjacent node at a time in a loop, making it effectively one-branch recursion.

Two-Branch Recursion: When DFS is applied to a binary tree, each node has up to two children (left and right), leading to two recursive calls. Example: A standard DFS traversal for a binary tree is two-branch recursion.

---

So DFS in the context of (i) binary trees and (ii) recursion trees (of which decision trees are a special case) are examples of two-branch recursion. But DFS in the context of graphs is an example of one-branch recursion.

In [None]:
def preorder(root): # Process the node before processing its children.
    if root is None:
        return

    print(root.val)
    preorder(root.left)
    preorder(root.right)

In [None]:
def postorder(root): # Process the node after processing its children.
    if root is None:
        return

    postorder(root.left)
    postorder(root.right)
    print(root.val)

In [None]:
def reverse_order(root): # Process the right sub-tree, then the node and finally the left sub-tree.
    if root is None:
        return

    reverse_order(root.right)
    print(root.val)
    reverse_order(root.left)

ALL four of the above functions are examples of DFS.

**Question:** Why is traversal called "search"?

**Answer:** Because we can use traveral to search for a number that satisfies some criterion (e.g., the kth smallest number).

**Note:** We have encountered two different flavors of tree search:-

1. Search for a number (value) -> search in a BST -> $O(log\ n)$ time complexity & $O(log\ n)$ space complexity (for a balanced tree).
2. Search for a number that satisfies some criterion -> e.g., DFS and BFS -> $O(n)$ time complexity and $O(log\ n)$ space complexity (for a balanced tree).

Sorting an array of numbers (in descending order) using a BST and reverse order traversal:-

In [None]:
arr = [9, 1, 8, 2, 7, 3, 6, 4, 5]

class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None

def insert_bst(root, val): # Time: O(log(n)); space: O(log(n)) for a balanced tree. Latter: O(log(n)) is the space taken by the call stack.
    if root is None:
        return TreeNode(val)

    if val > root.val:
        root.right = insert_bst(root.right, val)
    elif val < root.val:
        root.left = insert_bst(root.left, val)
    return root

root = None
for val in arr: # Time: O(n*log(n)); space: O(n) for a balanced tree. Latter: because space needs to be allocated for `n` nodes.
    root = insert_bst(root, val)

def reverse_order(root): # Time: O(n); space: O(n + log(n)) = O(n) for a balanced tree. Latter: because O(n) is the space taken by `arr`, and O(log(n)) is the space taken by the call stack.
    if root is None:
        return []

    arr = []
    arr.extend(reverse_order(root.right))
    arr.append(root.val)
    arr.extend(reverse_order(root.left))
    return arr

sorted_arr = reverse_order(root)
sorted_arr

[9, 8, 7, 6, 5, 4, 3, 2, 1]

In [None]:
# Alt:
def sort_descending(root): # Time: O(n); space: O(n + log(n)) = O(n) for a balanced tree. Latter: because O(n) is the space taken by `arr`, and O(log(n)) is the space taken by the call stack.
    arr = []

    def reverse_order(root):
        if root is None:
            return

        reverse_order(root.right)
        arr.append(root.val)
        reverse_order(root.left)

    reverse_order(root)
    return arr

sorted_arr = sort_descending(root)
sorted_arr

[9, 8, 7, 6, 5, 4, 3, 2, 1]

Overall time complexity: $O(n.log\ n + n) = O(n.log\ n)$ for a balanced tree.

This is as efficient as merge sort and quick sort. But it has a big drawback: BSTs (by definition) can't contain duplicates. So the above method won't work on an array that contains duplicates.

Overall space complexity: $O(log\ n + n + n) = O(n)$ for a balanced tree.

**Note:** The reason we don't get an `UnboundLocalError` when modifying a list within a function (even if we don't declare it as `global` / `nonlocal`) is because lists are mutable objects in Python.

<ins>Mutable vs. immutable:</ins>

When we modify a list (e.g., append / remove / change an element), we're modifying the object itself, not creating a new one. This is different from immutable objects like integers and strings, where any modification creates a new object. (And if no other variable is pointing to the original immutable object, then the garbage collector collects it.)

<ins>Scoping rules:</ins>

Python's scoping rules state that if we assign a value to a variable within a function, it's considered a local variable unless explicitly declared as `global` / `nonlocal`. However, when we modify a mutable object like a list, we're not creating a new variable, so these rules don't apply in the same way.

In [None]:
def modify_list():
    my_list.append(4)

my_list = [1, 2, 3]
modify_list()
my_list

[1, 2, 3, 4]

In the above example, even though `my_list` is not declared as `global`, we can modify it within the function without an `UnboundLocalError`. This is because we're working with the same list object, not creating a new local variable.

In fact, this is the same reason why the `train_epoch` and `validate_epoch` functions (that we write so often) are able to access and modify the PyTorch model, optimizer state, learning rate scheduler state, etc. without using the `global` keyword.

### Binary Tree Inorder Traversal

In [None]:
def inorderTraversal(root):
    if root is None:
        return []

    arr = []
    arr.extend(inorderTraversal(root.left))
    arr.append(root.val)
    arr.extend(inorderTraversal(root.right))
    return arr

inorderTraversal(root)

[1, 2, 3, 4, 5, 6, 7, 8, 9]

In [None]:
# Alt:
def inorderTraversal(root):
    arr = []

    def inorder(node):
        if node is None:
            return

        inorder(node.left)
        arr.append(node.val)
        inorder(node.right)

    inorder(root)
    return arr

inorderTraversal(root)

[1, 2, 3, 4, 5, 6, 7, 8, 9]

This algorithm has a time complexity of $O(n)$ and a space complexity of $O(n + h) = O(n)$ where $h = log\ n$ for a balanced tree and $h = n$ for a skewed tree. ($n$ is the memory consumed by the `arr` array, and $h$ is the memory consumed by the call stack.)

**Note:** To calculate space complexity, we must add the memory consumed by any arrays (or other variable-length data structures) to the memory consumed by the call stack.

### Kth Smallest Element in a BST

In [None]:
# Inorder traversal based solution:
def kthSmallest(root, k):
    arr = []

    def inorder(node):
        if node is None:
            return

        inorder(node.left)
        arr.append(node.val)
        inorder(node.right)

    inorder(root)
    return arr[k - 1]

kthSmallest(root, 5)

5

This algorithm has a time complexity of $O(n)$ and a space complexity of $O(n + h) = O(n)$ where $h = log\ n$ for a balanced tree and $h = n$ for a skewed tree. ($n$ is the memory consumed by the `arr` array, and $h$ is the memory consumed by the call stack.)

In [None]:
# DFS based solution:
def kthSmallest(root, k):
    cnt = 0
    kth = None

    def dfs(node):
        nonlocal cnt, kth

        if node is None:
            return

        dfs(node.left)
        cnt += 1
        if cnt == k:
            kth = node.val
            return
        dfs(node.right)

    dfs(root)
    return kth

kthSmallest(root, 5)

5

This algorithm has a time complexity of $O(n)$ and a space complexity of $O(h)$ where $h = log\ n$ for a balanced tree and $h = n$ for a skewed tree. (**Note:** The DFS based solution has a smaller space complexity than the inorder traversal based solution.)

**Exercise:** Do a FULL dry run (by drawing the call stack) with `k=2` on the following BST: `[4, 3, 5, 2, None, None, None]`. You will notice that the traversal still continues after the algorithm backtracks from the `3` node to the `4` (root) node.

### Construct Binary Tree from Preorder and Inorder Traversal

In [None]:
# Exploratory:
my_list = ['a', 'b', 'c', 'b']
my_list.index('b')

1

**Note:** The `index` method returns the position at which the argument occurs for the first time.

In [None]:
def buildTree(preorder, inorder):
    if len(preorder) == 0: # Equivalent to: `if len(inorder) == 0:`
        return None

    root_val = preorder[0]
    root = TreeNode(root_val)
    root_idx = inorder.index(root_val)
    left_subtree_vals = inorder[:root_idx]
    right_subtree_vals = inorder[root_idx + 1:]
    preorder = preorder[1:]
    root.left = buildTree(preorder[:len(left_subtree_vals)], left_subtree_vals)
    root.right = buildTree(preorder[len(left_subtree_vals):], right_subtree_vals)
    return root

The time complexity of this algorithm is $O(n^2)$, because the `inorder.index(root_val)` method call costs $O(n)$, and we're calling it $n$ times (once for creating each node).

The space complexity of this algorithm is $O(n.h)$, because we're passing subarrays of `preorder` and `inorder` (which cost $O(n)$) in each recursive function call, and there are a maximum of $h$ recursive function calls in the call stack, where $h = log\ n$ for a balanced tree and $h = n$ for a skewed tree.

In [None]:
# Test:
root = buildTree([1, 2, 3, 4], [2, 1, 3, 4])
root

<__main__.TreeNode at 0x7e52a67237d0>

In [None]:
preorder(root)

1
2
3
4


In [None]:
inorder(root)

2
1
3
4


## Breadth-First Search

In [None]:
from collections import deque

def bfs(root):
    queue = deque()

    if root is not None:
        queue.append(root)

    level = 0
    while len(queue) > 0:
        print("Level:", level)
        for i in range(len(queue)):
            curr = queue.popleft()
            print(curr.val)
            if curr.left is not None:
                queue.append(curr.left)
            if curr.right is not None:
                queue.append(curr.right)
        print("---")
        level += 1

The time complexity of BFS is $O(n)$ where $n$ is the number of nodes in the tree. This is because we visit every node exactly once.

The space complexity of BFS is $O(n)$ where $n$ is the number of nodes in the tree. This is because we will store an entire level of the tree in the queue at a time. In the worst case, the last level may be roughly half the size of the tree, so the space complexity is $O(n)$.

**Exercise:** Do a dry run of level-order traveral (BFS) for the binary tree `[4, 3, 6, 2, None, 5, 7]` (by drawing the queue and writing out the output).

Let's test out our `bfs` function. We can use our earlier `buildTree` function to create the above tree.

In [None]:
# Test:
root = buildTree(preorder=[4, 3, 2, 6, 5, 7], inorder=[2, 3, 4, 5, 6, 7])
bfs(root)

Level: 0
4
---
Level: 1
3
6
---
Level: 2
2
5
7
---


**Note:** The list representation of a binary tree (e.g., `[4, 3, 6, 2, None, 5, 7]`) can be thought of as its level-order traversal that includes `None` values.

### Binary Tree Level Order Traversal

In [None]:
from collections import deque

def levelOrder(root):
    output_arr = []
    queue = deque()

    if root is not None:
        queue.append(root)

    while len(queue) > 0:
        level_arr = []
        for i in range(len(queue)):
            curr = queue.popleft()
            level_arr.append(curr.val)
            if curr.left is not None:
                queue.append(curr.left)
            if curr.right is not None:
                queue.append(curr.right)
        output_arr.append(level_arr)

    return output_arr

This algorithm has a time complexity of $O(n)$ and a space complexity of $O(n)$.

In [None]:
# Test:
levelOrder(root)

[[4], [3, 6], [2, 5, 7]]

### Binary Tree Right Side View

In [None]:
from collections import deque

def rightSideView(root):
    output_arr = []
    queue = deque()

    if root is not None:
        queue.append(root)

    while len(queue) > 0:
        level_arr = []
        for i in range(len(queue)):
            curr = queue.popleft()
            level_arr.append(curr.val)
            if curr.left is not None:
                queue.append(curr.left)
            if curr.right is not None:
                queue.append(curr.right)
        output_arr.append(level_arr[-1])

    return output_arr

This algorithm has a time complexity of $O(n)$ and a space complexity of $O(n)$.

In [None]:
# Test:
rightSideView(root)

[4, 6, 7]

## BST Sets and Maps

A set is a collection of unique elements (no duplicates allowed). It can be implemented using a BST (since BSTs also don't contain duplicates). Such a set is known as a tree set / sorted set. (Each node of a tree set only contains a key.)

The main advantage of tree sets (vis-a-vis dynamic arrays) is that we can search for an element, insert an element and delete an element in $O(log\ n)$ time.

A map is a set of key-value pairs, where the keys are unique (no duplicates allowed). It can be implemented using a BST. Such a map is known as a tree map / sorted map / sorted dict. (Each node of a tree map contains a key and one or more values.)

Tree maps have some tradeoffs (vis-a-vis hash maps). We shall examine these tradeoffs in the **Hashing** chapter.

Python implementation of tree map:-

In [None]:
!pip install -q sortedcontainers

In [None]:
from sortedcontainers import SortedDict

tree_map = SortedDict({'c': 3, 'a': 1, 'b': 2})
tree_map

SortedDict({'a': 1, 'b': 2, 'c': 3})

### Design Binary Search Tree

In [None]:
class TreeNode:
    def __init__(self, key, val):
        self.key = key
        self.val = val
        self.left = None
        self.right = None

class TreeMap:
    def __init__(self):
        self.root = None

    def insert(self, key, val):
        def add(node):
            if node is None:
                return TreeNode(key, val)

            if key > node.key:
                node.right = add(node.right)
            elif key < node.key:
                node.left = add(node.left)
            else:
                node.val = val
            return node

        self.root = add(self.root)

    def get(self, key):
        def search(node):
            if node is None:
                return -1

            if key > node.key:
                return search(node.right)
            elif key < node.key:
                return search(node.left)
            else:
                return node.val

        return search(self.root)

    def getMin(self):
        if self.root is None:
            return -1
        else:
            curr = self.root
            while curr is not None and curr.left is not None:
                curr = curr.left
            return curr.val

    def getMax(self):
        if self.root is None:
            return -1
        else:
            curr = self.root
            while curr is not None and curr.right is not None:
                curr = curr.right
            return curr.val

    def remove(self, key):
        def getMinNode(node):
            while node is not None and node.left is not None:
                node = node.left
            return node

        def delete(node, key_):
            if node is None:
                return None

            if key_ > node.key:
                node.right = delete(node.right, key_)
            elif key_ < node.key:
                node.left = delete(node.left, key_)
            else:
                if node.right is None:
                    return node.left
                elif node.left is None:
                    return node.right
                else:
                    inorder_successor = getMinNode(node.right)
                    node.key = inorder_successor.key
                    node.val = inorder_successor.val
                    node.right = delete(node.right, inorder_successor.key)
            return node

        self.root = delete(self.root, key)

    def getInorderKeys(self):
        output_arr = []

        def inorder(node):
            if node is None:
                return

            inorder(node.left)
            output_arr.append(node.key)
            inorder(node.right)

        inorder(self.root)
        return output_arr