## Segment tree with array (using None as first element, 2n elements)
##### o(n) space, referred from Stable sort's

In [7]:
class segment_tree():
    
    def __init__(self, arr, func, fid=0):
        self.func = func
        self.tree = None
        self.arr = list(arr)
        self.f_identity = fid
    
    def build_tree(self):            #o(n)
        self.tree = [None]*(len(self.arr)) + self.arr
        for i in range(len(self.arr)-1, 0, -1):
            self.tree[i] = self.func(self.tree[2*i], self.tree[2*i+1])
            
    def update(self, index, value):         #o(logn)
        self.arr[index] = value
        index += (len(self.arr))
        self.tree[index] = value   #updating the leaf node
        while (index := (index//2)) > 0:   #updating upper nodes
            updated_internal_node_val = self.func(self.tree[2*index], self.tree[2*index+1])
            if updated_internal_node_val != self.tree[index]:
                self.tree[index] = updated_internal_node_val
            else: 
                break
    
    def query(self, l, r):               # l, r -1, stable_sorts' solution
        if l < 0 or r > len(self.arr):
            return None
        l += (len(self.arr))
        r += (len(self.arr))
        func_res = self.f_identity
        while l < r:
            if l&1 == 1:   #if l is the right-child (odd index in tree)
                func_res = self.func(self.tree[l], func_res)
                l += 1
                
            if r&1 == 1:   #if r is the right-child (odd index in tree), then directly use its parent
                r -= 1
                func_res = self.func(self.tree[r], func_res)
                
            r = r//2
            l = l//2
                          
        return func_res
                            

In [8]:
%%timeit
st = segment_tree([x for x in range(1,int(1e7+1))], lambda x, y : x+y, 0)
st.build_tree()
#st.update(6,8)
#st.tree
#st.query(0,999999)

4.31 s ± 289 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Segment tree with array (with 2n - 1 elements)
### for bottom - up querying 
#### Bottom up querying is snappy, usful when:
#### 1) When the range is known (ie, l and r are known), and/or
#### 2) the function is commutative

In [45]:
''' Segment tree is a full binary tree, that is, all of its internal nodes have 2 or 0 children, and number of 
nodes is 2*l - 1, where l is the number of leaf nodes '''

class SegmentTree():   
    
    def __init__(self, arr, func, fid=0):
        self.func = func
        self.f_identity = fid
        self.length = len(arr)
        self._build_tree(arr)
        
    def _build_tree(self, arr):            # o(n)
        self.tree = [None]*(self.length - 1) + arr
        for i in range(self.length-2, -1, -1):
            self.tree[i] = self.func(self.tree[(i << 1) + 1], self.tree[(i + 1) << 1])
            
    def update(self, index, value):         # o(logn)
        index += (self.length - 1)
        self.tree[index] = value   #updating the leaf node
        while (index := ((index - 1) >> 1)) >= 0:   #updating upper nodes
            upd = self.func(self.tree[(index << 1) + 1], self.tree[(index + 1) << 1])
            if self.tree[index] != upd:
                self.tree[index] = upd
            else:
                break

    def query(self, l, r):                   # O(logn)
        if l < 0 or r >= self.length:
            return None
        l += (self.length - 1)
        r += (self.length - 1)
        res = self.f_identity
        
        while r >= l:   
            
            if not (l & 1) :     #if l is a right node(even indexed in the tree), then add to the result and increament l
                res = self.func(res, self.tree[l])   #changed ordering of the arguments for non commutative combiner functions
                l += 1
                
            if r & 1 :      #if r is left child (odd indexed in the tree), then add to the result and decreament r
                res = self.func(self.tree[r], res)
                r -= 1
            
            l = (l-1) >> 1
            r = (r-1) >> 1        #move l and r to their parents
            
        return res
       

In [46]:
#%%timeit
st2 = SegmentTree([x for x in range(1,int(8))], lambda x, y : x+y, 0)
st2.query(2,5)
#st2.query(99,199)

18

## 2D Segment Tree

In [2]:
class SegmentTree2D():
    
    def __init__(self, mat, size, func, fid = 0):
        self.matrix = mat
        self.func = func
        self.fid = fid
        self.mat_size = size              # m x n
        self.func_map = lambda arr1, arr2 : [self.func(arr1[i], arr2[i]) for i in range(len(arr1))]
        self.trees = [None] * (self.mat_size[0]-1)
        self._build()
        
    def _build(self):
        for i in range(self.mat_size[0]):
            self.trees.append(SegmentTree(self.func, self.matrix[i], self.fid))
        for i in range(self.mat_size[0]-2, -1, -1):
            self.trees[i] = SegmentTree(self.func, self.func_map(self.trees[2*i+1].arr, self.trees[2*i+2].arr))
            
    def update(self, r, c, val):
        self.matrix[r][c] = val
        r += (self.mat_size[0] - 1)
        self.trees[r].update(c,val)
        while (r := (r-1)//2) >= 0:
            self.trees[r].tree = self.func_map(self.trees[2*r+1].tree, self.trees[2*r+2].tree)
            
    def query(self, r1, r2, c1, c2):
        res = self.fid
        r1 += self.mat_size[0] - 1
        r2 += self.mat_size[0] - 1
        
        while r2 >= r1:
            if r2 == r1:
                res = self.func(self.trees[r1].query(c1, c2), res)
                break
            
            if not (r1 & 1):
                res = self.func(self.trees[r1].query(c1, c2), res)
                r1 += 1
            
            if (r2 & 1) == 1:
                res = self.func(self.trees[r2].query(c1, c2), res)
                r2 -= 1
            
            r1 = (r1 - 1) // 2
            r2 = (r2 - 1) // 2
        
        return res

## Segment tree, with number of elements exactly equal to a power of 2
### for top - down querying
#### Top down query is like an exhaustive DFS, usful when:
#### 1) When the range is unknown (ie, l and r are not known), and/or
#### 2) the function is not commutative (as the nodes are always considered from left to right)

In [1]:
import math
class segment_tree_td():
    
    def __init__(self, arr, func, fid = 0):
        self.func = func
        self.fid = fid      #function identity
        self.arr = arr
        self.org_len = len(self.arr)
        self._build()
    
    def _build(self):
        if (tree_len := math.log2(len(self.arr))) != int(tree_len):
            required = 2 ** math.ceil(tree_len)
            extra = required - len(self.arr)
            self.arr += [self.fid] * extra
        self.tree = ([None] * (len(self.arr) - 1)) + self.arr
        for index in range(len(self.arr)-2, -1, -1):
            self.tree[index] = self.func(self.tree[2*index + 1], self.tree[2*index + 2])
            
    def update(self, index, val):
        self.arr[index] = val
        index += (len(self.arr) - 1)
        self.tree[index] = val
        while (index := ((index-1)//2)) >= 0:
            self.tree[index] = self.func(self.tree[2*index+1],self.tree[2*index+2])
            
    def query(self, l, r):
        return self._query_util(l, r, [0,len(self.arr)-1], 0)
            
    def _query_util(self, l, r, el_range, index):
        if l > r:
            return self.fid
        if el_range[0] == l and el_range[1] == r:
            return self.tree[index]
        mid = (el_range[0]+el_range[1])//2
        return (self._query_util(l, min(r, mid),  [el_range[0], mid], index*2+1) +
                self._query_util(max(l, mid+1), r, [mid+1, el_range[1]], index*2+2))
        
    def smallest_index_for_prefix_sum(self, k):
        tl, tr = 0, len(self.arr)-1
        cur_node = 0
        while tr != tl:
            mid = (tl+tr)//2
            if self.tree[2*cur_node + 1] >= k:          #if sum at left child is greater than k, then it is at left subtree
                cur_node = (2*cur_node + 1)
                tr = mid
            else:
                k -= self.tree[2*cur_node - 1]
                cur_node = (2*cur_node + 2)        #if in right subtree, then subtract sum of left subtree from k
                tl = mid+1    
        return tl
                
    ''' for finding an index in range l...r greater than a given sum, use tl >= l and tr <= r 
        the rest is same as above '''

In [2]:
std = segment_tree_td([x for x in range(1,int(9))], lambda x,y : x + y, 0)
#std.query(0,999999)
std.tree
std.smallest_index_for_prefix_sum(10)

2

## Range query for Largest Sum Contiguous Subarray
### Represented as : Query(x,y) = Max { a[i]+a[i+1]+...+a[j] ; x ≤ i ≤ j ≤ y }
#### [SPOJ Sample Problem](https://www.spoj.com/problems/GSS1/)
#### [GFG Article](https://www.geeksforgeeks.org/maximum-subarray-sum-given-range/)
#### [GFG Article for update](https://www.geeksforgeeks.org/range-query-largest-sum-contiguous-subarray/)
### We store 4 values in the segment tree:
#### 1. max_prefix_sum : max(left_child.max_prefix_sum, left_child.total_sum + right_child.max_prefix_sum)
#### 2. max_suffix_sum : max(right_child.max_suffix_sum, right_child.total_sum + left_child.max_suffix_sum)
#### 3. total_sum : left_child.total_sum + right_child.total_sum
#### 1. max_subarray_sum : max(left_child.max_subarray_sum, right_child.max_subarray_sum, left_child.max_suffix_sum + right_child.max_prefix_sum)


In [35]:
from collections import defaultdict as dd

inf = int(1e9)

class Map(dict):
    def __init__(self, *args, **kwargs):
        super(Map, self).__init__(*args, **kwargs)
        for arg in args:
            if isinstance(arg, dict):
                for k, v in arg.items():
                    self[k] = v
        if kwargs:
            for k, v in kwargs.items():
                self[k] = v

    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)

    def __setitem__(self, key, value):
        super(Map, self).__setitem__(key, value)
        self.__dict__.update({key: value})

    def __delattr__(self, item):
        self.__delitem__(item)

    def __delitem__(self, key):
        super(Map, self).__delitem__(key)
        del self.__dict__[key]
        
def merge(left_child, right_child):
    assert(isinstance(left_child, Map) and isinstance(right_child, Map))
    prefix = max(left_child.max_prefix_sum, left_child.total_sum + right_child.max_prefix_sum)
    suffix = max(right_child.max_suffix_sum, right_child.total_sum + left_child.max_suffix_sum)
    total = left_child.total_sum + right_child.total_sum
    subarray_sum = max(left_child.max_subarray_sum, right_child.max_subarray_sum, 
                      left_child.max_suffix_sum + right_child.max_prefix_sum)
    return Map({'max_prefix_sum':prefix, 'max_suffix_sum':suffix, 'total_sum':total, 'max_subarray_sum':subarray_sum})

def value_node(val):
    return Map({'max_prefix_sum':val, 'max_suffix_sum':val, 'total_sum':val, 'max_subarray_sum':val})

def empty_node():
    return Map({'max_prefix_sum':-inf, 'max_suffix_sum':-inf, 'total_sum':-inf, 'max_subarray_sum':-inf})

class SegmentTree():
    
    def __init__(self, arr, func, fid = 0):
        self.func = func
        self.fid = fid
        self.length = len(arr)
        def _build_tree(arr):            # o(n)
            nonlocal self
            self.tree = [None]*(self.length - 1) + arr
            for i in range(self.length-2, -1, -1):
                self.tree[i] = self.func(self.tree[2*i+1], self.tree[2*i+2])
        _build_tree(arr)

inp =  [1,3,-4,5,-2,7,-6,8]
arr = [value_node(i) for i in inp]
sgt = SegmentTree(arr, merge, -inf)
sgt.tree

[{'max_prefix_sum': 12,
  'max_suffix_sum': 12,
  'total_sum': 12,
  'max_subarray_sum': 12},
 {'max_prefix_sum': 5,
  'max_suffix_sum': 5,
  'total_sum': 5,
  'max_subarray_sum': 5},
 {'max_prefix_sum': 7,
  'max_suffix_sum': 9,
  'total_sum': 7,
  'max_subarray_sum': 9},
 {'max_prefix_sum': 4,
  'max_suffix_sum': 4,
  'total_sum': 4,
  'max_subarray_sum': 4},
 {'max_prefix_sum': 1,
  'max_suffix_sum': 5,
  'total_sum': 1,
  'max_subarray_sum': 5},
 {'max_prefix_sum': 5,
  'max_suffix_sum': 7,
  'total_sum': 5,
  'max_subarray_sum': 7},
 {'max_prefix_sum': 2,
  'max_suffix_sum': 8,
  'total_sum': 2,
  'max_subarray_sum': 8},
 {'max_prefix_sum': 1,
  'max_suffix_sum': 1,
  'total_sum': 1,
  'max_subarray_sum': 1},
 {'max_prefix_sum': 3,
  'max_suffix_sum': 3,
  'total_sum': 3,
  'max_subarray_sum': 3},
 {'max_prefix_sum': -4,
  'max_suffix_sum': -4,
  'total_sum': -4,
  'max_subarray_sum': -4},
 {'max_prefix_sum': 5,
  'max_suffix_sum': 5,
  'total_sum': 5,
  'max_subarray_sum': 5},
 {

defaultdict(<function __main__.empty_node()>, {3: 4})

## Hackerrank Question
#### Given a string of size n, Do the following operations exactly n times:
####  1. Select the charcter with least ASCII value, remove it from the string. If their are multiple chars, remove the one with lowest index.
####  2. Add the index of the removed character to a variable named count. (1 based indexing).
# ------------------------------------------------------------------------------------
### The idea is to use 2 segment trees, 1 for finding the minimum of char from the string, and the other for storing how many characters have been popped, since the indices of chars might change if one of its previous characters is removed. (1 if the char of that index is removed, else keep it 0) (Use a seg_tree of summation for this)

In [28]:
from string import ascii_lowercase as lower
kp = {a:i for a,i in zip(lower,range(26))}
lkp['inf'] = 26

class seg_tree():
    
    @staticmethod
    def find_min_ch(a, b):
        global lkp
        if a[0] != b[0]:
            if lkp[a[0]] < lkp[b[0]]:
                return a
            return b
        else:
            if a[1] < b[1]:   #smaller index
                return a
            return b
    
    def __init__(self, arr):
        self.tree = None
        self.arr = list(arr)
        self.f_identity = 'inf'
        self._build_tree()
    
    def _build_tree(self):           
        self.tree = [None]*(len(self.arr) - 1) + self.arr
        for i in range(len(self.arr)-2, -1, -1):
            self.tree[i] = seg_tree.find_min_ch(self.tree[2*i+1], self.tree[2*i+2])   
            
    def update(self, index, value):         
        self.arr[index][0] = value
        index += (len(self.arr) - 1)
        self.tree[index][0] = value   
        while (index := ((index-1)//2)) >= 0:  
            updated_internal_node_val = seg_tree.find_min_ch(self.tree[2*index+1], self.tree[2*index+2])
            if updated_internal_node_val != self.tree[index]:
                self.tree[index] = updated_internal_node_val
            else: 
                break
    def return_min(self):
        return self.tree[0]

In [33]:
class seg_tree_for():
    
    def __init__(self, arr, func, fid=0):
        self.func = func
        self.tree = None
        self.arr = list(arr)
        self.f_identity = fid
        self._build_tree()
    
    def _build_tree(self):            # o(n)
        self.tree = [None]*(len(self.arr) - 1) + self.arr
        for i in range(len(self.arr)-2, -1, -1):
            self.tree[i] = self.func(self.tree[2*i+1], self.tree[2*i+2])
            
    def update(self, index, value):         # o(logn)
        self.arr[index] = value
        index += (len(self.arr) - 1)
        self.tree[index] = value   
        while (index := ((index-1)//2)) >= 0:   
            updated_internal_node_val = self.func(self.tree[2*index+1], self.tree[2*index+2])
            if updated_internal_node_val != self.tree[index]:
                self.tree[index] = updated_internal_node_val
            else: 
                break

    def query(self, l, r):                  
        if l < 0 or r >= len(self.arr):
            return None
        l += (len(self.arr) - 1)
        r += (len(self.arr) - 1)
        res = self.f_identity
        
        while r >= l:   
            
            if not (l & 1) :     
                res = self.func(self.tree[l], res)
                l += 1
                
            if (r & 1) == 1 :     
                res = self.func(self.tree[r], res)
                r -= 1
            
            r = (r-1) // 2
            l = (l-1) // 2 
            
        return res
       

In [34]:
#max val
n = int(input())
ls = list(input())
for i in range(n):
    ls[i] = [ls[i], i]
sgt_alpha = seg_tree(ls)
count = 0
sgt_indices = seg_tree_for_indices([0]*n,lambda x,y : x+y)
for _ in range(n):
    min_node = sgt_alpha.return_min()
    count += (min_node[1]-sgt_indices.query(0,min_node[1])+1)
    sgt_alpha.update(min_node[1], 'inf')
    sgt_indices.update(min_node[1], 1)
print(count)

4
aabb
4


## Heap / Priority - queue using segment tree

In [26]:
class pri_q(): 
    ''' min heap, can be modified for max heap, max_seg_tree for finding the 
        smallest index where a new element can be pushed
    '''
    @staticmethod
    def minimum(a, b):
        if a[0] < b[0]:
            return a
        elif a[0] > b[0]:
            return b
        else:
            if a[1] < b[1]:
                return a
            return b
        
    @staticmethod  
    def maximum(a, b):    #to find position for m
        if a[0] < b[0]:
            return b
        elif a[0] > b[0]:
            return a
        else:
            if a[1] < b[1]:
                return a
            return b
        
    def __init__(self, arr, max_len = 2*1e5):
        self.length = int(max_len)
        self.arr = [[arr[i] ,i] if i < len(arr) else [int(1e10),i] for i in range(self.length)]
        #[element, index]
        self.tree_min = [None]*(self.length-1) + self.arr
        self.tree_max = [None]*(self.length-1) + self.arr
        self.max_val = int(1e10)
        self._build() 
        
    def _build(self):
        for i in range(self.length-2, -1, -1):
            self.tree_min[i] = pri_q.minimum(self.tree_min[2*i+1], self.tree_min[2*i+2])
            self.tree_max[i] = pri_q.maximum(self.tree_max[2*i+1], self.tree_max[2*i+2])
            
    def _update_min_tree(self, index, value):
        self.arr[index] = [value, index]
        t_index = index+(self.length - 1)
        self.tree_min[t_index] = [value, index]
        while (t_index := ((t_index-1)//2)) >= 0:   
            updated_internal_node_val = pri_q.minimum(self.tree_min[2*t_index+1], self.tree_min[2*t_index+2])
            if updated_internal_node_val != self.tree_min[t_index]:
                self.tree_min[t_index] = updated_internal_node_val
            else: 
                break
                
    def _update_max_tree(self, index, value):
        self.arr[index] = [value, index]
        t_index = index+(self.length - 1)
        self.tree_max[t_index] = [value, index]
        while (t_index := ((t_index-1)//2)) >= 0:   
            updated_internal_node_val = pri_q.maximum(self.tree_max[2*t_index+1], self.tree_max[2*t_index+2])
            if updated_internal_node_val != self.tree_max[t_index]:
                self.tree_max[t_index] = updated_internal_node_val
            else: 
                break
                
    def update(self, index, value):
        self._update_min_tree(index, value)
        self._update_max_tree(index, value)
                
    def get_minimum(self):
        return self.tree_min[0][0]
    
    def pop(self):
        smallest = self.tree_min[0]
        self.update(smallest[1], self.max_val)
        return smallest[0]
    
    def push(self, value):
        index = self.tree_max[0][1]
        self.update(index, value)
        