## Problem statement

Given the root of a binary tree, find the diameter.

*Note: Diameter of a Binary Tree is the maximum distance between any two nodes*

In [11]:
class BinaryTreeNode:

    def __init__(self, data):
        self.left = None
        self.right = None
        self.data = data
        
    def __repr__(self):
        return 'Node({})'.format(self.data)


In [28]:
# My solution was here, and clearly I have misunderstood.
#
# I was thinking the diameter was going to be the width of a level of a tree.
# You know. Like a damn diameter. So that would be the distance along a level 
# from the first used node to the last used node

def strip_level(nodes, level):
    stripped = [n for n, l in nodes if l == level]
    start = 0
    for i in range(len(stripped)):
        if stripped[i]:
            break
        start += 1
    end = len(stripped) - 1
    for j in range(len(stripped) - 1, 0, -1):
        if stripped[j]:
            break
        end -= 1
    return stripped[start:end]
        

def diameter_of_binary_tree_wrong(root):
    """
    """
    # build a list of the nodes along with their level
    level = 1
    nodes = [(root, 1)]
    pos = 0
    while pos < len(nodes):
        node, l = nodes[pos]
        pos += 1
        if node:
            nodes.append((node.left, l + 1))
            nodes.append((node.right, l + 1))
            
    # debug
    print(nodes)
    
    max_diameter = 0
    level = 1
    while True:
        level_nodes = strip_level(nodes, level)
        # debug
        print(level, level_nodes)
        if len(level_nodes) < 1:
            break
        if len(level_nodes) - 1 > max_diameter:
            max_diameter = len(level_nodes) - 1
        level += 1
    
    return max_diameter

In [29]:
# Second attempt now that I understand the problem
#
# A path between two nodes means that the two share a common ancestor
# So there is some node that is that ancestor for the longest path
# And for that node the longest path is the height of the left subtree
# + the height of the right subtree. So at the root the longest path
# is either through the root or is the longest path somewhere in the left
# or the right subtrees.
# 
# so this sounds like a recursive solution

def diameter_of_binary_tree_two(root):
    return diameter_and_height_of_tree(root)[0]

def diameter_and_height_of_tree(root):
    '''Return (diameter, height)'''
    if not root:
        return (0, 0)
    
    left_d, left_h = diameter_and_height_of_tree(root.left)
    right_d, right_h = diameter_and_height_of_tree(root.right)
    
    height = max(left_h, right_h) + 1
    diameter = max(left_d, right_d, left_h + right_h)
    return (diameter, height)

You can use the following function to test your code with custom test cases. The function `convert_arr_to_binary_tree` takes an array input representing level-order traversal of the binary tree.


<img src='./resources/01-binary-tree.png'>

The above tree would be represented as `arr = [1, 2, 3, 4, None, 5, None, None, None, None, None]`

Notice that the level order traversal of the above tree would be `[1, 2, 3, 4, 5]`. 

Note the following points about this tree:
* `None` represents the lack of a node. For example, `2` only has a left node; therefore, the next node after `4` (in level order) is represented as `None`
* Similarly, `3` only has a left node; hence, the next node after `5` (in level order) is represted as `None`.
* Also, `4` and `5` don't have any children. Therefore, the spots for their children in level order are represented by four `None` values (for each child of `4` and `5`).

In [22]:
from queue import Queue

def convert_arr_to_binary_tree(arr):
    """
    Takes arr representing level-order traversal of Binary Tree 
    """
    index = 0
    length = len(arr)
    
    if length <= 0 or arr[0] == -1:
        return None

    root = BinaryTreeNode(arr[index])
    index += 1
    queue = Queue()
    queue.put(root)
    
    while not queue.empty():
        current_node = queue.get()
        left_child = arr[index]
        index += 1
        
        if left_child is not None:
            left_node = BinaryTreeNode(left_child)
            current_node.left = left_node
            queue.put(left_node)
        
        right_child = arr[index]
        index += 1
        
        if right_child is not None:
            right_node = BinaryTreeNode(right_child)
            current_node.right = right_node
            queue.put(right_node)
    return root

    
    

In [30]:
first = convert_arr_to_binary_tree([1, None, None])
print(diameter_of_binary_tree_two(first))

second = convert_arr_to_binary_tree([1, 2, 3, None, None, None, None])
print(diameter_of_binary_tree_two(second))

third = convert_arr_to_binary_tree([1, 2, 3, 4, None, 5, None, None, None, None, None])
print(diameter_of_binary_tree_two(third))

fourth = convert_arr_to_binary_tree([1, 2, 3, 4, None, 5, 6, None, None, None, 7, None, None, None, None])
print(diameter_of_binary_tree_two(fourth))


0
2
4
5


In [35]:
# Solution
def diameter_of_binary_tree(root):
    return diameter_of_binary_tree_func(root)[1]
    
def diameter_of_binary_tree_func(root):
    """
    Diameter for a particular BinaryTree Node will be:
        1. Either diameter of left subtree
        2. Or diameter of a right subtree
        3. Sum of left-height and right-height
    :param root:
    :return: [height, diameter]
    """
    if root is None:
        return 0, 0

    left_height, left_diameter = diameter_of_binary_tree_func(root.left)
    right_height, right_diameter = diameter_of_binary_tree_func(root.right)

    current_height = max(left_height, right_height) + 1
    height_diameter = left_height + right_height
    current_diameter = max(left_diameter, right_diameter, height_diameter)

    return current_height, current_diameter

In [38]:
def test_function(test_case):
    arr = test_case[0]
    solution = test_case[1]
    root = convert_arr_to_binary_tree(arr)
    # output = diameter_of_binary_tree(root)
    output = diameter_of_binary_tree_two(root)
    print(output)
    if output == solution:
        print("Pass")
    else:
        print("Fail")

In [39]:
arr = [1, 2, 3, 4, 5, None, None, None, None, None, None]
solution = 3

test_case = [arr, solution]
test_function(test_case)

3
Pass


In [40]:
arr = [1, 2, 3, 4, None, 5, None, None, None, None, None]
solution = 4

test_case = [arr, solution]
test_function(test_case)

4
Pass


In [41]:
arr = [1, 2, 3, None, None, 4, 5, 6, None, 7, 8, 9, 10, None, None, None, None, None, None, 11, None, None, None]
solution = 6

test_case = [arr, solution]
test_function(test_case)

6
Pass
