diff --git a/HISTORY.rst b/HISTORY.rst index f34a9c5a..d30b095a 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,8 +3,9 @@ History ------- -v0.17 (unreleased) -.................. +v0.17.0 (unreleased) +.................... +* add ``worker.queue_read_limit``, fix #141, by @rubik * custom serializers, eg. to use msgpack rather than pickle, #143 by @rubik v0.16.1 (2019-08-02) diff --git a/arq/worker.py b/arq/worker.py index a3345bcd..66e5089b 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -138,9 +138,13 @@ class Worker: :param job_timeout: default job timeout (max run time) :param keep_result: default duration to keep job results for :param poll_delay: duration between polling the queue for new jobs + :param queue_read_limit: the maximum number of jobs to pull from the queue each time it's polled; by default it + equals ``max_jobs`` :param max_tries: default maximum number of times to retry a job :param health_check_interval: how often to set the health check key :param health_check_key: redis key under which health check is set + :param retry_jobs: whether to retry jobs on Retry or CancelledError or not + :param max_burst_jobs: the maximum number of jobs to process in burst mode (disabled with negative values) :param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps :param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads """ @@ -160,6 +164,7 @@ def __init__( job_timeout: SecondsTimedelta = 300, keep_result: SecondsTimedelta = 3600, poll_delay: SecondsTimedelta = 0.5, + queue_read_limit: Optional[int] = None, max_tries: int = 5, health_check_interval: SecondsTimedelta = 3600, health_check_key: Optional[str] = None, @@ -184,6 +189,8 @@ def __init__( self.job_timeout_s = to_seconds(job_timeout) self.keep_result_s = to_seconds(keep_result) self.poll_delay_s = to_seconds(poll_delay) + self.queue_read_limit = queue_read_limit or max_jobs + self._queue_read_offset = 0 self.max_tries = max_tries self.health_check_interval = to_seconds(health_check_interval) if health_check_key is None: @@ -264,18 +271,7 @@ async def main(self): await self.on_startup(self.ctx) async for _ in poll(self.poll_delay_s): # noqa F841 - async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs - now = timestamp_ms() - job_ids = await self.pool.zrangebyscore(self.queue_name, max=now) - await self.run_jobs(job_ids) - - # required to make sure errors in run_job get propagated - for t in self.tasks: - if t.done(): - self.tasks.remove(t) - t.result() - - await self.heart_beat() + await self._poll_iteration() if self.burst: if ( @@ -287,6 +283,22 @@ async def main(self): if queued_jobs == 0: return + async def _poll_iteration(self): + async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs + now = timestamp_ms() + job_ids = await self.pool.zrangebyscore( + self.queue_name, offset=self._queue_read_offset, count=self.queue_read_limit, max=now + ) + await self.run_jobs(job_ids) + + # required to make sure errors in run_job get propagated + for t in self.tasks: + if t.done(): + self.tasks.remove(t) + t.result() + + await self.heart_beat() + async def run_jobs(self, job_ids): for job_id in job_ids: await self.sem.acquire() diff --git a/tests/conftest.py b/tests/conftest.py index dca8b0a7..06104580 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,9 +37,11 @@ async def arq_redis_msgpack(loop): async def worker(arq_redis): worker_: Worker = None - def create(functions=[], burst=True, poll_delay=0, **kwargs): + def create(functions=[], burst=True, poll_delay=0, max_jobs=10, **kwargs): nonlocal worker_ - worker_ = Worker(functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, **kwargs) + worker_ = Worker( + functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs, **kwargs + ) return worker_ yield create diff --git a/tests/test_worker.py b/tests/test_worker.py index 8f7ca6bf..61165d39 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -28,6 +28,7 @@ class Settings: functions = [func(foobar, name='foobar')] burst = True poll_delay = 0 + queue_read_limit = 10 loop.run_until_complete(arq_redis.enqueue_job('foobar')) worker = run_worker(Settings) @@ -453,6 +454,56 @@ async def test_repeat_job_result(arq_redis: ArqRedis, worker): assert await arq_redis.enqueue_job('foobar', _job_id='job_id') is None +async def test_queue_read_limit_equals_max_jobs(arq_redis: ArqRedis, worker): + for _ in range(4): + await arq_redis.enqueue_job('foobar') + + assert await arq_redis.zcard(default_queue_name) == 4 + worker: Worker = worker(functions=[foobar], max_jobs=2) + assert worker.jobs_complete == 0 + assert worker.jobs_failed == 0 + assert worker.jobs_retried == 0 + + await worker._poll_iteration() + await asyncio.sleep(0.1) + assert await arq_redis.zcard(default_queue_name) == 2 + assert worker.jobs_complete == 2 + assert worker.jobs_failed == 0 + assert worker.jobs_retried == 0 + + await worker._poll_iteration() + await asyncio.sleep(0.1) + assert await arq_redis.zcard(default_queue_name) == 0 + assert worker.jobs_complete == 4 + assert worker.jobs_failed == 0 + assert worker.jobs_retried == 0 + + +async def test_custom_queue_read_limit(arq_redis: ArqRedis, worker): + for _ in range(4): + await arq_redis.enqueue_job('foobar') + + assert await arq_redis.zcard(default_queue_name) == 4 + worker: Worker = worker(functions=[foobar], max_jobs=4, queue_read_limit=2) + assert worker.jobs_complete == 0 + assert worker.jobs_failed == 0 + assert worker.jobs_retried == 0 + + await worker._poll_iteration() + await asyncio.sleep(0.1) + assert await arq_redis.zcard(default_queue_name) == 2 + assert worker.jobs_complete == 2 + assert worker.jobs_failed == 0 + assert worker.jobs_retried == 0 + + await worker._poll_iteration() + await asyncio.sleep(0.1) + assert await arq_redis.zcard(default_queue_name) == 0 + assert worker.jobs_complete == 4 + assert worker.jobs_failed == 0 + assert worker.jobs_retried == 0 + + async def test_custom_serializers(arq_redis_msgpack: ArqRedis, worker): j = await arq_redis_msgpack.enqueue_job('foobar', _job_id='job_id') worker: Worker = worker(