In [26]:
from math import ceil, log2

class SegmentationTree():
    def __init__(self, input_list):
        self._input_list = input_list[:] # copy of input_list
        self._init_tree()
        self._is_propagated = True
        
    def _init_tree(self):
        
        self._n = len(self._input_list)
        
        height = ceil(log2(self._n))
        n_nodes = 2 * (2 ** height) - 1
        self._seg_tree = [None] * n_nodes
        
        arr_left = 0
        arr_right = self._n - 1
        seg_node_index = 0
        self._propagate(arr_left, arr_right, seg_node_index)
        
    
    def query(self, query_left, query_right):
        arr_left = 0
        arr_right = self._n - 1
        seg_node_index = 0    
        
        if not self._is_propagate:
            self._propagate(arr_left, arr_right, seg_node_index)
            self._is_propagate = True
        
        return self._query_helper(query_left, query_right, arr_left, arr_right, seg_node_index)


    def _query_helper(self, query_left, query_right, arr_left, arr_right, seg_node_index):
        if arr_right < arr_left:
            return 0

        if query_right < arr_left or arr_right < query_left:
            return 0

        if query_left <= arr_left and arr_right <= query_right:
            return self._seg_tree[seg_node_index]


        midpoint = (arr_left + arr_right) // 2

        left_seg_node_index = seg_node_index * 2 + 1
        left_node_arr_left = arr_left
        left_node_arr_right = midpoint
        left_val = self._query_helper(query_left, query_right, left_node_arr_left, left_node_arr_right, left_seg_node_index)

        right_seg_node_index = seg_node_index * 2 + 2
        right_node_arr_left = midpoint + 1
        right_node_arr_right = arr_right
        right_val = self._query_helper(query_left, query_right, right_node_arr_left, right_node_arr_right, right_seg_node_index)

        return left_val + right_val

        
    def _propagate(self, arr_left, arr_right, seg_node_index):
        if arr_right < arr_left:
            return
        
        if arr_left == arr_right:
            value = self._input_list[arr_left]
            self._seg_tree[seg_node_index] = value
            return
        
        midpoint = (arr_left + arr_right) // 2
        
        left_seg_node_index = seg_node_index * 2 + 1
        left_node_arr_left = arr_left
        left_node_arr_right = midpoint
        self._propagate(left_node_arr_left, left_node_arr_right, left_seg_node_index)
        
        right_seg_node_index = seg_node_index * 2 + 2
        right_node_arr_left = midpoint + 1
        right_node_arr_right = arr_right
        self._propagate(right_node_arr_left, right_node_arr_right, right_seg_node_index)
        
        left_val = self._seg_tree[left_seg_node_index]
        right_val = self._seg_tree[right_seg_node_index]
        self._seg_tree[seg_node_index] = left_val + right_val
        
    def lazy_update(self, arr_index, new_value):
        old_value = self._input_list[arr_index]
        if old_value != new_value:  
            self._input_list[arr_index] = new_value
            self._is_propagate = False
            
            
#     def update(self, arr_index, new_value):
#         def inner(node, left, right):
#             if arr_index < left or arr_index > right:
#                 return self._seg_tree[node]
#             if left == right:
#                 self._seg_tree[node] = new_value
#                 return self._seg_tree[node]
#             midpoint = (left + right) // 2
#             left_value = inner(node * 2 + 1, left, midpoint)
#             right_value = inner(node * 2 + 2, midpoint + 1, right)  
            
#             self.seg_tree[node] = left_value + right_value
#         inner(0, 0, self._n - 1)

In [27]:
nums = [1, 3, 5]
st = SegmentationTree(nums)
print(st.query(0, 2))
assert st.query(0, 2) == 9
st.update(1, 2)
print(st.query(0, 2))
assert st.query(0, 2) == 8

9
8


In [29]:
import time
while True:
    time.sleep(0.1)

KeyboardInterrupt: 