### Segement Tree

* This is a useful data strucutre for range queries, for example, finding sum for a range, finding min for a range etc.
* It is a binary tree sort of structure, where leaf nodes are array elements and rest of the nodes all the way to the root, are ranges with some value for which we query the range
* Construction : O(n), Update : O(logn), Query: O(logn); 
* Space: n size array would need a  (n + n/2 + n/4 ... 2^logn ~= 4n) size array, i.e., O(n) sapce.
* We can use another array to create segment tree structure, as it is just a binary tree.

### Building Seg tree for sum query

In [74]:
a = [5, 3, 1, 2, 6, 5, 4]

                                              1(0:6)
                                                 |
                     2(0:3)                      |                         3(4:6)
                        |                        |                            |
                       11                        |                           15
                                                 | 
          4(0:1)                5(2:3)           |              6(4:5)               7(6:6)
             |                     |             |                |                     |
             8                     3             |               11                     |                                                                                |                                      |
    8(0:0)      9(1:1)   10(2:2)       11(3:3)   |     12(4:4)         13(5:5)          | 
      |            |        |             |     26        |                |            |
      |            |        |             |               |                |            |
      5            3        1             2               6                5            4
          

In [75]:
n = len(a)

In [76]:
tree = 4*n*[None]

In [77]:
def build_tree(start, end, curr):
    if start == end:
        tree[curr] = a[start]
    else:
        l = curr * 2
        r = curr * 2 + 1
        m = (start + end) / 2
        build_tree(start, m, l)
        build_tree(m+1, end, r)
        tree[curr] = tree[l] + tree[r]


In [78]:
def update_tree(start, end, curr, idx, value):
    if start == end:
        tree[curr] = value
    else:
        l = curr * 2
        r = curr * 2 + 1
        m = (start + end) / 2
        if idx <= m:
            update_tree(start, m, l, idx, value)
        else:
            update_tree(m+1, end, r, idx, value)
        tree[curr] = tree[l] + tree[r]
        

In [94]:
def query_tree(start, end, curr, l, r):
    if l > r:
        return 0
    if start == l and r == end:
        return tree[curr]
    l1, r1 = curr * 2, curr * 2 + 1
    m = (start + end) / 2
    return (
        query_tree(start, m, l1, l, min(m, r)) +
        query_tree(m+1, end, r1, max(l, m+1), r)
    )


In [88]:
build_tree(0, n-1, 1)

In [89]:
print(tree)

[None, 26, 11, 15, 8, 3, 11, 4, 5, 3, 1, 2, 6, 5, None, None, None, None, None, None, None, None, None, None, None, None, None, None]


In [90]:
update_tree(0, 6, 1, 2, 11)

                                              1(0:6)
                                                 |
                     2(0:3)                      |                         3(4:6)
                        |                        |                            |
                       21                        |                           15
                                                 | 
          4(0:1)                5(2:3)           |              6(4:5)               7(6:6)
             |                     |             |                |                     |
             8                    13             |               11                     |                                                                                |                                      |
    8(0:0)      9(1:1)   10(2:2)       11(3:3)   |     12(4:4)         13(5:5)          | 
      |            |        |             |     36        |                |            |
      |            |        |             |               |                |            |
      5            3       11             2               6                5            4
          

In [91]:
print(tree)

[None, 36, 21, 15, 8, 13, 11, 4, 5, 3, 11, 2, 6, 5, None, None, None, None, None, None, None, None, None, None, None, None, None, None]


In [93]:
query_tree(0, 6, 1, 1, 5)

27

