## LCA using Binary Lifting (DP)

### For kth ancestor, use binary representation of k to jump through the ancestors. For examples, if k = 6, then bin(6) = 0 1 1 0. Then we use bitwise operations to find in which positions we have a set bit. Let those positions be p. Then we jump through all those positions (using 2^p) in any order and reach the desired ancestor. For k = 6, we have to make a jump of p = 1 (ie, 2 ^ 1 th ancestor), and p = 2 (2 ^ 2 th ancestor). That is, jump through 2nd and 4th ancestor to reach the 6th ancestor

In [1]:
from collections import defaultdict as dd
import math

graph = [[1, 2], [0, 3, 4], [0, 5, 6], [1, 7, 8], [1, 9], [2], [2, 10], [3], [3, 11, 12, 13], [4], [6], [8], [8], [8]]
class bin_lift():
    def __init__(self, graph):
        self.graph = graph
        self.n = len(graph)
        self.entry = [None] * self.n
        self.exit = [None] * self.n
        self.parents = dd(lambda : dd(lambda : -1))
        self.height = math.ceil(math.log2(self.n))
        timer = 0
        
        def dfs(u, p):
            nonlocal self, timer
            self.entry[u] = timer
            timer += 1
            self.parents[u][0] = p     #1st ancestor
            for i in range(1,self.height+1):
                self.parents[u][i] = self.parents[self.parents[u][i-1]][i-1]
            for v in self.graph[u]:
                if not v == p:
                    dfs(v, u)
            self.exit[u] = timer
            timer += 1
            
        dfs(0, -1)     #or dfs(0, 0), if ancestors of root are to be kept as root itself
        
    def is_ancestor(self, u, v):
        return self.entry[u] <= self.entry[v] and self.exit[u] >= self.exit[v]

    def lca(self, u, v):
        if self.is_ancestor(u, v):
            return u
        if self.is_ancestor(v, u):
            return v
        for i in range(self.height, -1, -1):
            if not self.is_ancestor(self.parents[u][i], v):
                u = self.parents[u][i]
        return self.parents[u][0]
    
    def k_ancestor(self, node, k):
        for i in range(self.height + 1):
            if k & (1 << i):
                node = self.parents[node][i]
                if node == -1:
                    break           
        return node

go = bin_lift(graph)
go.k_ancestor(10,2)

2

# Tree Flattening 
#### Useful for subtree queries, such as subtree sum, etc, which can be done using O(logn) with a Fenwick  / Segment tree
#### Tree Flattening, as the name says, is converting a tree / graph to a linear structure.
#### We can do this with the help of DFS and a timer array, which stores the entry and exit of each node, which serve as range for a subtree
![lp](tf1.png)
### Points to remember ->

## 1. Use entry[node] = timer; timer+=1; exit[node] = timer; timer+=1 for path queries, such as path sum, and subtree queries such as subtree sum, with queries on the flattened array, which contains both entry and exit points of every node

### Values at entry and exit points of a node in the flattened array must be chosen according to the question. For example, for node-count query between 2 nodes, use entry[node] = 1 and exit[node] = -1. Similarly, for subtree queries, such as subtree sum, use entry[node] = 1 and exit[node] = 0.

## 2. Use entry[node] = timer; timer+=1; exit[node] = timer - 1; for subtree queries, where having an exit point explicitly in the flattened array is not needed at all.



## 1. Touring, for subtree sum, gcd etc using method 2

### Q. Find the sum of all nodes in a subtree
### The idea is to use a DFS tour and apply ranged queries using a Segment / Fenwick tree. We don't mind the exits of 2 nodes being at the same time.
# Subtree query for a node = seg_tree.query(entry[node], exit[node])

## 2. Flattening for path sum queries, etc using method 1
### Q. Find a path from a node / leaf to another node / leaf

