# Diameter of a Binary Tree

The diameter of a tree (sometimes called the width) is the number of nodes on the longest path between two end nodes

The function is O(n), no matter the shape of the tree. The original function is O(n^2) in the worst case when the 
tree is very unbalanced because of the repeated height calculations.

In [15]:
class Node:

    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None

        
# returns the diameter and the height of the tree        
def diameter_height(node):
    if node is None:
        return 0, 0
    
    
    l_diameter, l_height = diameter_height(node.left)
    r_diameter, r_height = diameter_height(node.right)
    
    return max(l_height + r_height + 1, l_diameter, r_diameter), 1 + max(l_height, r_height)

#  this function is just compute the diameter (by discarding the height).
def find_tree_diameter(node):
    d, _ = diameter_height(node)
    return d

root = Node(1)
root.left = Node(2)
root.right = Node(3)
root.left.left = Node(4)
root.left.right = Node(5)

print ("Diameter of given binary tree is %d" %(find_tree_diameter(root)))

Diameter of given binary tree is 4


In [13]:
class Node:

    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None

        
# returns the diameter and the height of the tree        
def diameter_height(node):
    if node is None:
        return 0, 0
    
    
    l_diameter, l_height = diameter_height(node.left)
    r_diameter, r_height = diameter_height(node.right)
    
    return max(l_height + r_height + 1, l_diameter, r_diameter), 1 + max(l_height, r_height)

#  this function is just compute the diameter (by discarding the height).
def find_tree_height(node):
    _, h= diameter_height(node)
    return h

root = Node(1)
root.left = Node(2)
root.right = Node(3)
root.left.left = Node(4)
root.left.right = Node(5)

print ("Diameter of given binary tree is %d" %(find_tree_height(root)))

Diameter of given binary tree is 3
