In [1]:
import typing
import numpy as np
from dataclasses import dataclass, field

@dataclass
class Node:
    value: np.float32
    start: int
    end: int
    left_child: typing.Any = field(repr=False, default=None)
    right_child: typing.Any = field(repr=False, default=None)
    parent: typing.Any = field(repr=False, default=None)
        
a = Node(98.19, 0, 0)
b = Node(88.41, 1, 1)
root = Node(144, 0, 1, a, b)

In [11]:
def overlaps(x, y, A, B):
    """Returns true if [x, y] is completely included within [A, B].
    """
    return A <= x <= y <= B
    

class SumTree:
    def __init__(self, capacity):
        assert capacity & (capacity - 1) == 0  # capacity must be a power of two
        self.capacity = capacity
        self.values = np.zeros(2 * capacity, dtype=np.int)
        
    
    def _query_recursively(self, start, end, node, node_start, node_end):
        if start == node_start and end == node_end:
            return self.values[node]
        
        mid = (node_start + node_end) // 2
        if overlaps(start, end, node_start, mid):
            return self._query_recursively(
                start, end, 2*node, node_start, mid
            )
        elif overlaps(start, end, mid + 1, node_end):
            return self._query_recursively(
                start, end, 2*node + 1, mid + 1, node_end
            )
        else:
            # [start, end] partially overlaps both left/right children
            return self._query_recursively(start, mid, 2*node, node_start, mid) + self._query_recursively(mid + 1, end, 2*node + 1, mid + 1, node_end)
    
    def query(self, start=0, end=None):
        """Returns the sum in the interval given by start and end.
        """
        if end is None:
            end = self.capacity - 1
        return self._query_recursively(start, end, 1, 0, self.capacity - 1)
    
    def __setitem__(self, i, value):
        adjusted_i = i + self.capacity
        self.values[adjusted_i] = value
        
        # Propagate upwards
        parent = adjusted_i // 2
        while parent >= 1:
            self.values[parent] = self.values[2 * parent] + self.values[2 * parent + 1]
            parent //= 2
    
    def __getitem__(self, i):
        assert 0 <= i < self.capacity
        return self.values[i + self.capacity]


In [16]:
sum_tree = SumTree(16)
data = [np.random.randint(0, 100) for _ in range(16)]
for i, x in enumerate(data):
    sum_tree[i] = x

def dumb_sum(data, x, y):
    _sum = 0
    for i in range(x, y):
        _sum += data[i]    
    return _sum

print(sum_tree.query(2))
print(dumb_sum(data, 2, len(data)))

2 15 1 0 15
2 7 2 0 7
2 3 4 0 3
2 3 9 2 3
4 7 5 4 7
8 15 3 8 15
696
696
