## Binary Search Tree

### Introduction to BST

* What is Binary Search Tree (BST)?
  + A binary search tree (BST), a special form of a binary tree, satisfies the binary search property:
    + The value in each node must be greater than (or equal to) any values stored in its left subtree.
    + The value in each node must be less than (or equal to) any values stored in its right subtree.
* inorder traversal in BST will be in ascending order. Therefore, the inorder traversal is the most frequent used traversal method of a BST.  

#### Leetcode 285. Inorder Successor in BST
* Overview 
  + Given the root of a binary search tree and a node p in it, return the in-order successor of that node in the BST. If the given node has no in-order successor in the tree, return null.
  + The successor of a node p is the node with the smallest key greater than p.val.
* Algorithm
  + recusive implementation
    + we use BST's property that the root node has a value bigger than its left child nodes, but smaller than its right child nodes
    + define traverse(node) function.
      + if node.val <= p.val
        + traverse(node.right)
      + otherwise
        + return node or traverse(node.left)
    + return traverse(root)
    + time complexity
      + O(N) worst case and O(logN) when the BST is balanced tree
    + space complexity
      + O(N) worst case and O(logN) for balanced BST
  + iterative
    + set successor = None
    + while root
      + if root.val <= p.val, root = root.right
      + else
        + successor = root
        + root = root.left (if successor along this branch can survive, then that successor will be the answer)
        
      
    
      

In [2]:
from typing import Optional

# recursive implemenation

# Definition for a binary tree node.
class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

class Solution:
    def inorderSuccessor(self, root: TreeNode, p: TreeNode) -> Optional[TreeNode]:
        if root is None or p is None:
            return None
        
        def traverse(node: Optional[TreeNode]) -> Optional[TreeNode]:
            if node is None:
                return None
            
            # if node.val <= p.val, set node as its right child
            if node.val <= p.val:
                return traverse(node.right)
            
            # otherwise, node is a possible answer if its
            # left branch doesn't have a node with val > p
            right = traverse(node.left)
            return right or node
        
        return traverse(root)

# iterative implementation 
class Solution:
    def inorderSuccessor(self, root: TreeNode, p: TreeNode) -> Optional[TreeNode]:
        if root is None or p is None:
            return None
        
        # initialize successor as None
        successor = None
        
        # traverse descendents of root node
        while root:
            # if root.val <= p.val, traverse right branch
            if root.val <= p.val:
                root = root.right
                
            # otherwise, set successor = node, then explore
            # its left branch, if left branch returns None
            # we will return the current successor node
            else:
                successor = root
                root = root.left
        
        return successor            

#### Leetcode 173. Binary Search Tree Iterator
* Overview
  + Implement the BSTIterator class that represents an iterator over the in-order traversal of a binary search tree (BST):
    + BSTIterator(TreeNode root) Initializes an object of the BSTIterator class. The root of the BST is given as part of the constructor. The pointer should be initialized to a non-existent number smaller than any element in the BST.
    + boolean hasNext() Returns true if there exists a number in the traversal to the right of the pointer, otherwise returns false.
    + int next() Moves the pointer to the right, then returns the number at the pointer.
  + Notice that by initializing the pointer to a non-existent smallest number, the first call to next() will return the smallest element in the BST.
  + You may assume that next() calls will always be valid. That is, there will be at least a next number in the in-order traversal when next() is called.
  
* Algorithm
  + we separate the inorder travese template into the three parts in the implemetation. 
  + in dunder init function, we set the self.stack as empty list and self.curr = root
  + in next(self) implement the logic in while curr or stack without the while condition
  + in the hasNext() fucntion, we just check the while condition
* Time complexity
  + O(N) worst case, on average O(1) for next
* Space complexity
  + O(h) to store each branch in self.stack

In [4]:
# Definition for a binary tree node.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
class BSTIterator:

    def __init__(self, root: Optional[TreeNode]):
        
        self.curr = root
        self.stack = []
        
        self.traverse_left()

    # traverse and push nodes on left branch to self.stack
    def traverse_left(self) -> None:
        while self.curr:
            self.stack.append(self.curr)
            self.curr = self.curr.left
            
    def next(self) -> int:        
        
        # pop the node from left branch
        rs = self.stack.pop()
        
        # traverse its right branch and push
        # nodes on its left branch to self.stack
        self.curr = rs.right
        self.traverse_left()
        
        # return inorder node's value
        return rs.val
        

    def hasNext(self) -> bool:
        return self.curr or self.stack        

