In [2]:
class SegmentTree():   
    
    def __init__(self, arr, func, fid=0):
        self.func = func
        self.tree = None
        self.f_identity = fid
        self.length = len(arr)
        self._build_tree(arr)
        
    def _build_tree(self, arr):            # o(n)
        self.tree = [None]*(self.length - 1) + arr
        for i in range(self.length-2, -1, -1):
            self.tree[i] = self.func(self.tree[2*i+1], self.tree[2*i+2])
            
    def update(self, index, value):         # o(logn)
        index += (self.length - 1)
        self.tree[index] = value   #updating the leaf node
        while (index := ((index-1)//2)) >= 0:   #updating upper nodes
            updated_internal_node_val = self.func(self.tree[2*index+1], self.tree[2*index+2])
            if updated_internal_node_val != self.tree[index]:
                self.tree[index] = updated_internal_node_val
            else: 
                break

    def query(self, l, r):                   # O(logn)
        if l < 0 or r >= self.length:
            return None
        l += (self.length - 1)
        r += (self.length - 1)
        res = self.f_identity
        
        while r >= l:   
            
            if not (l & 1) :     #if l is a right node(even indexed in the tree), then add to the result and increament l
                res = self.func(res, self.tree[l])   #changed ordering of the arguments for non commutative combiner functions
                l += 1
                
            if (r & 1) == 1 :      #if r is left child (odd indexed in the tree), then add to the result and decreament r
                res = self.func(self.tree[r], res)
                r -= 1
            
            r = (r-1) // 2
            l = (l-1) // 2 
            
        return res
       

In [23]:
from collections import defaultdict as dd
import sys
#input = sys.stdin.readline

import math
class sgtx():
    @staticmethod
    def ceilp2(a):   #ceil power of 2
        return 2 ** math.ceil(math.log2(a))
    
    def __init__(self, arr, func = lambda x, y : x + y, fid = 0):
        self.func = func 
        self.fid = fid
        self.lazy = dd(lambda : [False, self.fid])
        self.range = dd(list)
        self._build_tree(arr)
        
    def _build_tree(self, arr):          # O(n)
        arr += [self.fid] * (sgtx.ceilp2(len(arr)) - len(arr))
        self.tree = [self.fid] * (len(arr)-1) + arr
        self.length = len(arr)
        for i in range(self.length-1, len(self.tree)):
            self.range[i] = [i-self.length+1, i-self.length+1]
        for i in range(self.length-2, -1, -1):
            self.tree[i] = self.func(self.tree[2*i+1], self.tree[2*i+2])
            self.range[i] = [self.range[2*i+1][0], self.range[2*i+2][1]]
            
    def propagate_up(self, index):       # O(logn)
        if self.lazy[index][0]:
            self.tree[index] += (self.lazy[index][1]*(self.range[index][1]-self.range[index][0]+1))
            while (index := (index-1)//2) >= 0:
                self.propagate_down(index)
                self.tree[index] = self.func(self.tree[2*index+1], self.tree[2*index+2])
    
    def propagate_down(self, index):       # O(1)
        if not self.is_leaf(index):
            left_child, right_child = 2*index+1, 2*index+2
            left_val = self.lazy[index][1]*(self.range[left_child][1]-self.range[left_child][0]+1)
            right_val = self.lazy[index][1]*(self.range[right_child][1]-self.range[right_child][0]+1)
            original_val = self.lazy[index][1]
            #self.tree[left_child] = self.func(self.tree[left_child], left_val)
            self.tree[left_child] += left_val
            self.pass_down(left_child, original_val)
            #self.tree[right_child] = self.func(self.tree[right_child], right_val)
            self.tree[right_child] += right_val
            self.pass_down(right_child, original_val)
        self.lazy[index][0], self.lazy[index][1] = False, self.fid
        
    def pass_down(self, child_node, val):
        self.lazy[child_node][0] = True
        self.lazy[child_node][1] += val
    
    def point_update(self, ):
        pass
    
    def range_update(self, l, r, val):    #O(logn ^ 2)
        if l < 0 or r >= self.length:
            return None
        l += (self.length - 1)
        r += (self.length - 1)
        while r >= l:
            if not (l & 1):               # l is the right child to some parent (even indexed in the tree)
                self.lazy[l][0] = True
                self.lazy[l][1] += val
                self.propagate_up(l)      #propagate the update to the top (O(logn))
                l += 1
            if r & 1 :                    # r is the left child to some parent (odd indexed in the tree)
                self.lazy[r][0] = True
                self.lazy[r][1] += val
                self.propagate_up(r)      #propagate the update to the top (O(logn))
                r -= 1
            l = (l-1) // 2                # move l and r to their parents
            r = (r-1) // 2
        
    def query(self, l, r):
        return self._query_util(l, r, [0,self.length-1], 0)
            
    def _query_util(self, l, r, el_range, index):            # O(logn)
        if l > r:
            return self.fid
        if el_range[0] == l and el_range[1] == r:
            return self.tree[index]
        if self.lazy[index][0]:    #push updates downwards
            self.propagate_down(index)
        mid = (el_range[0]+el_range[1])//2
        return self.func(self._query_util(l, min(r, mid),  [el_range[0], mid], index*2+1),
                self._query_util(max(l, mid+1), r, [mid+1, el_range[1]], index*2+2))
        
    def is_leaf(self, index):
        return self.range[index][0] == self.range[index][1]
    
bs = sgtx([1]*16)
bs.range_update(4,10,1)
bs.range_update(8,13,1)

In [24]:
bs.query(7,12)

15

## Assignment modifications, sum queries

In [15]:
from collections import defaultdict as dd
import math
import sys
#input = sys.stdin.readline

class sgtx():
    @staticmethod
    def ceilp2(a):   #ceil power of 2
        return 2 ** math.ceil(math.log2(a))
    
    def __init__(self, arr, func = lambda x, y : x + y, fid = 0):
        self.func = func 
        self.fid = fid
        self.lazy = dd(lambda : [False, self.fid])
        self.range = dd(list)
        self._build_tree(arr)
        
    def _build_tree(self, arr):          # O(n)
        arr += [self.fid] * (sgtx.ceilp2(len(arr)) - len(arr))
        self.tree = [self.fid] * (len(arr)-1) + arr
        self.length = len(arr)
        for i in range(self.length-1, len(self.tree)):
            self.range[i] = [i-self.length+1, i-self.length+1]
        for i in range(self.length-2, -1, -1):
            self.tree[i] = self.func(self.tree[2*i+1], self.tree[2*i+2])
            self.range[i] = [self.range[2*i+1][0], self.range[2*i+2][1]]
            
    def propagate_up(self, index):           # O(logn)
        if self.lazy[index][0]:
            self.tree[index] = (self.lazy[index][1]*(self.range[index][1]-self.range[index][0]+1))
            while (index := (index-1)//2) >= 0:
                if self.lazy[index][0]:
                    self.propagate_down(index)
                self.tree[index] = self.func(self.tree[2*index+1], self.tree[2*index+2])
    
    def propagate_down(self, index):       # O(1)
        if not self.is_leaf(index):
            left_child, right_child = 2*index+1, 2*index+2
            val = self.lazy[index][1]*(self.range[left_child][1]-self.range[left_child][0]+1)
            original_val = self.lazy[index][1]
            self.tree[left_child] = val
            self.apply_lazy(left_child, original_val)
            self.tree[right_child] = val
            self.apply_lazy(right_child, original_val)
        self.lazy[index][0], self.lazy[index][1] = False, self.fid
        
    def apply_lazy(self, child_node, val):
        self.lazy[child_node][0] = True
        self.lazy[child_node][1] = val
    
    def range_update(self, l, r, val):    #O(logn ^ 2)
        if l < 0 or r >= self.length:
            return None
        l += (self.length - 1)
        r += (self.length - 1)
        while r >= l:
            if not (l & 1):               # l is the right child to some parent (even indexed in the tree)
                self.apply_lazy(l, val)
                self.propagate_up(l)      # propagate the update to the top (O(logn))
                l += 1
            if r & 1 :                    # r is the left child to some parent (odd indexed in the tree)
                self.apply_lazy(r, val)
                self.propagate_up(r)      #propagate the update to the top (O(logn))
                r -= 1
            l = (l-1) // 2                # move l and r to their parents
            r = (r-1) // 2
    
    def is_leaf(self, index):
        return self.range[index][0] == self.range[index][1]
    
    def query(self, l, r):
        return self._query_util(l, r, [0,self.length-1], 0)
            
    def _query_util(self, l, r, el_range, index):            # O(logn)
        if l > r:
            return self.fid
        if el_range[0] == l and el_range[1] == r:
            return self.tree[index]
        if self.lazy[index][0]:    #push updates downwards
            self.propagate_down(index)
        mid = (el_range[0]+el_range[1])//2
        return self.func(self._query_util(l, min(r, mid),  [el_range[0], mid], index*2+1),
                self._query_util(max(l, mid+1), r, [mid+1, el_range[1]], index*2+2))
    
bs = sgtx([1]*8)
bs.tree

[8, 4, 4, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1]

In [16]:
bs.range_update(2,5,1)
bs.tree
#bs.query(4, 6)

[8, 4, 4, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1]

In [17]:
bs.query(3,6)
bs.tree

[8, 4, 4, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1]

In [30]:
class flattened_subtree_queries():
    def __init__(self, graph, values, root = 0):
        self.n = len(graph)
        self.entry = [None] * self.n
        self.exit = [None] * self.n
        self.values = values
        timer = 0
        self.flattened_arr = []
        def dfs(u, p):
            nonlocal graph, timer, self
            self.entry[u] = timer
            self.flattened_arr.append(self.values[u])
            timer += 1
            for v in graph[u]:
                if not v == p:
                    dfs(v, u)
            self.exit[u] = timer - 1
            
        dfs(root, -1)
        #print(self.flattened_arr)
        self.sgt = sgtx(self.flattened_arr)
        
    def subtree_update(self, node, value):
        self.sgt.range_update(self.entry[node]+1, self.exit[node], value)
    
    def subtree_query(self, node):
        return self.sgt.query(self.entry[node]+1, self.exit[node])
        

In [31]:
n = int(input())
graph = dd(list)
root = None
par = [int(x) for x in input().split()]
for i in range(n):
    if par[i] != 0:
        graph[i].append(par[i]-1)
        graph[par[i]-1].append(i)
    else:
        root = i
#print(graph, root, sep = '\n')

3
2 0 1


In [32]:
fsq = flattened_subtree_queries(graph, [1]*n, root)
for _ in range(int(input())):
    qt, sup = [int(x) for x in input().split()]
    if qt == 1:
        fsq.subtree_update(sup-1, 1)
    elif qt == 2:
        fsq.subtree_update(sup-1, 0)
    else:
        print(fsq.subtree_query(sup-1))

3
3 1
1
2 1
3 1
0


In [10]:
#Map class:

class Map(dict):
    """
    Example:
    m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
    """
    def __init__(self, *args, **kwargs):
        super(Map, self).__init__(*args, **kwargs)
        for arg in args:
            if isinstance(arg, dict):
                for k, v in arg.items():
                    self[k] = v

        if kwargs:
            for k, v in kwargs.items():
                self[k] = v

    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)

    def __setitem__(self, key, value):
        super(Map, self).__setitem__(key, value)
        self.__dict__.update({key: value})

    def __delattr__(self, item):
        self.__delitem__(item)

    def __delitem__(self, key):
        super(Map, self).__delitem__(key)
        del self.__dict__[key]


In [13]:
inf = int(1e9)
m = Map({'max_prefix_sum':-inf, 'max_suffix_sum':-inf, 'total_sum':-inf, 'max_subarray_sum':-inf})
m.total_sum

-1000000000