diff --git a/tchannel/container/heap.py b/tchannel/container/heap.py index 10ac9355..50a8e6ac 100644 --- a/tchannel/container/heap.py +++ b/tchannel/container/heap.py @@ -23,6 +23,11 @@ import math import six +from collections import deque + + +class NoMatchError(Exception): + pass class HeapOperation(object): @@ -48,6 +53,15 @@ def pop(self): """Pop an item from the heap""" raise NotImplementedError() + def peek(self, i): + """Peek at the item at the given position without removing it from the + heap. + + :param i: + 0-indexed position of the iteam in the heap + """ + raise NotImplementedError() + def swap(self, i, j): """swap items between position i and j of the heap""" raise NotImplementedError() @@ -122,3 +136,42 @@ def down(h, parent, n): h.swap(parent, min_child) parent = min_child + + +def smallest(h, p, current=0): + """Finds the index of the smallest item in the heap that matches the given + predicate. + + :param p: + Function that accepts an item from the heap and returns true or false. + :returns: + Index of the first item for which ``p`` returned true. + :raises NoMatchError: + If no matching items were found. + """ + n = h.size() + + # items contains indexes of items yet to be checked. + items = deque([0]) + while items: + current = items.popleft() + if current >= n: + continue + + if p(h.peek(current)): + return current + + child1 = 2 * current + 1 + child2 = child1 + 1 + + if child1 < n and child2 < n and h.lt(child2, child1): + # make sure we check the smaller child first. + child1, child2 = child2, child1 + + if child1 < n: + items.append(child1) + + if child2 < n: + items.append(child2) + + raise NoMatchError() diff --git a/tchannel/peer_heap.py b/tchannel/peer_heap.py index 41db71eb..38b530ce 100644 --- a/tchannel/peer_heap.py +++ b/tchannel/peer_heap.py @@ -75,6 +75,9 @@ def lt(self, i, j): return self.peers[i].score < self.peers[j].score + def peek(self, i): + return self.peers[i] + def push(self, x): x.index = len(self.peers) self.peers.append(x) diff --git a/tests/container/test_heap.py b/tests/container/test_heap.py index 44bb40ed..e56cb3ae 100644 --- a/tests/container/test_heap.py +++ b/tests/container/test_heap.py @@ -24,9 +24,11 @@ import math import pytest import six +from hypothesis import given +from hypothesis import strategies as st from tchannel.container import heap -from tchannel.container.heap import HeapOperation +from tchannel.container.heap import HeapOperation, NoMatchError class IntHeap(HeapOperation): @@ -36,6 +38,9 @@ def __init__(self): def size(self): return len(self.values) + def peek(self, i): + return self.values[i] + def lt(self, i, j): return self.values[i] < self.values[j] @@ -123,6 +128,40 @@ def test_remove(int_heap, values): verify(int_heap, 0) +def test_smallest_basic(int_heap, values): + for value in values: + heap.push(int_heap, value) + verify(int_heap, 0) + + assert heap.smallest(int_heap, (lambda _: True)) == 0 + + with pytest.raises(NoMatchError): + heap.smallest(int_heap, (lambda _: False)) + + +def test_smallest_empty(int_heap): + with pytest.raises(NoMatchError): + heap.smallest(int_heap, (lambda _: True)) + + +def test_smallest_unordered_children(int_heap): + int_heap.values = [1, 4, 2] + verify(int_heap, 0) + + assert heap.smallest(int_heap, (lambda x: x % 2 == 0)) == 2 + + +@given(st.lists(st.integers(), min_size=1)) +def test_smallest_random(values): + int_heap = IntHeap() + for v in values: + heap.push(int_heap, v) + + target = random.choice(int_heap.values) + valid = [i for (i, v) in enumerate(int_heap.values) if v == target] + assert heap.smallest(int_heap, (lambda x: x == target)) in valid + + @pytest.mark.heapfuzz @pytest.mark.skipif(True, reason='stress test for the value heap operations') def test_heap_fuzz(int_heap):