#### Leetcode 701. Insert into a Binary Search Tree
* Overview
  + You are given the root node of a binary search tree (BST) and a value to insert into the tree. Return the root node of the BST after the insertion. It is guaranteed that the new value does not exist in the original BST.
  + Notice that there may exist multiple valid ways for the insertion, as long as the tree remains a BST after insertion. You can return any of them.
  
* Algorithm
  + follow the BST rule and search the target value in BST
  + if the next position corresponds to a None node, return the new node
  + connect the left or right child to the calling node and return the node
* time complexity
  + O(h)
* space complexity
  + O(h) for recursive implementation and O(1) for iterative implementation

In [5]:
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def insertIntoBST(self, root: Optional[TreeNode], val: int) -> Optional[TreeNode]:
        if root is None:
            root = TreeNode(val)
            
        def traverse(node: TreeNode) -> TreeNode:
            # if traverse to a None node, return the new TreeNode
            if node is None:
                return TreeNode(val)
            
            # if node.val < val, recursively call node.right and
            # connect the returned node to node's right branch
            if node.val < val:
                node.right = traverse(node.right)
            
            # if node.val > val, recursively call node.right and
            # connect the returned node to node's left branch
            elif node.val > val:
                node.left = traverse(node.left)
            
            # return node
            return node     
                   
                    
        return traverse(root) 

# iterative implementation
class Solution:
    def insertIntoBST(self, root: Optional[TreeNode], val: int) -> Optional[TreeNode]:
        if root is None:
            return TreeNode(val)
                        
        # set curr = root
        curr = root
        
        # traverse the BST using iteration
        while curr:
            # if the next position is None, insert the new node
            # otherwise, set curr = the next position
            if curr.val < val:
                if curr.right:
                    curr = curr.right
                # if the new node is inserted, break the loop
                else:
                    curr.right = TreeNode(val)
                    break
            elif curr.val > val:
                if curr.left:
                    curr = curr.left
                else:
                    curr.left = TreeNode(val)
                    break
        return root                    

#### Leetcode 450. Delete Node in a BST
* Overview
  + Given a root node reference of a BST and a key, delete the node with the given key in the BST. Return the root node reference (possibly updated) of the BST.
  + Basically, the deletion can be divided into two stages:
  + Search for a node to remove.
  + If the node is found, delete the node.
* Algorithm
  + use the BST property and traverse the tree to find the node whose val equals the target value
    + if node.val > target, node.left = self.deleteNode(node.left)
    + if node.val < target, node.right = self.deleteNode(node.right)
  + when the target node is found
    + if node is a leaf node, return None
    + if one of its children is None, assign root to the non-None child node
    + if both child nodes are not None, find the min node in right branch. Assign the value of that min node to the current node, and recursively call the deleteNode on its right branch
  + note: when traversing branches, we modify the branch structure, assign the modified child branch to the appropriate child of the current node, and return the current node
* time and space complexity
  + O(h) we only focus on one of the child branch, so both time and space are the depth of the BST. In the worst case, it is O(N) and for balanced BST, it is O(logN)

In [6]:
from typing import Optional
# Definition for a binary tree node.
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
class Solution:
    def deleteNode(self, root: Optional[TreeNode], key: int) -> Optional[TreeNode]:
        if root is None:
            return None
        
        def find_min(node: TreeNode) -> Optional[TreeNode]:
            while node.left:
                node = node.left                
            return node
        
        # if root.val > key. traverse and modify root.left. Assign
        # root.left to the returned node of recusive call
        if root.val > key:
            root.left = self.deleteNode(root.left, key)
                    
        # if root.val < key. traverse and modify root.right. Assign
        # root.right to the returned node of recusive call
        elif root.val < key:
            root.right = self.deleteNode(root.right, key)
        
        # if root.val == key. Delete the curren node and
        # maintain the BST structure as much as possible
        else:
            
            if root.left is None and root.right is None:
                return None

            if root.left is None:
                return root.right

            if root.right is None:
                return root.left

            # we know root.right is not None
            # so we find the leaf node of its left branch
            successor = find_min(root.right)

            # copy the value of successor to root and
            # delete the successor by recusive calling the function
            # to its right branch
            root.val = successor.val

            root.right = self.deleteNode(root.right, successor.val)
        
        return root
        
        

