Skip to content

Commit

Permalink
Merge pull request #178 from zuoxingdong/step_info_trajectory
Browse files Browse the repository at this point in the history
Update transform: add SegmentTree/SumTree/MinTree
  • Loading branch information
zuoxingdong committed May 9, 2019
2 parents 536b033 + d881d49 commit 218a840
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/source/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,12 @@ lagom.transform: Transformations
:members:

.. autofunction:: smooth_filter

.. autoclass:: SegmentTree
:members:

.. autoclass:: SumTree
:members:

.. autoclass:: MinTree
:members:
3 changes: 3 additions & 0 deletions lagom/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@
from .rank_transform import rank_transform
from .polyak_average import PolyakAverage
from .running_mean_var import RunningMeanVar
from .segment_tree import SegmentTree
from .segment_tree import SumTree
from .segment_tree import MinTree
from .smooth_filter import smooth_filter
118 changes: 118 additions & 0 deletions lagom/transform/segment_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import operator


class SegmentTree(object):
r"""Defines a segment tree data structure.
It can be regarded as regular array, but with two major differences
- Value modification is slower: O(ln(capacity)) instead of O(1)
- Efficient reduce operation over contiguous subarray: O(ln(segment size))
Args:
capacity (int): total number of elements, it must be a power of two.
operation (lambda): binary operation forming a group, e.g. sum, min
identity_element (object): identity element in the group, e.g. 0 for sum
"""
def __init__(self, capacity, operation, identity_element):
assert capacity > 0 and capacity & (capacity - 1) == 0, 'capacity must be positive and a power of 2.'
self.capacity = capacity
self.operation = operation
self.value = [identity_element for _ in range(2*capacity)]

def _reduce(self, start, end, node, node_start, node_end):
if start == node_start and end == node_end:
return self.value[node]
mid = (node_start + node_end)//2
if end <= mid: # go to left child
return self._reduce(start, end, 2*node, node_start, mid)
else:
if start >= mid + 1: # go to right child
return self._reduce(start, end, 2*node + 1, mid + 1, node_end)
else:
return self.operation(self._reduce(start, mid, 2*node, node_start, mid),
self._reduce(mid + 1, end, 2*node + 1, mid + 1, node_end))

def reduce(self, start=0, end=None):
r"""Returns result of operation(A[start], operation(A[start+1], operation(... A[end - 1]))).
Args:
start (int): start of segment
end (int): end of segment
Returns
-------
out : object
result of reduce operation
"""
if end is None:
end = self.capacity
if end < 0:
end += self.capacity
end -= 1
return self._reduce(start, end, 1, 0, self.capacity - 1)

def __setitem__(self, index, value):
# index of leaf
index += self.capacity
self.value[index] = value
index //= 2
while index >= 1:
self.value[index] = self.operation(self.value[2*index], self.value[2*index + 1])
index //= 2

def __getitem__(self, index):
assert 0 <= index < self.capacity
return self.value[index + self.capacity]


class SumTree(SegmentTree):
r"""Defines the sum tree for storing replay priorities.
Each leaf node contains priority value. Internal nodes maintain the sum of the priorities
of all leaf nodes in their subtrees.
"""
def __init__(self, capacity):
super().__init__(capacity, operator.add, 0.0)

def sum(self, start=0, end=None):
r"""Return A[start] + ... + A[end - 1]"""
return super().reduce(start, end)