In [108]:
class SegTreeSum:
    def __init__(self, values):
        self.values = values 
        self.n = len(self.values)
        self.tree = [None]*4*self.n
        self.build_tree(0, self.n-1, 1) # vertex 1 is the root vertex

    def build_tree(self, start, end, curr):
        if start == end:
            self.tree[curr] = self.values[start]
        else:
            mid = (start + end) / 2
            l, r = 2*curr, 2*curr+1
            self.build_tree(start, mid, l)
            self.build_tree(mid+1, end, r)
            self.tree[curr] = self.tree[l] + self.tree[r]

    def update_tree(self, idx, new_value, end, start=0, curr=1):
        if start==end:
            self.values[idx] = new_value
            self.tree[curr] = new_value 
        else:
            mid = (start + end) / 2
            l, r = 2*curr, 2*curr+1
            if idx <= mid:
                self.update_tree(idx, new_value, mid, start, l)
            else:
                self.update_tree(idx, new_value, end, mid+1, r)
            self.tree[curr] = self.tree[l] + self.tree[r]
            
    def query_tree(self, l, r, end, start=0, curr=1):
        if l > r:
            return 0
        if start == l and end == r:
            return self.tree[curr]
        l1, r1 = 2*curr, 2*curr+1
        mid = (start + end)/2
        return self.query_tree(l, min(r, mid), mid, start, l1) + self.query_tree(max(mid+1, l), r, end, mid+1, r1)

In [109]:
sTs = SegTreeSum([10, 2, 3, 5, 9, 1])

In [110]:
sTs.query_tree(2, 4, 5)

17

In [111]:
sTs.update_tree(2, 11, 5)

In [112]:
sTs.query_tree(2, 4, 5)

25

### Build Seg tree for nth free index

In [232]:
class CountSegTree:
    def __init__(self, n):
        self.n = n
        self.tree = [(None, None)]*4*n
        self.build_tree(0, n-1, 1)
    
    def build_tree(self, start, end, curr):
        if start==end:
            self.tree[curr] = (1, end)
        else:
            l, r = 2*curr, 2*curr+1
            mid  = (start + end)/2
            self.build_tree(start, mid, l)
            self.build_tree(mid+1, end, r)
            self.tree[curr] = (self.tree[l][0] + self.tree[r][0], end)
    
    def update_tree(self, idx, end, start=0, curr=1):
        if start==end:
            self.tree[curr] = (0, end)
        else:
            mid = (start + end)/2
            if idx <= mid:
                self.update_tree(idx, mid, start, 2*curr)
            else:
                self.update_tree(idx, end, mid+1, 2*curr+1)
            l_value, r_value = self.tree[2*curr], self.tree[2*curr+1]
            self.tree[curr] = (l_value[0] + r_value[0], r_value[1] if r_value[0] else l_value[1])
            
    def query_tree(self, count, end, start=0, curr=1):
        if self.tree[curr][0] == count:
            return self.tree[curr][1]
        mid = (start+end)/2
        l_count = self.tree[2*curr][0]
        if count <= l_count:
            return self.query_tree(count, mid, start, 2*curr)
        else:
            return self.query_tree(count-l_count, end, mid+1, 2*curr+1)

In [246]:
# sTs = CountSegTree(6)

# print sTs.tree

# sTs.query_tree(4, 5)

# sTs.update_tree(3, 5); print sTs.tree

# sTs.query_tree(3, 5)

# sTs.update_tree(2, 5); print sTs.tree

# sTs.query_tree(2, 5)

# sTs.update_tree(1, 5); print sTs.tree

# sTs.query_tree(3, 5)

# sTs.update_tree(5, 5); print sTs.tree

# sTs.query_tree(1, 5)

# sTs.update_tree(0, 5); print sTs.tree

# sTs.query_tree(1, 5)

