Skip to content

Commit

Permalink
Make Pool constructor compatible with simpy.Container
Browse files Browse the repository at this point in the history
* Fix unsafe loop in Pool._trigger_get
* Add check that Pool.put and Pool.get take only positive values
* Add tests to validate correct handling of Pool.put(0) and Pool.get(0)
  • Loading branch information
evangelos-vazaios-wdc authored and jpgrayson committed Aug 2, 2018
1 parent 55b9921 commit a33c096
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
39 changes: 21 additions & 18 deletions desmod/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
class PoolPutEvent(Event):
def __init__(self, pool, amount=1):
super(PoolPutEvent, self).__init__(pool.env)
if amount <= 0:
raise ValueError('amount {} must be > 0'.format(amount))
self.pool = pool
self.amount = amount
self.callbacks.append(pool._trigger_get)
Expand All @@ -28,6 +30,11 @@ def cancel(self):
class PoolGetEvent(Event):
def __init__(self, pool, amount=1):
super(PoolGetEvent, self).__init__(pool.env)
if amount <= 0:
raise ValueError('amount {} must be > 0'.format(amount))
assert amount <= pool.capacity, (
"Amount {} greater than pool's {} capacity {}".format(
amount, str(pool.name), pool.capacity))
self.pool = pool
self.amount = amount
self.callbacks.append(pool._trigger_put)
Expand Down Expand Up @@ -97,13 +104,13 @@ class Pool(object):
"""

def __init__(self, env, capacity=float('inf'), hard_cap=False,
init_level=0, name=None):
def __init__(self, env, capacity=float('inf'), init=0, hard_cap=False,
name=None):
self.env = env
#: Capacity of the queue (maximum number of items).
self.capacity = capacity
self._hard_cap = hard_cap
self.level = init_level
self.level = init
self.name = name
self._putters = []
self._getters = []
Expand Down Expand Up @@ -149,8 +156,7 @@ def _trigger_put(self, _=None):
put_ev = self._putters.pop(0)
put_ev.succeed()
self.level += put_ev.amount
if put_ev.amount:
self._trigger_when_new()
self._trigger_when_new()
self._trigger_when_any()
self._trigger_when_full()
if self._put_hook:
Expand All @@ -159,19 +165,16 @@ def _trigger_put(self, _=None):
raise OverflowError()

def _trigger_get(self, _=None):
if self._getters and self.level:
for get_ev in self._getters:
assert get_ev.amount <= self.capacity, (
"Amount {} greater than pool's {} capacity {}".format(
get_ev.amount, str(self.name), self.capacity))
if get_ev.amount <= self.level:
self._getters.remove(get_ev)
self.level -= get_ev.amount
get_ev.succeed(get_ev.amount)
if self._get_hook:
self._get_hook()
else:
break
while self._getters and self.level:
get_ev = self._getters[0]
if get_ev.amount <= self.level:
assert self._getters.pop(0) is get_ev
self.level -= get_ev.amount
get_ev.succeed(get_ev.amount)
if self._get_hook:
self._get_hook()
else:
break

def _trigger_when_new(self):
for when_new_ev in self._new_waiters:
Expand Down
24 changes: 23 additions & 1 deletion tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_pool_overflow(env):

def producer(env):
yield env.timeout(1)
for i in range(5):
for i in range(1, 5):
yield pool.put(i)
yield env.timeout(1)

Expand All @@ -87,6 +87,28 @@ def producer(env):
env.run()


def test_pool_put_zero(env):
pool = Pool(env, capacity=5, hard_cap=True)

def producer(env):
yield pool.put(0)

env.process(producer(env))
with raises(ValueError):
env.run()


def test_pool_get_zero(env):
pool = Pool(env, capacity=5, hard_cap=True)

def consumer(env):
yield pool.get(0)

env.process(consumer(env))
with raises(ValueError):
env.run()


def test_pool_get_more(env):
pool = Pool(env, capacity=6, name='foo')

Expand Down

0 comments on commit a33c096

Please sign in to comment.