From afea00c064b99b6b541dbd5ab80c6e919f291653 Mon Sep 17 00:00:00 2001 From: Pete Grayson Date: Thu, 7 Feb 2019 15:55:55 -0500 Subject: [PATCH] Pool at-most and at-least events Pool is updated to use PoolWhenAtMostEvent and PoolWhenAtLeastEvent as building blocks for other events--just like is now done with Queue. Unit tests are updated to achieve full coverage. --- desmod/pool.py | 276 +++++++++++++++++++++++++------------------- tests/test_pool.py | 138 ++++++++++++++++++---- tests/test_probe.py | 16 +-- 3 files changed, 284 insertions(+), 146 deletions(-) diff --git a/desmod/pool.py b/desmod/pool.py index 302ee2b5..746ea142 100644 --- a/desmod/pool.py +++ b/desmod/pool.py @@ -8,17 +8,23 @@ `int` or `float`. """ +from sys import float_info import heapq from simpy import Event from simpy.core import BoundClass -class PoolEvent(Event): - def __init__(self, pool): - super(PoolEvent, self).__init__(pool.env) +class PoolPutEvent(Event): + def __init__(self, pool, amount=1): + if not (0 < amount <= pool.capacity): + raise ValueError('amount must be in (0, capacity]') + super(PoolPutEvent, self).__init__(pool.env) self.pool = pool - pool._waiters.setdefault(type(self), []).append(self) + self.amount = amount + self.callbacks.extend([pool._trigger_when_at_least, pool._trigger_get]) + pool._put_waiters.append(self) + pool._trigger_put() def __enter__(self): return self @@ -28,62 +34,101 @@ def __exit__(self, exc_type, exc_value, traceback): def cancel(self): if not self.triggered: - self.pool._waiters[type(self)].remove(self) + self.pool._put_waiters.remove(self) self.callbacks = None -class PoolPutEvent(PoolEvent): +class PoolGetEvent(Event): def __init__(self, pool, amount=1): if not (0 < amount <= pool.capacity): raise ValueError('amount must be in (0, capacity]') + super(PoolGetEvent, self).__init__(pool.env) + self.pool = pool self.amount = amount - super(PoolPutEvent, self).__init__(pool) - self.callbacks.extend( - [ - pool._trigger_when_full, - pool._trigger_when_new, - pool._trigger_when_any, - pool._trigger_get, - ] - ) - pool._trigger_put() + self.callbacks.extend([pool._trigger_when_at_most, pool._trigger_put]) + pool._get_waiters.append(self) + pool._trigger_get() + + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + self.cancel() -class PoolGetEvent(PoolEvent): - def __init__(self, pool, amount=1): - if not (0 < amount <= pool.capacity): - raise ValueError('amount must be in (0, capacity]') + def cancel(self): + if not self.triggered: + self.pool._get_waiters.remove(self) + self.callbacks = None + + +class PoolWhenAtMostEvent(Event): + def __init__(self, pool, amount): + super(PoolWhenAtMostEvent, self).__init__(pool.env) + self.pool = pool self.amount = amount - super(PoolGetEvent, self).__init__(pool) - self.callbacks.extend( - [ - pool._trigger_when_not_full, - pool._trigger_put, - ] - ) - pool._trigger_get() + heapq.heappush(pool._at_most_waiters, self) + pool._trigger_when_at_most() + def __lt__(self, other): + return self.amount > other.amount -class PoolWhenNewEvent(PoolEvent): - pass + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + self.cancel() -class PoolWhenAnyEvent(PoolEvent): - def __init__(self, pool): - super(PoolWhenAnyEvent, self).__init__(pool) - pool._trigger_when_any() + def cancel(self): + if not self.triggered: + self.pool._at_most_waiters.remove(self) + heapq.heapify(self.pool._at_most_waiters) + self.callbacks = None + + +class PoolWhenAtLeastEvent(Event): + def __init__(self, pool, amount): + super(PoolWhenAtLeastEvent, self).__init__(pool.env) + self.pool = pool + self.amount = amount + heapq.heappush(pool._at_least_waiters, self) + pool._trigger_when_at_least() + + def __lt__(self, other): + return self.amount < other.amount + + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + self.cancel() + + def cancel(self): + if not self.triggered: + self.pool._at_least_waiters.remove(self) + heapq.heapify(self.pool._at_least_waiters) + self.callbacks = None -class PoolWhenFullEvent(PoolEvent): + +class PoolWhenAnyEvent(PoolWhenAtLeastEvent): + def __init__(self, pool, epsilon=float_info.epsilon): + super(PoolWhenAnyEvent, self).__init__(pool, amount=epsilon) + + +class PoolWhenFullEvent(PoolWhenAtLeastEvent): def __init__(self, pool): - super(PoolWhenFullEvent, self).__init__(pool) - pool._trigger_when_full() + super(PoolWhenFullEvent, self).__init__(pool, amount=pool.capacity) + + +class PoolWhenNotFullEvent(PoolWhenAtMostEvent): + def __init__(self, pool, epsilon=float_info.epsilon): + super(PoolWhenNotFullEvent, self).__init__( + pool, amount=pool.capacity - epsilon + ) -class PoolWhenNotFullEvent(PoolEvent): +class PoolWhenEmptyEvent(PoolWhenAtMostEvent): def __init__(self, pool): - super(PoolWhenNotFullEvent, self).__init__(pool) - pool._trigger_when_not_full() + super(PoolWhenEmptyEvent, self).__init__(pool, amount=0) class Pool(object): @@ -114,7 +159,10 @@ def __init__(self, env, capacity=float('inf'), init=0, hard_cap=False, self.level = init self._hard_cap = hard_cap self.name = name - self._waiters = {} + self._put_waiters = [] + self._get_waiters = [] + self._at_most_waiters = [] + self._at_least_waiters = [] self._put_hook = None self._get_hook = None BoundClass.bind_early(self) @@ -140,25 +188,30 @@ def is_full(self): #: Get amount from the pool. get = BoundClass(PoolGetEvent) + #: Return and event triggered when the pool has at least `amount` items. + when_at_least = BoundClass(PoolWhenAtLeastEvent) + + #: Return and event triggered when the pool has at most `amount` items. + when_at_most = BoundClass(PoolWhenAtMostEvent) + #: Return an event triggered when the pool is non-empty. when_any = BoundClass(PoolWhenAnyEvent) - #: Return an event triggered when items are put in pool - when_new = BoundClass(PoolWhenNewEvent) - #: Return an event triggered when the pool becomes full. when_full = BoundClass(PoolWhenFullEvent) #: Return an event triggered when the pool becomes not full. when_not_full = BoundClass(PoolWhenNotFullEvent) + #: Return an event triggered when the pool becomes empty. + when_empty = BoundClass(PoolWhenEmptyEvent) + def _trigger_put(self, _=None): - waiters = self._waiters.get(PoolPutEvent) idx = 0 - while waiters and idx < len(waiters): - put_ev = waiters[idx] + while self._put_waiters and idx < len(self._put_waiters): + put_ev = self._put_waiters[idx] if self.capacity - self.level >= put_ev.amount: - waiters.pop(idx) + self._put_waiters.pop(idx) self.level += put_ev.amount put_ev.succeed() if self._put_hook: @@ -169,12 +222,11 @@ def _trigger_put(self, _=None): idx += 1 def _trigger_get(self, _=None): - waiters = self._waiters.get(PoolGetEvent) idx = 0 - while waiters and idx < len(waiters): - get_ev = waiters[idx] + while self._get_waiters and idx < len(self._get_waiters): + get_ev = self._get_waiters[idx] if get_ev.amount <= self.level: - waiters.pop(idx) + self._get_waiters.pop(idx) self.level -= get_ev.amount get_ev.succeed(get_ev.amount) if self._get_hook: @@ -182,33 +234,21 @@ def _trigger_get(self, _=None): else: idx += 1 - def _trigger_when_new(self, _=None): - waiters = self._waiters.get(PoolWhenNewEvent) - if waiters: - for when_new_ev in waiters: - when_new_ev.succeed() - del waiters[:] - - def _trigger_when_any(self, _=None): - waiters = self._waiters.get(PoolWhenAnyEvent) - if waiters and self.level: - for when_any_ev in waiters: - when_any_ev.succeed() - del waiters[:] - - def _trigger_when_full(self, _=None): - waiters = self._waiters.get(PoolWhenFullEvent) - if waiters and self.level >= self.capacity: - for when_full_ev in waiters: - when_full_ev.succeed() - del waiters[:] - - def _trigger_when_not_full(self, _=None): - waiters = self._waiters.get(PoolWhenNotFullEvent) - if waiters and self.level < self.capacity: - for when_not_full_ev in waiters: - when_not_full_ev.succeed() - del waiters[:] + def _trigger_when_at_least(self, _=None): + while ( + self._at_least_waiters + and self.level >= self._at_least_waiters[0].amount + ): + when_at_least_ev = heapq.heappop(self._at_least_waiters) + when_at_least_ev.succeed() + + def _trigger_when_at_most(self, _=None): + while ( + self._at_most_waiters + and self.level <= self._at_most_waiters[0].amount + ): + at_most_ev = heapq.heappop(self._at_most_waiters) + at_most_ev.succeed() def __repr__(self): return ( @@ -217,14 +257,18 @@ def __repr__(self): ).format(self) -class PriorityPoolEvent(Event): - def __init__(self, pool, priority): - super(PriorityPoolEvent, self).__init__(pool.env) +class PriorityPoolPutEvent(Event): + def __init__(self, pool, amount=1, priority=0): + if not (0 < amount <= pool.capacity): + raise ValueError('amount must be in (0, capacity]') + super(PriorityPoolPutEvent, self).__init__(pool.env) self.pool = pool + self.amount = amount self.key = priority, pool._event_count pool._event_count += 1 - waiters = pool._waiters.setdefault(type(self), []) - heapq.heappush(waiters, self) + self.callbacks.extend([pool._trigger_when_at_least, pool._trigger_get]) + heapq.heappush(pool._put_waiters, self) + pool._trigger_put() def __lt__(self, other): return self.key < other.key @@ -237,42 +281,38 @@ def __exit__(self, exc_type, exc_value, traceback): def cancel(self): if not self.triggered: - waiters = self.pool._waiters[type(self)] - waiters.remove(self) - heapq.heapify(waiters) + self.pool._put_waiters.remove(self) + heapq.heapify(self.pool._put_waiters) self.callbacks = None -class PriorityPoolPutEvent(PriorityPoolEvent): +class PriorityPoolGetEvent(Event): def __init__(self, pool, amount=1, priority=0): if not (0 < amount <= pool.capacity): raise ValueError('amount must be in (0, capacity]') + super(PriorityPoolGetEvent, self).__init__(pool.env) + self.pool = pool self.amount = amount - super(PriorityPoolPutEvent, self).__init__(pool, priority) - self.callbacks.extend( - [ - pool._trigger_when_full, - pool._trigger_when_new, - pool._trigger_when_any, - pool._trigger_get, - ] - ) - pool._trigger_put() + self.key = priority, pool._event_count + pool._event_count += 1 + self.callbacks.extend([pool._trigger_when_at_most, pool._trigger_put]) + heapq.heappush(pool._get_waiters, self) + pool._trigger_get() + def __lt__(self, other): + return self.key < other.key -class PriorityPoolGetEvent(PriorityPoolEvent): - def __init__(self, pool, amount=1, priority=0): - if not (0 < amount <= pool.capacity): - raise ValueError('amount must be in (0, capacity]') - self.amount = amount - super(PriorityPoolGetEvent, self).__init__(pool, priority) - self.callbacks.extend( - [ - pool._trigger_when_not_full, - pool._trigger_put, - ] - ) - pool._trigger_get() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.cancel() + + def cancel(self): + if not self.triggered: + self.pool._get_waiters.remove(self) + heapq.heapify(self.pool._get_waiters) + self.callbacks = None class PriorityPool(Pool): @@ -297,11 +337,10 @@ def __init__( get = BoundClass(PriorityPoolGetEvent) def _trigger_put(self, _=None): - waiters = self._waiters.get(PriorityPoolPutEvent) - while waiters: - put_ev = waiters[0] + while self._put_waiters: + put_ev = self._put_waiters[0] if self.capacity - self.level >= put_ev.amount: - heapq.heappop(waiters) + heapq.heappop(self._put_waiters) self.level += put_ev.amount put_ev.succeed() if self._put_hook: @@ -312,11 +351,10 @@ def _trigger_put(self, _=None): break def _trigger_get(self, _=None): - waiters = self._waiters.get(PriorityPoolGetEvent) - while waiters: - get_ev = waiters[0] + while self._get_waiters: + get_ev = self._get_waiters[0] if get_ev.amount <= self.level: - heapq.heappop(waiters) + heapq.heappop(self._get_waiters) self.level -= get_ev.amount get_ev.succeed(get_ev.amount) if self._get_hook: diff --git a/tests/test_pool.py b/tests/test_pool.py index 257716bb..8f41c60f 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,7 +1,8 @@ -from desmod.pool import Pool, PriorityPool from pytest import raises import pytest +from desmod.pool import Pool, PriorityPool + @pytest.mark.parametrize('PoolClass', [Pool, PriorityPool]) def test_pool(env, PoolClass): @@ -42,9 +43,6 @@ def proc(env, pool): when_any = pool.when_any() assert not when_any.triggered - when_new = pool.when_new() - assert not when_new.triggered - with pool.when_not_full() as when_not_full: yield when_not_full assert when_not_full.triggered @@ -62,29 +60,25 @@ def proc(env, pool): assert put_one.triggered assert not when_any.triggered - assert not when_new.triggered assert not get_two.triggered assert not when_full.triggered assert pool.level == 1 yield put_one assert when_any.triggered - assert when_new.triggered yield env.timeout(1) - when_full2 = pool.when_full() - assert not when_full2.triggered + with pool.when_full() as when_full2: + assert not when_full2.triggered put_one = pool.put(1) assert put_one.triggered assert not when_full.triggered - assert not when_full2.triggered yield put_one assert when_full.triggered - assert when_full2.triggered assert get_two.triggered assert pool.level == 0 @@ -93,6 +87,10 @@ def proc(env, pool): when_not_full = pool.when_not_full() assert not when_not_full.triggered + with pool.when_any() as when_any2: + yield when_any2 + assert when_any2.triggered + yield pool.get(1) assert when_not_full.triggered @@ -174,24 +172,27 @@ def proc(env): @pytest.mark.parametrize('PoolClass', [Pool, PriorityPool]) def test_pool_cancel(env, PoolClass): - pool = PoolClass(env, capacity=2) - - def proc(env): + def proc(env, pool): get_ev = pool.get(2) full_ev = pool.when_full() any_ev = pool.when_any() - new_ev = pool.when_new() + empty_ev = pool.when_empty() + + assert not any_ev.triggered + assert empty_ev.triggered yield env.timeout(1) any_ev.cancel() - new_ev.cancel() - yield pool.put(1) + with pool.put(1) as put_ev: + yield put_ev assert not get_ev.triggered assert not any_ev.triggered - assert not new_ev.triggered + + with pool.when_empty() as empty_ev: + assert not empty_ev.triggered get_ev.cancel() full_ev.cancel() @@ -208,10 +209,81 @@ def proc(env): yield env.timeout(1) put_ev.cancel() - yield pool.get(1) + with pool.get(1) as get_ev2: + yield get_ev2 assert not put_ev.triggered - env.process(proc(env)) + env.process(proc(env, PoolClass(env, capacity=2))) + env.run() + + +@pytest.mark.parametrize('PoolClass', [Pool, PriorityPool]) +def test_pool_when_at_most(env, PoolClass): + def proc(env, pool): + yield pool.put(3) + at_most = {} + at_most[0] = pool.when_at_most(0) + at_most[3] = pool.when_at_most(3) + at_most[1] = pool.when_at_most(1) + at_most[2] = pool.when_at_most(2) + assert not at_most[0].triggered + assert not at_most[1].triggered + assert not at_most[2].triggered + assert at_most[3].triggered + + yield pool.get(1) + assert pool.level == 2 + assert not at_most[0].triggered + assert not at_most[1].triggered + assert at_most[2].triggered + + yield pool.get(1) + assert pool.level == 1 + assert not at_most[0].triggered + assert at_most[1].triggered + + yield pool.get(1) + assert pool.level == 0 + assert at_most[0].triggered + + env.process(proc(env, PoolClass(env))) + env.run() + + +@pytest.mark.parametrize('PoolClass', [Pool, PriorityPool]) +def test_when_at_least(env, PoolClass): + def proc(env, pool): + at_least = {} + at_least[3] = pool.when_at_least(3) + at_least[0] = pool.when_at_least(0) + at_least[2] = pool.when_at_least(2) + at_least[1] = pool.when_at_least(1) + assert at_least[0].triggered + assert not at_least[1].triggered + assert not at_least[2].triggered + assert not at_least[3].triggered + + yield pool.put(1) + assert at_least[1].triggered + assert not at_least[2].triggered + assert not at_least[3].triggered + + yield pool.get(1) + assert not at_least[2].triggered + assert not at_least[3].triggered + + yield pool.put(1) + assert not at_least[2].triggered + assert not at_least[3].triggered + + yield pool.put(1) + assert at_least[2].triggered + assert not at_least[3].triggered + + yield pool.put(1) + assert at_least[3].triggered + + env.process(proc(env, PoolClass(env))) env.run() @@ -223,7 +295,7 @@ def test_pool_check_str(env, PoolClass): ) -def test_priority_pool(env): +def test_priority_pool_gets(env): pool = PriorityPool(env) def producer(env, pool): @@ -255,3 +327,29 @@ def consumer(get_event): env.run(until=10.1) assert get1_p1_a.triggered assert not get1_p1_b.triggered + + +def test_priority_pool_puts(env): + def proc(env, pool): + put_ev = {} + put_ev[2] = pool.put(1, priority=2) + put_ev[0] = pool.put(1, priority=0) + put_ev[1] = pool.put(1, priority=1) + assert not put_ev[0].triggered + assert not put_ev[1].triggered + assert not put_ev[2].triggered + + yield pool.get(1) + assert put_ev[0].triggered + assert not put_ev[1].triggered + assert not put_ev[2].triggered + + yield pool.get(1) + assert put_ev[1].triggered + assert not put_ev[2].triggered + + yield pool.get(1) + assert put_ev[2].triggered + + env.process(proc(env, PriorityPool(env, capacity=2, init=2))) + env.run() diff --git a/tests/test_probe.py b/tests/test_probe.py index 13aaa32c..009f81fc 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -1,8 +1,8 @@ +import pytest + +from desmod.pool import Pool, PriorityPool from desmod.probe import attach from desmod.queue import Queue -from desmod.pool import Pool - -import pytest import simpy @@ -137,9 +137,10 @@ def proc(): assert values == [9, 8, 7, 8] -def test_attach_pool_level(env): +@pytest.mark.parametrize('PoolClass', [Pool, PriorityPool]) +def test_attach_pool_level(env, PoolClass): values = [] - pool = Pool(env) + pool = PoolClass(env) attach('scope', pool, [values.append]) def proc(): @@ -154,9 +155,10 @@ def proc(): assert values == [1, 2, 3, 2] -def test_attach_pool_remaining(env): +@pytest.mark.parametrize('PoolClass', [Pool, PriorityPool]) +def test_attach_pool_remaining(env, PoolClass): values = [] - pool = Pool(env, capacity=10) + pool = PoolClass(env, capacity=10) attach('scope', pool, [values.append], trace_remaining=True)