99. Recover Binary Search Tree
Medium

2902

109

Add to List

Share
You are given the root of a binary search tree (BST), where exactly two nodes of the tree were swapped by mistake. Recover the tree without changing its structure.

Follow up: A solution using O(n) space is pretty straight forward. Could you devise a constant space solution?

 

Example 1:


Input: root = [1,3,null,null,2]
Output: [3,1,null,null,2]
Explanation: 3 cannot be a left child of 1 because 3 > 1. Swapping 1 and 3 makes the BST valid.
Example 2:


Input: root = [3,1,4,null,null,2]
Output: [2,1,4,null,null,3]
Explanation: 2 cannot be in the right subtree of 3 because 2 < 3. Swapping 2 and 3 makes the BST valid.
 

Constraints:

The number of nodes in the tree is in the range [2, 1000].
-231 <= Node.val <= 231 - 1

In [None]:
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def recoverTree(self, root: TreeNode) -> None:
        """
        Do not return anything, modify root in-place instead.
        """
        """recursive: in order"""
        trees = []
        def in_order(root):
            if root:
                in_order(root.left)
                trees.append(root)
                in_order(root.right)
        in_order(root)
        n = len(trees)
        # from left to right: find li[i] > li[i+1]
        for i in range(n-1):
            if trees[i].val > trees[i+1].val:
                x = trees[i]
                break;
        # from right to left: find li[j] < li[j-1]
        """will end early at li[i+1] if search from left to right"""
        for j in range(n-1, 0, -1):
            if trees[j].val < trees[j-1].val:
                y = trees[j]
                break
        x.val, y.val = y.val, x.val

There are 2 cases: The values that need to be swapped are either adjacent or not adjacent. If they're adjacent, then there will be one "drop"; if they're not adjacent, then there will be two "drops".

adjacent: ... _ < _ < A > B < _ < _ ...
                      ^^^^^
                      drop #1

not adjacent: ... _ < _ < A > X < _ < Y > B < _ < _ ... (X may be the same as Y, but it's irrelevant)
                          ^^^^^       ^^^^^
                          drop #1     drop #2
In both cases, we want to swap A and B. 
So the idea is to keep a drops array and append a tuple of (lastNode, currentNode) whenever we come across lastValue > currentValue. At the end of the traversal, the drops array must have either 1 or 2 tuples (otherwise, there would be more than 2 nodes that need to be swapped).

In [None]:
"""Here's the clear but not-so-clean way to swap them:"""
if len(drops) == 1: # drops == [(A, B)]
    drops[0][0].val, drops[0][1].val = drops[0][1].val, drops[0][0].val
else: # drops == [(A, X), (Y, B)]
    drops[0][0].val, drops[1][1].val = drops[1][1].val, drops[0][0].val

    
"""Here's the clean but not-so-clear way that gets rid of the conditional branching:"""
drops[0][0].val, drops[-1][1].val = drops[-1][1].val, drops[0][0].val

In [None]:
class Solution:
    def recoverTree(self, root: TreeNode) -> None:
        """iterative: in order - O(logN) extra space"""
        # def in_order(root):
        #     curr = root
        #     tree, stack = [], []
        #     while curr or stack:
        #         while curr:
        #             stack.append(curr)
        #             curr = curr.left
        #         node = stack.pop()
        #         tree.append(node.val)
        #         curr = node.right
                
        curr, prev, = root, TreeNode(float('-inf'))
        drops, stack = [], []
        while curr or stack:
            while curr:
                stack.append(curr)
                curr = curr.left
            node = stack.pop()
            if node.val < prev.val:
                drops.append((prev, node))
            prev, curr = node, node.right
        drops[0][0].val, drops[-1][1].val = drops[-1][1].val, drops[0][0].val

In [None]:
class Solution:
    def recoverTree(self, root: TreeNode) -> None:
        """iterative: Morris in order - O(1) extra space"""
        # def in_order_Morris:
        #     curr = root
        #     tree = []
        #     while curr:
        #         if curr.left:
        #             temp = curr.left
        #             while temp.right and temp.right != curr:
        #                 temp = temp.right
        #             if not temp.right:
        #                 temp.right, cur = cur, cur.left
        #                 continue
        #             temp.right = None
        #         tree.append(curr.val)
        #         curr = curr.right
        
        curr, prev = root, TreeNode(float('-inf'))
        drops = []
        while curr:
            if curr.left:
                temp = curr.left
                while temp.right and temp.right != curr:
                    temp = temp.right
                if not temp.right:
                    temp.right, curr = curr, curr.left
                    continue
                temp.right = None
            if curr.val < prev.val: drops.append((prev, curr))
            prev, curr = curr, curr.right
        drops[0][0].val, drops[-1][1].val = drops[-1][1].val, drops[0][0].val   