In [None]:
"""
You are given the root of a binary tree with n nodes. Each node is assigned a unique value from 1 to n. You are also given an array queries of size m.

You have to perform m independent queries on the tree where in the ith query you do the following:

Remove the subtree rooted at the node with the value queries[i] from the tree. It is guaranteed that queries[i] will not be equal to the value of the root.
Return an array answer of size m where answer[i] is the height of the tree after performing the ith query.

Note:

The queries are independent, so the tree returns to its initial state after each query.
The height of a tree is the number of edges in the longest simple path from the root to some node in the tree.
 

Example 1:


Input: root = [1,3,4,2,null,6,5,null,null,null,null,null,7], queries = [4]
Output: [2]
Explanation: The diagram above shows the tree after removing the subtree rooted at node with value 4.
The height of the tree is 2 (The path 1 -> 3 -> 2).
Example 2:


Input: root = [5,8,9,2,1,3,7,4,6], queries = [3,2,4,8]
Output: [3,2,3,2]
Explanation: We have the following queries:
- Removing the subtree rooted at node with value 3. The height of the tree becomes 3 (The path 5 -> 8 -> 2 -> 4).
- Removing the subtree rooted at node with value 2. The height of the tree becomes 2 (The path 5 -> 8 -> 1).
- Removing the subtree rooted at node with value 4. The height of the tree becomes 3 (The path 5 -> 8 -> 2 -> 6).
- Removing the subtree rooted at node with value 8. The height of the tree becomes 2 (The path 5 -> 9 -> 3).
 

Constraints:

The number of nodes in the tree is n.
2 <= n <= 105
1 <= Node.val <= n
All the values in the tree are unique.
m == queries.length
1 <= m <= min(n, 104)
1 <= queries[i] <= n
queries[i] != root.val
"""

'''
EULERIAN TOUR: primarily used while handling problems of subtree removal
make a tour during performing a dfs traversal and store it

for a trivial tree: root
                    /  \
                  left right
the tour will be:root->left->left->right->right->root
so every node will appear twice in the tour array once when it is visited 
for the first time(no left or right subtree of it is visited earlier) and 
the other when it is visited for the last time(when both left and right 
subtrees are visited)

NOW WHEN A SUBTREE IS REMOVED THE REMAINING TREE THAT WE HAVE WOULD BE THE
SUBARRAY BEFORE THE FIRST OCCURENCE OF ROOT AND THE SUBARRAY AFTER THE LAST
OCCURENCE OF ROOT

Now we will make a prefix-max-height and a suffix-max-height for this tour 
array and after the removal of a subtree the resultant maximum height will 
be max(
        prefix_max(first occurence ofsubtree- root-1),
        suffix_max(last occurence ofsubtree- root+1)
    )

https://www.youtube.com/watch?v=s62a0uxeRkE -- best
'''

'''
So Basically we create with euler tour array with elements and its 
depth from root, like for tree as shown :
      "1"
   " /   \"
  "2      3"
 "/"
"4"

euler tour array is [(1,0),(2,1),(4,2),(4,2),(2,1),(3,1),(3,1),(1,0)]

now lets say if we remove subtree 2 than array from index 1 to 4 is 
removed so now what is maximum height of tree so basically is the maximum 
height till from index 0 to 1 (not included) and maximum height from 
index(5 to 7) so we can maintain a prefix and suffix maximum array upto 
a particular point

LET ME EXPLAIN THE MY ROUGH ALGO NOW

create a euler tour array as u can see
for each node maintain its starting and ending index in euler tour array
pre and suf have the maximum in prefix till i and suffix from i to end
So for a particular query we find starting and ending point of subtree 
with the element and find prefix maxium till start and suffix maximum 
till end
'''

from typing import Optional, List
from collections import defaultdict

# Definition for a binary tree node.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

class Solution:
    def treeQueries(self, root: Optional[TreeNode], queries: List[int]) -> List[int]:
        lvl_hts  = defaultdict(list)
        node_lvl = {}

        def assign_lvl_height(node, lvl):
            if not node:
                return -1
            left_lvl_ht = assign_lvl_height(node.left, lvl+1)
            rite_lvl_ht = assign_lvl_height(node.right, lvl+1)
            cur_root_ht = 1 + max(left_lvl_ht, rite_lvl_ht)
            curr_hts = sorted([*lvl_hts[lvl], cur_root_ht], reverse=True)[:2]
            lvl_hts[lvl] = curr_hts
            node_lvl[node.val] = [lvl, cur_root_ht]
            return cur_root_ht
        
        def result():
            for val in queries:
                lvl, ht = node_lvl[val]
                hts = lvl_hts[lvl]
                if len(hts)==1:
                    yield lvl - 1
                else:
                    yield hts[0]+lvl if hts[0] != ht else hts[1]+lvl

        assign_lvl_height(root, 0)

        return list(result())

In [None]:
# Explanation:
# Find height of each node
# Perform a dfs, the parameters are current node, current depth and max height without current node
# Max height without current node will be max height elsewhere in the tree or height of sibling node + curr depth
from functools import cache
class Solution:
    def treeQueries(self, root: Optional[TreeNode], queries: List[int]) -> List[int]:

        result = {}

        @cache
        def height(node):
            if not node:
                return 0
            return 1 + max(height(node.left), height(node.right))

        def set_heights_after_removal(node, lvl, mx):
            if not node:
                return
            result[node.val] = mx
            set_heights_after_removal(node.left, lvl+1, max(mx, lvl + height(node.right)))
            set_heights_after_removal(node.right, lvl+1, max(mx, lvl + height(node.left)))

        set_heights_after_removal(root, 0, 0)
        return [result[v] for v in queries]

In [None]:
# Euler Tour

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def treeQueries(self, root: Optional[TreeNode], queries: List[int]) -> List[int]:
        def solve(node,level):
            euler_tour.append((node.val,level))
            if node.left:
                solve(node.left,level+1)
            if node.right:
                solve(node.right,level+1)
            euler_tour.append((node.val,level))
        euler_tour=[]
        solve(root,1)
        d={}
        for i,el in enumerate(euler_tour):
            if el[0] in d:
                d[el[0]].append(i)
            else:
                d[el[0]]=[i]
        pre=[0]
        for el,l in euler_tour:
            pre.append(max(pre[-1],l))
        suf=[0]
        for el,l in euler_tour[::-1]:
            suf.append(max(suf[-1],l))
        suf=suf[::-1]
        ans=[]
        
        for el in queries:
            start,end=d[el]
            ans.append(max(pre[start],suf[end+1])-1)
        return ans