### The idea is to use a flattened array with distinct entry and exit points. At entry points we place the value of the node, and ant the exit points in the flattened array , we place the negative of value of the node, which will cancel out the nodes that are not in the path of the required path sum. Now, to find the sum, we first find the LCA of two nodes, found using binary lifting in O(logn) time. Then using the property that in a Tree, the path of 2 nodes is equal to the sum of paths of lca to 1 node and lca to the second node, we find the sum

# Path(u, v) = Path(lca, u) + Path(lca, v)
# Path(u, v) = seg_tree.query(entry[lca], entry[u]) + seg_tree.query(entry[lca], entry[v]) - value(lca)(optional)

In [2]:
from collections import defaultdict as dd
import math

class segtree():
        def __init__(self, arr, func, fid=0):
            self.func = func
            self.tree = None
            self.arr = list(arr)
            self.f_identity = fid
            self._build_tree()
        def _build_tree(self):           
            self.tree = [None]*(len(self.arr) - 1) + self.arr
            for i in range(len(self.arr)-2, -1, -1):
                self.tree[i] = self.func(self.tree[2*i+1], self.tree[2*i+2])
        def update(self, index, value):         
            self.arr[index] = value
            index += (len(self.arr) - 1)
            self.tree[index] = value   
            while (index := ((index-1)//2)) >= 0:  
                updated_internal_node_val = self.func(self.tree[2*index+1], self.tree[2*index+2])
                if updated_internal_node_val != self.tree[index]:
                    self.tree[index] = updated_internal_node_val
                else: 
                    break
        def query(self, l, r):                   
            if l < 0 or r >= len(self.arr):
                return None
            l += (len(self.arr) - 1)
            r += (len(self.arr) - 1)
            res = self.f_identity
            while r >= l:   
                if not (l & 1) :   
                    res = self.func(self.tree[l], res)
                    l += 1
                if (r & 1) == 1 :      
                    res = self.func(self.tree[r], res)
                    r -= 1
                r = (r-1) // 2
                l = (l-1) // 2 
            return res
    
class flattened_path_queries():
    def __init__(self, graph, values = None):
        self.graph = graph
        self.n = len(graph)
        self.entry = [None] * self.n
        self.exit = [None] * self.n
        self.parents = dd(lambda : dd(lambda : -1))
        self.height = math.ceil(math.log2(self.n))
        self.values = ([0] * self.n) if not values else values
        self.flattened_arr = [None] * (2 * self.n)
        timer = 0
        
        def dfs(u, p):
            nonlocal self, timer
            self.entry[u] = timer
            timer += 1
            self.parents[u][0] = p     #1st ancestor
            for i in range(1,self.height+1):
                self.parents[u][i] = self.parents[self.parents[u][i-1]][i-1]
            for v in self.graph[u]:
                if not v == p:
                    dfs(v, u)
            self.exit[u] = timer
            timer += 1
            
        dfs(0, 0)
        for i in range(self.n):
            self.flattened_arr[self.entry[i]] = self.values[i]
            self.flattened_arr[self.exit[i]] = -self.values[i]
        self.sgt = segtree(self.flattened_arr, lambda x, y : x + y)   
        
    def is_ancestor(self, u, v):
        return self.entry[u] <= self.entry[v] and self.exit[u] >= self.exit[v]

    def lca(self, u, v):
        if self.is_ancestor(u, v):
            return u
        if self.is_ancestor(v, u):
            return v
        for i in range(self.height, -1, -1):
            if not self.is_ancestor(self.parents[u][i], v):
                u = self.parents[u][i]
        return self.parents[u][0]
    
    def update_node(self, node, val):
        self.sgt.update(self.entry[node], val)
        self.sgt.update(self.exit[node], -val)
        self.values[node] = val
    
    def path_query(self, u, v):
        lca = self.lca(u, v)
        lca_entry = self.entry[lca]
        l = self.entry[u]
        r = self.entry[v]
        return self.sgt.query(lca_entry, l) + self.sgt.query(lca_entry, r) - self.values[lca]

graph = [[1, 2], [0, 3, 4], [0, 5, 6], [1, 7, 8], [1, 9], [2], [2, 10], [3], [3, 11, 12, 13], [4], [6], [8], [8], [8]]
go = flattened_path_queries(graph)
go.update_node(1, 1)
go.update_node(6, 100)
go.path_query(9,2)

1

## Node to Node Path - Tree flattening
### to print the path in o(logn) time, we can use a modified segment tree, where each leaf node is a set containing a single element, and the combiner funtion is symmetric difference.

### At the entry and exit points of the node, use the node number itself. 
### If, in a path, a node enters and exits, symmetric difference of the sets will remove all the occurences of the node which occur more than once.

In [3]:
from collections import defaultdict as dd
import math

class segtree():
        def __init__(self, arr, func, fid=0):
            self.func = func
            self.tree = None
            self.arr = list(arr)
            self.f_identity = fid
            self._build_tree()
        def _build_tree(self):           
            self.tree = [None]*(len(self.arr) - 1) + self.arr
            for i in range(len(self.arr)-2, -1, -1):
                self.tree[i] = self.func(self.tree[2*i+1], self.tree[2*i+2])
        def update(self, index, value):         
            self.arr[index] = value
            index += (len(self.arr) - 1)
            self.tree[index] = value   
            while (index := ((index-1)//2)) >= 0:  
                updated_internal_node_val = self.func(self.tree[2*index+1], self.tree[2*index+2])
                if updated_internal_node_val != self.tree[index]:
                    self.tree[index] = updated_internal_node_val
                else: 
                    break
        def query(self, l, r):                   
            if l < 0 or r >= len(self.arr):
                return None
            l += (len(self.arr) - 1)
            r += (len(self.arr) - 1)
            res = self.f_identity
            while r >= l:   
                if not (l & 1) :   
                    res = self.func(self.tree[l], res)
                    l += 1
                if (r & 1) == 1 :      
                    res = self.func(self.tree[r], res)
                    r -= 1
                r = (r-1) // 2
                l = (l-1) // 2 
            return res
    
class flattened_path_queries():
    def __init__(self, graph):
        self.graph = graph
        self.n = len(graph)
        self.entry = [None] * self.n
        self.exit = [None] * self.n
        self.parents = dd(lambda : dd(lambda : -1))
        self.height = math.ceil(math.log2(self.n))
        self.path_arr = [None] * (2 * self.n)
        timer = 0
        def dfs(u, p):
            nonlocal self, timer
            self.entry[u] = timer
            timer += 1
            self.parents[u][0] = p     #1st ancestor
            for i in range(1, self.height+1):
                self.parents[u][i] = self.parents[self.parents[u][i-1]][i-1]
            for v in self.graph[u]:
                if not v == p:
                    dfs(v, u)
            self.exit[u] = timer
            timer += 1
            
        dfs(0, 0)
        for i in range(self.n):
            self.path_arr[self.entry[i]], self.path_arr[self.exit[i]] = {i}, {i}
        self.pt = segtree(self.path_arr, lambda x, y: x.symmetric_difference(y), set())
        
    def is_ancestor(self, u, v):
        return self.entry[u] <= self.entry[v] and self.exit[u] >= self.exit[v]

    def lca(self, u, v):
        if self.is_ancestor(u, v):
            return u
        if self.is_ancestor(v, u):
            return v
        for i in range(self.height, -1, -1):
            if not self.is_ancestor(self.parents[u][i], v):
                u = self.parents[u][i]
        return self.parents[u][0]
    
    def get_path(self, v, u):
        lca = self.lca(u, v)
        lca_entry = self.entry[lca]
        l = self.entry[u]
        r = self.entry[v]
        return self.pt.query(lca_entry, l).union(self.pt.query(lca_entry, r))
    
graph = [[1, 2], [0, 3, 4], [0, 5, 6], [1, 7, 8], [1, 9], [2], [2, 10], [3], [3, 11, 12, 13], [4], [6], [8], [8], [8]]
go = flattened_path_queries(graph)
go.get_path(7,12)

{3, 7, 8, 12}