-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
segment_tree.py
212 lines (170 loc) · 7.68 KB
/
segment_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import operator
from typing import Any, Optional
class SegmentTree:
"""A Segment Tree data structure.
https://en.wikipedia.org/wiki/Segment_tree
Can be used as regular array, but with two important differences:
a) Setting an item's value is slightly slower. It is O(lg capacity),
instead of O(1).
b) Offers efficient `reduce` operation which reduces the tree's values
over some specified contiguous subsequence of items in the array.
Operation could be e.g. min/max/sum.
The data is stored in a list, where the length is 2 * capacity.
The second half of the list stores the actual values for each index, so if
capacity=8, values are stored at indices 8 to 15. The first half of the
array contains the reduced-values of the different (binary divided)
segments, e.g. (capacity=4):
0=not used
1=reduced-value over all elements (array indices 4 to 7).
2=reduced-value over array indices (4 and 5).
3=reduced-value over array indices (6 and 7).
4-7: values of the tree.
NOTE that the values of the tree are accessed by indices starting at 0, so
`tree[0]` accesses `internal_array[4]` in the above example.
"""
def __init__(
self, capacity: int, operation: Any, neutral_element: Optional[Any] = None
):
"""Initializes a Segment Tree object.
Args:
capacity: Total size of the array - must be a power of two.
operation: Lambda obj, obj -> obj
The operation for combining elements (eg. sum, max).
Must be a mathematical group together with the set of
possible values for array elements.
neutral_element (Optional[obj]): The neutral element for
`operation`. Use None for automatically finding a value:
max: float("-inf"), min: float("inf"), sum: 0.0.
"""
assert (
capacity > 0 and capacity & (capacity - 1) == 0
), "Capacity must be positive and a power of 2!"
self.capacity = capacity
if neutral_element is None:
neutral_element = (
0.0
if operation is operator.add
else float("-inf")
if operation is max
else float("inf")
)
self.neutral_element = neutral_element
self.value = [self.neutral_element for _ in range(2 * capacity)]
self.operation = operation
def reduce(self, start: int = 0, end: Optional[int] = None) -> Any:
"""Applies `self.operation` to subsequence of our values.
Subsequence is contiguous, includes `start` and excludes `end`.
self.operation(
arr[start], operation(arr[start+1], operation(... arr[end])))
Args:
start: Start index to apply reduction to.
end (Optional[int]): End index to apply reduction to (excluded).
Returns:
any: The result of reducing self.operation over the specified
range of `self._value` elements.
"""
if end is None:
end = self.capacity
elif end < 0:
end += self.capacity
# Init result with neutral element.
result = self.neutral_element
# Map start/end to our actual index space (second half of array).
start += self.capacity
end += self.capacity
# Example:
# internal-array (first half=sums, second half=actual values):
# 0 1 2 3 | 4 5 6 7
# - 6 1 5 | 1 0 2 3
# tree.sum(0, 3) = 3
# internally: start=4, end=7 -> sum values 1 0 2 = 3.
# Iterate over tree starting in the actual-values (second half)
# section.
# 1) start=4 is even -> do nothing.
# 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
# 3) int-divide start and end by 2: start=2, end=3
# 4) start still smaller end -> iterate once more.
# 5) start=2 is even -> do nothing.
# 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
# NOTE: This adds the sum of indices 4 and 5 to the result.
# Iterate as long as start != end.
while start < end:
# If start is odd: Add its value to result and move start to
# next even value.
if start & 1:
result = self.operation(result, self.value[start])
start += 1
# If end is odd: Move end to previous even value, then add its
# value to result. NOTE: This takes care of excluding `end` in any
# situation.
if end & 1:
end -= 1
result = self.operation(result, self.value[end])
# Divide both start and end by 2 to make them "jump" into the
# next upper level reduce-index space.
start //= 2
end //= 2
# Then repeat till start == end.
return result
def __setitem__(self, idx: int, val: float) -> None:
"""
Inserts/overwrites a value in/into the tree.
Args:
idx: The index to insert to. Must be in [0, `self.capacity`[
val: The value to insert.
"""
assert 0 <= idx < self.capacity, f"idx={idx} capacity={self.capacity}"
# Index of the leaf to insert into (always insert in "second half"
# of the tree, the first half is reserved for already calculated
# reduction-values).
idx += self.capacity
self.value[idx] = val
# Recalculate all affected reduction values (in "first half" of tree).
idx = idx >> 1 # Divide by 2 (faster than division).
while idx >= 1:
update_idx = 2 * idx # calculate only once
# Update the reduction value at the correct "first half" idx.
self.value[idx] = self.operation(
self.value[update_idx], self.value[update_idx + 1]
)
idx = idx >> 1 # Divide by 2 (faster than division).
def __getitem__(self, idx: int) -> Any:
assert 0 <= idx < self.capacity
return self.value[idx + self.capacity]
def get_state(self):
return self.value
def set_state(self, state):
assert len(state) == self.capacity * 2
self.value = state
class SumSegmentTree(SegmentTree):
"""A SegmentTree with the reduction `operation`=operator.add."""
def __init__(self, capacity: int):
super(SumSegmentTree, self).__init__(capacity=capacity, operation=operator.add)
def sum(self, start: int = 0, end: Optional[Any] = None) -> Any:
"""Returns the sum over a sub-segment of the tree."""
return self.reduce(start, end)
def find_prefixsum_idx(self, prefixsum: float) -> int:
"""Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.
Args:
prefixsum: `prefixsum` upper bound in above constraint.
Returns:
int: Largest possible index (i) satisfying above constraint.
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
# Global sum node.
idx = 1
# While non-leaf (first half of tree).
while idx < self.capacity:
update_idx = 2 * idx
if self.value[update_idx] > prefixsum:
idx = update_idx
else:
prefixsum -= self.value[update_idx]
idx = update_idx + 1
return idx - self.capacity
class MinSegmentTree(SegmentTree):
def __init__(self, capacity: int):
super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)
def min(self, start: int = 0, end: Optional[Any] = None) -> Any:
"""Returns min(arr[start], ..., arr[end])"""
return self.reduce(start, end)