In [1]:
'''
We have discussed recursive segment tree implementation. In this post, iterative implementation is discussed.
Let us consider the following problem understand Segment Trees.
We have an array arr[0 . . . n-1]. We should be able to 
1 Find the minimum of elements from index l to r where 0 <= l <= r <= n-1 
2 Change value of a specified element of the array to a new value x. We need to do arr[i] = x where 0 <= i <= n-1.
Examples: 
 

Input : 2, 6, 7, 5, 18, 86, 54, 2
        minimum(2, 7)  
        update(3, 4)
        minimum(2, 6) 
Output : Minimum in range 2 to 7 is 2.
         Minimum in range 2 to 6 is 4.
 

The iterative version of the segment tree basically uses the fact, that for an index i, 
left child = 2 * i and right child = 2 * i + 1 in the tree. The parent for an index i 
in the segment tree array can be found by parent = i / 2. Thus we can easily travel up 
and down through the levels of the tree one by one. At first we compute the minimum in 
the ranges while constructing the tree starting from the leaf nodes and climbing up 
through the levels one by one. We use the same concept while processing the queries 
for finding the minimum in a range. Since there are (log n) levels in the worst case, 
so querying takes log n time. For update of a particular index to a given value we 
start updating the segment tree starting from the leaf nodes and update all those 
nodes which are affected by the updation of the current node by gradually moving up 
through the levels at every iteration. Updation also takes log n time because there 
we have to update all the levels starting from the leaf node where we update the 
exact value at the exact index given by the user. 



 
'''

# Python3 program to implement
# iterative segment tree.
def construct_segment_tree(segtree, a, n):
     
    # assign values to leaves of
    # the segment tree
    for i in range(n):
        segtree[n + i] = a[i];
     
    # assign values to remaining nodes
    # to compute minimum in a given range
    for i in range(n - 1, 0, -1):
        segtree[i] = min(segtree[2 * i],
                         segtree[2 * i + 1])
                         
def range_query(segtree, left, right, n):
    left += n
    right += n
     
    """ Basically the left and right indices
        will move towards right and left respectively
        and with every each next higher level and
        compute the minimum at each height change
        the index to leaf node first """
    mi = 1e9 # initialize minimum to a very high value
    while (left < right):
        if (left & 1): # if left index in odd
                mi = min(mi, segtree[left])
                left = left + 1
        if (right & 1): # if right index in odd
                right -= 1
                mi = min(mi, segtree[right])
                 
        # move to the next higher level
        left = left // 2
        right = right // 2
    return mi
 
def update(segtree, pos, value, n):
     
    # change the index to leaf node first
    pos += n
     
    # update the value at the leaf node
    # at the exact index
    segtree[pos] = value
    while (pos > 1):
         
        # move up one level at a time in the tree
        pos >>= 1;
         
        # update the values in the nodes
        # in the next higher level
        segtree[pos] = min(segtree[2 * pos],
                           segtree[2 * pos + 1])
 
# Driver Code    
 
# Elements in list
a = [2, 6, 10, 4, 7, 28, 9, 11, 6, 33]
n = len(a)
 
# Construct the segment tree by assigning
# the values to the internal nodes
segtree = [0 for i in range(2 * n)]
construct_segment_tree(segtree, a, n);
left = 0
right = 5 #compute minimum in the range left to right
print ("Minimum in range", left, "to", right, "is",
        range_query(segtree, left, right + 1, n))
 
# update the value of index 3 to 1
index = 3
value = 1
 
# a[3] = 1;
# Contents of array : {2, 6, 10, 1, 7, 28, 9, 11, 6, 33}
update(segtree, index, value, n); # point update
left = 2
right = 6 # compute minimum in the range left to right
print("Minimum in range", left, "to", right, "is",
       range_query(segtree, left, right + 1, n))
        
'''
Output: 
Minimum in range 0 to 5 is 2
Minimum in range 2 to 6 is 1
 

Time Complexity :(n log n) 
Auxiliary Space : (n)
'''

Minimum in range 0 to 5 is 2
Minimum in range 2 to 6 is 1


'\nOutput: \nMinimum in range 0 to 5 is 2\nMinimum in range 2 to 6 is 1\n \n\nTime Complexity :(n log n) \nAuxiliary Space : (n)\n'