## 617 Merge two Binary Trees

Given two binary trees and imagine that when you put one of them to cover the other, some nodes of the two trees are overlapped while the others are not.

You need to merge them into a new binary tree. The merge rule is that if two nodes overlap, then sum node values up as the new value of the merged node. Otherwise, the NOT null node will be used as the node of new tree.

Example 1:

Input:
```
	Tree 1                     Tree 2                  
          1                         2                             
         / \                       / \                            
        3   2                     1   3                        
       /                           \   \                      
      5                             4   7 
```
Output: 
Merged tree:
```
	     3
	    / \
	   4   5
	  / \   \ 
	 5   4   7
```

Note: The merging process must start from the root nodes of both trees.

In [3]:
class TreeNode:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def levelTraversal(root):
    output = []

    def helper(head, level):
        if head:

            if level+1 > len(output):
                output.append([])

            output[level].append(head.val)
            helper(head.left, level+1)
            helper(head.right, level+1)

    helper(root, 0)
    return output

def mergeTrees(r1, r2):

    def helper(h1, h2, h):
        if h1 or h2:

            h = TreeNode((h1.val if h1 else 0) + (h2.val if h2 else 0))
            h.left = helper((h1.left if h1 else None), 
                            (h2.left if h2 else None), h.left)
            h.right = helper((h1.right if h1 else None),
                             (h2.right if h2 else None), h.right)

            return h

        return None
    
    t = helper(r1, r2, None)

    return t

In [4]:
r1 = TreeNode(1, TreeNode(3, TreeNode(5)), TreeNode(2))
r2 = TreeNode(2, TreeNode(1, None, TreeNode(4)), TreeNode(3, None, TreeNode(7)))

mTree = mergeTrees(r1, r2)
print(levelTraversal(mTree))

[[3], [4, 5], [5, 4, 7]]


## Optimized version

In [9]:
def mergeTreesOpt(t1, t2):
    if t1 is None:
        return t2
    
    if t2 is None:
        return t1
    
    t1.val = t1.val + t2.val
    t1.right = mergeTreesOpt(t1.right,t2.right)
    t1.left = mergeTreesOpt(t1.left, t2.left)
    
    return t1

In [10]:
r1 = TreeNode(1, TreeNode(3, TreeNode(5)), TreeNode(2))
r2 = TreeNode(2, TreeNode(1, None, TreeNode(4)), TreeNode(3, None, TreeNode(7)))

mTree = mergeTreesOpt(r1, r2)
print(levelTraversal(mTree))

[[3], [4, 5], [5, 4, 7]]
