Given a root node reference of a BST and a key, delete the node with the given key in the BST. Return the root node reference (possibly updated) of the BST.

Basically, the deletion can be divided into two stages:
1. Search for a node to remove.
2. If the node is found, delete the node.

In [1]:
# O(logn) time | O(height) space
def deleteNode(root, key):
    if not root:
        return None
    if key < root.val:
        root.left = deleteNode(root.left, key)      #find node and delete in left subtree
    elif key > root.val:
        root.right = deleteNode(root.right, key)    #find node and delete in right subtree
    else:                                           #found node! -- delete it now
        if not root.left:
            return root.right                       #if no left, delete root, new root is root.right
        if not root.right:
            return root.left                        #if no right, delete root, new root is root.left
        #if node has both left and right children -- replace val by successor and then delete successor
        temp = root.right
        while temp.left:
            temp = temp.left
        root.val = temp.val         
        root.right = deleteNode(root.right, root.val) #delete successor (min node in right subtree)
    return root

separating into functions and using successor and predecessor

In [2]:
def deleteNode(root, key):
    
    def successor(node):
        node = node.right
        while node.left:
            node = node.left
        return node.val

    def predecessor(node):
        node = node.left
        while node.right:
            node = node.right
        return node.val
    
    if not root: 
        return None
    if key < root.val:
        root.left = deleteNode(root.left, key)   #search and delete in left subtree
    elif key > root.val:
        root.right = deleteNode(root.right, key) #search and delete in right subtree
    else:                                        #delete current node
        if not root.left and not root.right:     #node is leaf, root is leaf
            root = None
        elif root.right:            #node has right, replace with successor and delete successor in right subtree
            root.val = successor(root)
            root.right = deleteNode(root.right, root.val)
        elif root.left:             #node has left, replace with predecessor and delete predecessor in left subtree
            root.val = predecessor(root)
            root.left = deleteNode(root.left, root.val)
    return root