def find_prefixsum_index(self, prefixsum):
r"""Find the highest index `i` in the array such that
sum(A[0] + A[1] + ... + A[i - 1]) <= prefixsum
if array values are probabilities, this function efficiently sample indices according
to the discrete probability.
Args:
prefixsum (float): prefix sum.
Returns
-------
index : int
highest index satisfying the prefixsum constraint
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
index = 1
while index < self.capacity: # while non-leaf
if self.value[2*index] > prefixsum:
index = 2*index
else:
prefixsum -= self.value[2*index]
index = 2*index + 1
return index - self.capacity


class MinTree(SegmentTree):
def __init__(self, capacity):
super().__init__(capacity, min, float('inf'))

def min(self, start=0, end=None):
r"""Returns min(A[start], ..., A[end])"""
return super().reduce(start, end)
91 changes: 91 additions & 0 deletions test/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from lagom.transform import rank_transform
from lagom.transform import PolyakAverage
from lagom.transform import RunningMeanVar
from lagom.transform import SumTree
from lagom.transform import MinTree
from lagom.transform import smooth_filter


Expand Down Expand Up @@ -161,7 +163,96 @@ def test_running_mean_var():
assert np.allclose(f.var, x.var(0))
assert np.allclose(np.sqrt(f.var + 1e-8), x.std(0))
assert f.n == 1000


def test_sum_tree():
# Naive test
tree = SumTree(4)
tree[2] = 1.0
tree[3] = 3.0

assert np.allclose(tree.sum(), 4.0)
assert np.allclose(tree.sum(0, 2), 0.0)
assert np.allclose(tree.sum(0, 3), 1.0)
assert np.allclose(tree.sum(2, 3), 1.0)
assert np.allclose(tree.sum(2, -1), 1.0)
assert np.allclose(tree.sum(2, 4), 4.0)

del tree

# overwritten same element
tree = SumTree(4)
tree[2] = 1.0
tree[2] = 3.0

assert np.allclose(tree.sum(), 3.0)
assert np.allclose(tree.sum(2, 3), 3.0)
assert np.allclose(tree.sum(2, -1), 3.0)
assert np.allclose(tree.sum(2, 4), 3.0)
assert np.allclose(tree.sum(1, 2), 0.0)

del tree

# prefixsum index: v1
tree = SumTree(4)
tree[2] = 1.0
tree[3] = 3.0

assert tree.find_prefixsum_index(0.0) == 2
assert tree.find_prefixsum_index(0.5) == 2
assert tree.find_prefixsum_index(0.99) == 2
assert tree.find_prefixsum_index(1.01) == 3
assert tree.find_prefixsum_index(3.00) == 3
assert tree.find_prefixsum_index(4.00) == 3

# prefixsum index: v2
tree = SumTree(4)
tree[0] = 0.5
tree[1] = 1.0
tree[2] = 1.0
tree[3] = 3.0

assert tree.find_prefixsum_index(0.00) == 0
assert tree.find_prefixsum_index(0.55) == 1
assert tree.find_prefixsum_index(0.99) == 1
assert tree.find_prefixsum_index(1.51) == 2
assert tree.find_prefixsum_index(3.00) == 3
assert tree.find_prefixsum_index(5.50) == 3


def test_min_tree():
tree = MinTree(4)
tree[0] = 1.0
tree[2] = 0.5
tree[3] = 3.0

assert np.allclose(tree.min(), 0.5)
assert np.allclose(tree.min(0, 2), 1.0)
assert np.allclose(tree.min(0, 3), 0.5)
assert np.allclose(tree.min(0, -1), 0.5)
assert np.allclose(tree.min(2, 4), 0.5)
assert np.allclose(tree.min(3, 4), 3.0)

tree[2] = 0.7

assert np.allclose(tree.min(), 0.7)
assert np.allclose(tree.min(0, 2), 1.0)
assert np.allclose(tree.min(0, 3), 0.7)
assert np.allclose(tree.min(0, -1), 0.7)
assert np.allclose(tree.min(2, 4), 0.7)
assert np.allclose(tree.min(3, 4), 3.0)

tree[2] = 4.0

assert np.allclose(tree.min(), 1.0)
assert np.allclose(tree.min(0, 2), 1.0)
assert np.allclose(tree.min(0, 3), 1.0)
assert np.allclose(tree.min(0, -1), 1.0)
assert np.allclose(tree.min(2, 4), 3.0)
assert np.allclose(tree.min(2, 3), 4.0)
assert np.allclose(tree.min(2, -1), 4.0)
assert np.allclose(tree.min(3, 4), 3.0)


def test_smooth_filter():
with pytest.raises(AssertionError):
Expand Down

0 comments on commit 218a840

Please sign in to comment.