#### Leetcode 703. Kth Largest Element in a Stream
* Overview
  + Design a class to find the kth largest element in a stream. Note that it is the kth largest element in the sorted order, not the kth distinct element.
  + Implement KthLargest class:
    + KthLargest(int k, int\[\] nums) Initializes the object with the integer k and the stream of integers nums.
    + int add(int val) Appends the integer val to the stream and returns the element representing the kth largest element in the stream.
* Algorithm
  + min heap
    + concept
      + for k largest values, we use min heap so that small values are pop and larger values are kept
      + when adding new values, if the new value < heap(0), ignore since it is not among the k largest
    + initialization
      + use heap rather than BST traverse
      + create a heap array from nums and heapfy it in O(N)
      + pop from the heap until its size == k in O(N-k) = O(NlogN)
    + add
      + if the heap size < k, heappush the val to the heap
      + if heap size == k, and heap(0) < val, heapreplace(heap, val)
      + return heap(0)
      + log(k) for each add (if we have to delete nodes before insertion, it will be log(2k) = O(log(k))
      + altogether, for m insertions, it is mlog(k)
    + Time complexity
      + O(NlogN + mlog(k))
    + Space comlexity
      + O(N) to store all the numbers in heap in initialization
  + BST
    + implement BST to store values 
      + if node is None, set self.root = TreeNode(val) and return 
      + using iterative implementation to insert elements to self.root
        + while node (set node = self.root)
          + node.count += 1 to count all the node's child nodes
          + if node.val < val, add new node to node.right if node.righ is None, otherwise node = node.right
          + if node.val >= val, add new node to node.left if node.left is None, otherwise node = node.left (if node.val == val, the new node is either the node's direct left child, or its left child's right child)
      + each node store the number of its child nodes
      + so a node's right child node's count means how many nodes are larger than the current node
      + if a node's right node is None, it is the largest node, and is the solution when k == 1
    + search(node, k)
      + check m = node.right.count 
        + if m = k-1, return node.val current node is the kth largest
        + if m >= k, continue to traverse the right branch
        + if m < k -1, k-= m+1, and traverse the left branch (this subtacts the m node on parent node and the parent node itself)
    + add(val)
      + insert(val)
      + search(self.root, self.size)
    + Time complexity
      + O(Nh) for inserting N nodes to the BST with a tree height of h. The worst case is O(N^2), or O(NlogN) with a balanced tree
    + Space complexity
      + O(N) to store all nodes in BST 

In [7]:
from heapq import heappop, heappush, heapify, heapreplace
from typing import List

# recursive implementation
class KthLargest:

    def __init__(self, k: int, nums: List[int]):
        self.heap = nums
        self.size = k
        
        # initialize min heap and pop elements
        # to maintain the heap size == k
        heapify(self.heap)
        
        while len(self.heap) > self.size:
            heappop(self.heap)        

    def add(self, val: int) -> int:
        if len(self.heap) < self.size:
            heappush(self.heap, val)
            
        elif len(self.heap) == self.size and self.heap[0] < val:
            heapreplace(self.heap, val)
            
        return self.heap[0]      
    
# implementation by BST
class TreeNode:
    
    def __init__(self, val: int):
        self.left = None
        self.right = None
        self.val = val
        self.count = 1
            
class KthLargest:

    # construct the BST tree by adding each num to root
    def __init__(self, k: int, nums: List[int]):
        self.root = None
        self.size = k
        for num in nums:
            self.insert(num)
        
    # if val == node.val, we add it to the left child
    # if the equal value node becomes the parent node's 
    # direct left child, we can get it by k -= m+1 where m is
    # the count of parent node's right child. Otherwise, this
    # equal node will become the parent node's direct left child's
    # right child.iF it is the answer, when traversing to this node
    # it can be returned since its right child has a zero count
    def insert(self, val) -> None:
        if self.root is None:
            self.root = TreeNode(val)
            return 
            
        node = self.root
        # each node records the total number of its child nodes plus 1, which is itself
        # therefore, when a node's right child has m counts, we know it is the m+1 the largest
        # if the right branch doesn't have k nodes, we use k-= m+1 to deduct its right branch
        # plus parent node, and continue to traverse its left branch. the leaf node will be returned
        # since it must have k == 1 (the first node after deduction during traversal), and the leaf
        # node has its right child count as 0
        while node:
            node.count += 1                  
            if node.val >= val:
                if node.left:
                    node = node.left                    
                else:
                    node.left = TreeNode(val)
                    return
            elif node.val < val:
                if node.right:
                    node = node.right                    
                else:
                    node.right = TreeNode(val)
                    return
            
        
    def search(self, node, k) -> int:
        # this is important to retrieve the leaf node when
        # k == 1 (the largest node in recursion if its right node is None), meaning
        # that there is no nodes larger than the current node
        m = node.right.count if node.right else 0
        
        if m + 1 == k:
            return node.val
        
        # if there are more than k nodes larger than the current node
        # continue to traverse the right branch
        if k <= m:
            return self.search(node.right, k)
        
        # otherwise, the right branch doesn't have k-1 nodes, we 
        # subtract all m nodes and the parent nodes from k (k -= m+1)
        # and continue to traverse left branch
        k -= m + 1
        return self.search(node.left,k)        
        
    def add(self, val: int) -> int:                
        self.insert(val)
               
        return self.search(self.root, self.size)

# Your KthLargest object will be instantiated and called as such:
# obj = KthLargest(k, nums)
# param_1 = obj.add(val)


#### Leetcode 235. Lowest Common Ancestor of a Binary Search Tree
* Overview
  + Given a binary search tree (BST), find the lowest common ancestor (LCA) node of two given nodes in the BST.
  + According to the definition of LCA on Wikipedia: “The lowest common ancestor is defined between two nodes p and q as the lowest node in T that has both p and q as descendants (where we allow a node to be a descendant of itself).”
* Algorithm
  + use BST and LCA's properties
    + LCA of two nodes must have one node in its left and the other in its right branche
    + Based on BST property, the LCA node must have a value between the values of p and q nodes
  + the basic idea is to 
    + search the node's left branch if its value is bigger than values of both p and q 
    + search the node's right branch if its value is smaller than values of both p and q
    + return the node if its value is between the values of p and q nodes
  + recursive implementation
    + time complexity
      + O(N) worst case, O(logN) for balanced BST
    + space complexity
      + O(N) worst case, O(lognN) for balanced BST due to the recursive stack
  + iterative implementation
    + O(N) to traverse all the node in worst case when the tree is skewed. O(logN) for balanced tree
    + O(1) space complexity

In [8]:
class Solution:
    def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
        if root is None:
            return None
        
        if p is None and q is None:
            return None
        
        if p is None:
            return q
        if q is None:
            return p
        
        # the LCA of the two nodes much have one node from
        # its left and the other from its right branche. Based on BST
        # property, root.val must be between q.val and p.val
        # therefore, if root.val < both values of p and q, traverse its right
        # child. If root.val > both values of p and q, traverse its left child
        # otherwise, root.val is between the values of p and q, return the node
        if root.val < p.val and root.val < q.val:
            return self.lowestCommonAncestor(root.right, p, q)
        if root.val > p.val and root.val > q.val:
            return self.lowestCommonAncestor(root.left, p, q)
        return root
    
# iterative implementation 
class Solution:
    def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
        if root is None:
            return None
        
        if p is None and q is None:
            return None
        
        if p is None:
            return q
        if q is None:
            return p
        
        while root:
            if root.val < p.val and root.val < q.val:
                root = root.right
            elif root.val > p.val and root.val > q.val:
                root = root.left
            else:
                return root
        return None     
        

#### Leetcode 220. Contains Duplicate III
* Overview
  + You are given an integer array nums and two integers indexDiff and valueDiff.
  + Find a pair of indices (i, j) such that:
    + i != j,
    + abs(i - j) <= indexDiff.
    + abs(nums[i] - nums[j]) <= valueDiff, and
  + Return true if such pair exists or false otherwise.

* Algorithm
  + we use the buckets on both value and index dimensions
  + we define bucket size = valueDiff + 1
    + if there are more than one elements in the same bucket within index difference of indexDiff, return True
    + for neighboring buckets (bucket id +/- 1), since each iteration we only have a single elements in each bucket, we just compare the current num and the value stored in the neighboring buckets and if the differenct <= valueDiff, return True
  + at the end of each iteration, we check if i >= indexDiff, if so, we remove the bucekts(i-indexDiff) since this element has been tested against all the possible elements within the index difference defiend by indexDiff, and all new elements will exceed the index difference.
  + return False after traversing of the array    

In [9]:
class Solution:
    def containsNearbyAlmostDuplicate(self, nums: List[int], indexDiff: int, valueDiff: int) -> bool:
        
        
        # use hashmap as bucket. Each bucekt has a size of bucke_size-1 (or valueDiff)
        # therefore, if we keep the sliding window on index with a window width of indexDiff
        # if two elements drop in the same bucket, we return True. If nothing is returned
        # after the entire traversal, return False
        buckets = {}
        bucket_size = valueDiff + 1
        
        for i, num in enumerate(nums):
            bucket_id = num // bucket_size
            if bucket_id in buckets:
                return True
            for nb_id in {bucket_id + 1, bucket_id - 1}:
                if nb_id in buckets and abs(buckets[nb_id] - num) <= valueDiff:
                    return True
            
            buckets[bucket_id] = num
            
            # if the ith element that has index >= indexDiff
            # has been checked, then we need to remove the single
            # element bucket with id = (i-indexDiff) // bucket_size
            # since the next element will have an index difference
            # bigger than indexDiff from this element. Note that
            # we check this for each element, and at most one bucket 
            # will be removed in each iteration. One bucket and element
            # will removed since i==indexDiff
            if i >= indexDiff:
                del buckets[nums[i-indexDiff]//bucket_size]
                
        return False            

#### Leetcode 110. Balanced Binary Tree
* Overview
  + Given a binary tree, determine if it is 
height-balanced
* Algorithm
  + use post order traverse
    + return (0, True) for empty node
    + recursively call node.left and node.right
    + if any of child node returns a Fasle result, or abs of left dept - right dept > 1, then return (0, False) since we don't need to compare depth any more, the False result will be passed back to root
    + return (max(left depth, right depth) +1, True)
* Time complexity
  + O(N)
* Space complexty
  + O(N)

In [None]:
# Definition for a binary tree node.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
class Solution:
    def isBalanced(self, root: Optional[TreeNode]) -> bool:
        if root is None:
            return True
        
                
        # use bottom up 
        def traverse(node: Optional[TreeNode]) -> Tuple[int, bool]:
            # return 0 depth and true as a height balanced tree
            if node is None:
                return (0, True)            
            
            # recursively call left and right child nodes
            left_depth, left = traverse(node.left)
            right_depth, right = traverse(node.right)
            
            # if any of the child nodes are False, or the depth difference > 1
            # return Fasle, and we don't care the depth value returned
            if left == False or right == False or abs(left_depth - right_depth) > 1:
                return (0, False)
            
            # otherwise, return the depth of the current node
            # and the current node is a balanced tree
            return (max(left_depth, right_depth) + 1, True)
        
        # call the root node and return its boolean value
        return traverse(root)[1]
       

#### Leetcode 108. Convert Sorted Array to Binary Search Tree
* Overview
  + Given an integer array nums where the elements are sorted in ascending order, convert it to a 
height-balanced binary search tree.

* Algorithm 
  + select the mid element (left mid) as the root element, then recursively call the function to build left and right subtrees. Connect the child nodes to root and return root
  + if start index > end index, return empty node
  + if start == end, return a single tree node
* time complexity
  + O(N) since we have to traverse all nodes
* space complexity
  + O(logN) we just travers the height of the balanced tree

In [None]:
# Definition for a binary tree node.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
class Solution:
    def sortedArrayToBST(self, nums: List[int]) -> Optional[TreeNode]:
        
        def build_tree(start: int, end: int) -> Optional[TreeNode]:
            # if start > end, return None
            if start > end:
                return None
            
            # single element list return a single TreeNode
            if start == end:
                return TreeNode(nums[start])
            
            # get the mid as the root
            mid = start + (end - start) // 2
            root = TreeNode(nums[mid])
            
            # build left and right nodes
            # all left nodes have smaller values than root
            # and right nodes have bigger values than root
            # and return the root
            root.left = build_tree(start, mid-1)
            root.right = build_tree(mid+1, end)
            return root
        
        return build_tree(0, len(nums)-1)
        