diff --git a/desmod/pool.py b/desmod/pool.py index 58f53ec5..ff3f8723 100644 --- a/desmod/pool.py +++ b/desmod/pool.py @@ -13,6 +13,8 @@ class PoolPutEvent(Event): def __init__(self, pool, amount=1): super(PoolPutEvent, self).__init__(pool.env) + if amount <= 0: + raise ValueError('amount {} must be > 0'.format(amount)) self.pool = pool self.amount = amount self.callbacks.append(pool._trigger_get) @@ -28,6 +30,11 @@ def cancel(self): class PoolGetEvent(Event): def __init__(self, pool, amount=1): super(PoolGetEvent, self).__init__(pool.env) + if amount <= 0: + raise ValueError('amount {} must be > 0'.format(amount)) + assert amount <= pool.capacity, ( + "Amount {} greater than pool's {} capacity {}".format( + amount, str(pool.name), pool.capacity)) self.pool = pool self.amount = amount self.callbacks.append(pool._trigger_put) @@ -97,13 +104,13 @@ class Pool(object): """ - def __init__(self, env, capacity=float('inf'), hard_cap=False, - init_level=0, name=None): + def __init__(self, env, capacity=float('inf'), init=0, hard_cap=False, + name=None): self.env = env #: Capacity of the queue (maximum number of items). self.capacity = capacity self._hard_cap = hard_cap - self.level = init_level + self.level = init self.name = name self._putters = [] self._getters = [] @@ -149,8 +156,7 @@ def _trigger_put(self, _=None): put_ev = self._putters.pop(0) put_ev.succeed() self.level += put_ev.amount - if put_ev.amount: - self._trigger_when_new() + self._trigger_when_new() self._trigger_when_any() self._trigger_when_full() if self._put_hook: @@ -159,19 +165,16 @@ def _trigger_put(self, _=None): raise OverflowError() def _trigger_get(self, _=None): - if self._getters and self.level: - for get_ev in self._getters: - assert get_ev.amount <= self.capacity, ( - "Amount {} greater than pool's {} capacity {}".format( - get_ev.amount, str(self.name), self.capacity)) - if get_ev.amount <= self.level: - self._getters.remove(get_ev) - self.level -= get_ev.amount - get_ev.succeed(get_ev.amount) - if self._get_hook: - self._get_hook() - else: - break + while self._getters and self.level: + get_ev = self._getters[0] + if get_ev.amount <= self.level: + assert self._getters.pop(0) is get_ev + self.level -= get_ev.amount + get_ev.succeed(get_ev.amount) + if self._get_hook: + self._get_hook() + else: + break def _trigger_when_new(self): for when_new_ev in self._new_waiters: diff --git a/tests/test_pool.py b/tests/test_pool.py index b2852552..7134de39 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -78,7 +78,7 @@ def test_pool_overflow(env): def producer(env): yield env.timeout(1) - for i in range(5): + for i in range(1, 5): yield pool.put(i) yield env.timeout(1) @@ -87,6 +87,28 @@ def producer(env): env.run() +def test_pool_put_zero(env): + pool = Pool(env, capacity=5, hard_cap=True) + + def producer(env): + yield pool.put(0) + + env.process(producer(env)) + with raises(ValueError): + env.run() + + +def test_pool_get_zero(env): + pool = Pool(env, capacity=5, hard_cap=True) + + def consumer(env): + yield pool.get(0) + + env.process(consumer(env)) + with raises(ValueError): + env.run() + + def test_pool_get_more(env): pool = Pool(env, capacity=6, name='foo')