In [None]:
"""

Max path nodes in rooted tree

Given an arbitrary unweighted rooted tree which consists of N nodes. The goal of
the problem is to find largest distance between two nodes in a tree. Distance 
between two nodes is a number of edges on a path between the nodes (there will 
be a unique path between any pair of nodes since it is a tree).

- The nodes will be numbered 0 through N - 1.
- The tree is given as an array A, there is an edge between nodes A[i] and i (0 <= i < N). Exactly one of the i's will have A[i] equal to -1, it will be root node.

Problem Constraints: 
   1 <= N <= 40000

Input Format: 
   First and only argument is an integer array A of size N.

Output Format: 
   Return a single integer denoting the largest distance between two nodes in a tree.


Example

Input 1:
    A = [-1, 0, 0, 0, 3]

Output 1:
   3

Explanation

 node 0 is the root and the whole tree looks like this: 
 
          0
          
       /  |  \
       
      1   2   3
      
               \
               
                4

 One of the longest path is 1 -> 0 -> 3 -> 4 and its length is 3, thus the answer is 3.

### Optimal - Solution

* Two BFS
* First from node U to farthest node X
* Then  from node X to farthest node Y
* Path X ---> Y would be longest
"""

import sys
sys.setrecursionlimit(50000)
class Solution:
    # @param A : list of integers
    # @return an integer
    def get_max_two(self, a, b, c):
        if a > c and b > c:
            return a, b
        if a > c:
            return a, max(b, c)
        if b > c:
            return b, max(a, c)
        return c, max(a, b)
 
    def solve_diameter(self, root, adj):
        a, b, umaxd = 0, 0, 0
        for node in adj[root]:
            maxpl, maxd = self.solve_diameter(node, adj)
            a, b  = self.get_max_two(a, b, maxpl)
            umaxd = max(maxd, umaxd)
        return max(a, b) + 1, max(a+b, umaxd)
 
    def solve(self, A):
        l = len(A)
        if l == 0 or l == 1:
            return 0
        elif l == 2:
            return 1
        adj = {}
        for i in range(l):
            adj[i] = []
        root = None
        for i, p in enumerate(A):
            if p == -1:
                root = i
            else:
                adj[p].append(i)
        return max(self.solve_diameter(root, adj))