In [32]:
# https://www.youtube.com/watch?v=-dUiRtJ8ot0

segment tree is used to calculate the min/max value in a particular range of an array in O(logn) time

brute force approach will take O(n) for each query because we will have to traverse through all the elements in that range in the array


<img src="./../../../images/Screenshot 2023-10-27 at 1.18.28 PM.png" width="700" >

remember the result 

2^0 + 2^1 + 2^2 + .... + 2^n = 2^(n+1) - 1

so we have n levels in segment tree

then the max index we will need is  2^(n+1) - 1

no of levels = logb2(no. of elements in array)

In [33]:
from math import ceil, inf, log2
from typing import List


def formSegmentTree(arr: List[int]):
    n = len(arr)
    nLevels = ceil(log2(n))
    lengthOfSegmentTree = 2**(nLevels+1)
    seg = [0] * lengthOfSegmentTree
    def build(ind: int, low: int, high: int):
        if low == high:
            seg[ind] = arr[low]
            return
        mid = (low + high) // 2
        build(2*ind + 1, low, mid)
        build(2*ind + 2, mid+1, high)
        seg[ind] = max(seg[2*ind+1], seg[2*ind+2])

    build(0, 0, n - 1)
    
    def query(ind: int, low: int, high: int, l:int, r:int):
        if low >= l and high <= r:
            # current range is completely inside the desired range
            return seg[ind]
        if high < l or low > r:
            # current range is outside the desired range
            # since we are calculating the max element, we don't want to consider this result, so we return INT_MIN
            return -inf
        
        # handle the case of overlap
        mid = (low + high) // 2
        left = query(2*ind+1, low, mid, l, r)
        right = query(2*ind+2, mid+1, high, l, r)
        # return the max from left and right
        return max(left, right)
    
    def findMaxInRange(l:int, r:int):
        return query(0, 0, n-1, l, r)

    return findMaxInRange

In [34]:
#      0  1  2  3  4  5  6  7  8   9
arr = [8, 2, 5, 1, 4, 5, 3, 9, 6, 10]
findMaxInRange = formSegmentTree(arr)

findMaxInRange(5, 7)

9

In [35]:
# verification

n = len(arr)
for i in range(n):
    for j in range(i,n):
        if findMaxInRange(i, j) != max(arr[i:j+1]):
            raise Exception('results are different')