Skip to content

Commit

Permalink
Feat worker limit (#142)
Browse files Browse the repository at this point in the history
* feat: add queue_read_limit to limit the number of job IDs read at each poll

* fix: set offset to None in case the limit is not specified

* chore: format code

* fix: set queue_read_limit to max_jobs if not specified

* chore: refactor worker poll iteration into a separate method and write tests for queue_read_limit

* chore: add history note

* chore: format code

* chore: format docstring

* fix: remove additional call to Redis introduced by mistake

* chore: remove unneeded call to create_pool in queue_read_limit tests

* chore: assign queue_read_limit in one line

* chore: increase delay in queue_read_limit tests for CI systems

* chore: fix linting errors

* fix remaining sleeps
  • Loading branch information
rubik authored and samuelcolvin committed Aug 11, 2019
1 parent f0fa671 commit 3b35919
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 16 deletions.
5 changes: 3 additions & 2 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
History
-------

v0.17 (unreleased)
..................
v0.17.0 (unreleased)
....................
* add ``worker.queue_read_limit``, fix #141, by @rubik
* custom serializers, eg. to use msgpack rather than pickle, #143 by @rubik

v0.16.1 (2019-08-02)
Expand Down
36 changes: 24 additions & 12 deletions arq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,13 @@ class Worker:
:param job_timeout: default job timeout (max run time)
:param keep_result: default duration to keep job results for
:param poll_delay: duration between polling the queue for new jobs
:param queue_read_limit: the maximum number of jobs to pull from the queue each time it's polled; by default it
equals ``max_jobs``
:param max_tries: default maximum number of times to retry a job
:param health_check_interval: how often to set the health check key
:param health_check_key: redis key under which health check is set
:param retry_jobs: whether to retry jobs on Retry or CancelledError or not
:param max_burst_jobs: the maximum number of jobs to process in burst mode (disabled with negative values)
:param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps
:param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads
"""
Expand All @@ -160,6 +164,7 @@ def __init__(
job_timeout: SecondsTimedelta = 300,
keep_result: SecondsTimedelta = 3600,
poll_delay: SecondsTimedelta = 0.5,
queue_read_limit: Optional[int] = None,
max_tries: int = 5,
health_check_interval: SecondsTimedelta = 3600,
health_check_key: Optional[str] = None,
Expand All @@ -184,6 +189,8 @@ def __init__(
self.job_timeout_s = to_seconds(job_timeout)
self.keep_result_s = to_seconds(keep_result)
self.poll_delay_s = to_seconds(poll_delay)
self.queue_read_limit = queue_read_limit or max_jobs
self._queue_read_offset = 0
self.max_tries = max_tries
self.health_check_interval = to_seconds(health_check_interval)
if health_check_key is None:
Expand Down Expand Up @@ -264,18 +271,7 @@ async def main(self):
await self.on_startup(self.ctx)

async for _ in poll(self.poll_delay_s): # noqa F841
async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs
now = timestamp_ms()
job_ids = await self.pool.zrangebyscore(self.queue_name, max=now)
await self.run_jobs(job_ids)

# required to make sure errors in run_job get propagated
for t in self.tasks:
if t.done():
self.tasks.remove(t)
t.result()

await self.heart_beat()
await self._poll_iteration()

if self.burst:
if (
Expand All @@ -287,6 +283,22 @@ async def main(self):
if queued_jobs == 0:
return

async def _poll_iteration(self):
async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs
now = timestamp_ms()
job_ids = await self.pool.zrangebyscore(
self.queue_name, offset=self._queue_read_offset, count=self.queue_read_limit, max=now
)
await self.run_jobs(job_ids)

# required to make sure errors in run_job get propagated
for t in self.tasks:
if t.done():
self.tasks.remove(t)
t.result()

await self.heart_beat()

async def run_jobs(self, job_ids):
for job_id in job_ids:
await self.sem.acquire()
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ async def arq_redis_msgpack(loop):
async def worker(arq_redis):
worker_: Worker = None

def create(functions=[], burst=True, poll_delay=0, **kwargs):
def create(functions=[], burst=True, poll_delay=0, max_jobs=10, **kwargs):
nonlocal worker_
worker_ = Worker(functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, **kwargs)
worker_ = Worker(
functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs, **kwargs
)
return worker_

yield create
Expand Down
51 changes: 51 additions & 0 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Settings:
functions = [func(foobar, name='foobar')]
burst = True
poll_delay = 0
queue_read_limit = 10

loop.run_until_complete(arq_redis.enqueue_job('foobar'))
worker = run_worker(Settings)
Expand Down Expand Up @@ -453,6 +454,56 @@ async def test_repeat_job_result(arq_redis: ArqRedis, worker):
assert await arq_redis.enqueue_job('foobar', _job_id='job_id') is None


async def test_queue_read_limit_equals_max_jobs(arq_redis: ArqRedis, worker):
for _ in range(4):
await arq_redis.enqueue_job('foobar')

assert await arq_redis.zcard(default_queue_name) == 4
worker: Worker = worker(functions=[foobar], max_jobs=2)
assert worker.jobs_complete == 0
assert worker.jobs_failed == 0
assert worker.jobs_retried == 0

await worker._poll_iteration()
await asyncio.sleep(0.1)
assert await arq_redis.zcard(default_queue_name) == 2
assert worker.jobs_complete == 2
assert worker.jobs_failed == 0
assert worker.jobs_retried == 0

await worker._poll_iteration()
await asyncio.sleep(0.1)
assert await arq_redis.zcard(default_queue_name) == 0
assert worker.jobs_complete == 4
assert worker.jobs_failed == 0
assert worker.jobs_retried == 0


async def test_custom_queue_read_limit(arq_redis: ArqRedis, worker):
for _ in range(4):
await arq_redis.enqueue_job('foobar')

assert await arq_redis.zcard(default_queue_name) == 4
worker: Worker = worker(functions=[foobar], max_jobs=4, queue_read_limit=2)
assert worker.jobs_complete == 0
assert worker.jobs_failed == 0
assert worker.jobs_retried == 0

await worker._poll_iteration()
await asyncio.sleep(0.1)
assert await arq_redis.zcard(default_queue_name) == 2
assert worker.jobs_complete == 2
assert worker.jobs_failed == 0
assert worker.jobs_retried == 0

await worker._poll_iteration()
await asyncio.sleep(0.1)
assert await arq_redis.zcard(default_queue_name) == 0
assert worker.jobs_complete == 4
assert worker.jobs_failed == 0
assert worker.jobs_retried == 0


async def test_custom_serializers(arq_redis_msgpack: ArqRedis, worker):
j = await arq_redis_msgpack.enqueue_job('foobar', _job_id='job_id')
worker: Worker = worker(
Expand Down

0 comments on commit 3b35919

Please sign in to comment.