Skip to content

Commit

Permalink
Added the ability to fetch workers by queue (#911)
Browse files Browse the repository at this point in the history
* job.exc_info is now compressed.

* job.data is now stored in compressed format.

* Added worker_registration.unregister.

* Added worker_registration.get_keys().

* Modified Worker.all(), Worker.all_keys() and Worker.count() to accept "connection" and "queue" arguments.
  • Loading branch information
selwin committed Dec 18, 2017
1 parent 34c403e commit 7a3c85f
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 11 deletions.
33 changes: 24 additions & 9 deletions rq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from redis import WatchError

from . import worker_registration
from .compat import PY2, as_text, string_types, text_type
from .connections import get_current_connection, push_connection, pop_connection
from .defaults import DEFAULT_RESULT_TTL, DEFAULT_WORKER_TTL
Expand All @@ -34,6 +35,7 @@
from .utils import (backend_class, ensure_list, enum,
make_colorizer, utcformat, utcnow, utcparse)
from .version import VERSION
from .worker_registration import get_keys

try:
from procname import setprocname
Expand Down Expand Up @@ -90,7 +92,7 @@ def signal_name(signum):

class Worker(object):
redis_worker_namespace_prefix = 'rq:worker:'
redis_workers_keys = 'rq:workers'
redis_workers_keys = worker_registration.REDIS_WORKER_KEYS
death_penalty_class = UnixSignalDeathPenalty
queue_class = Queue
job_class = Job
Expand All @@ -99,19 +101,32 @@ class Worker(object):
log_result_lifespan = True

@classmethod
def all(cls, connection=None, job_class=None, queue_class=None):
def all(cls, connection=None, job_class=None, queue_class=None, queue=None):
"""Returns an iterable of all Workers.
"""
if connection is None:
if queue:
connection = queue.connection
elif connection is None:
connection = get_current_connection()
reported_working = connection.smembers(cls.redis_workers_keys)

worker_keys = get_keys(queue=queue, connection=connection)
workers = [cls.find_by_key(as_text(key),
connection=connection,
job_class=job_class,
queue_class=queue_class)
for key in reported_working]
for key in worker_keys]
return compact(workers)

@classmethod
def all_keys(cls, connection=None, queue=None):
return [as_text(key)
for key in get_keys(queue=queue, connection=connection)]

@classmethod
def count(cls, connection=None, queue=None):
"""Returns the number of workers by queue or connection"""
return len(get_keys(queue=queue, connection=connection))

@classmethod
def find_by_key(cls, worker_key, connection=None, job_class=None,
queue_class=None):
Expand All @@ -121,7 +136,7 @@ def find_by_key(cls, worker_key, connection=None, job_class=None,
"""
prefix = cls.redis_worker_namespace_prefix
if not worker_key.startswith(prefix):
raise ValueError('Not a valid RQ worker key: {0}'.format(worker_key))
raise ValueError('Not a valid RQ worker key: %s' % worker_key)

if connection is None:
connection = get_current_connection()
Expand Down Expand Up @@ -188,7 +203,7 @@ def __init__(self, queues, name=None, default_result_ttl=None, connection=None,
if exc_handler is not None:
self.push_exc_handler(exc_handler)
warnings.warn(
"use of exc_handler is deprecated, pass a list to exception_handlers instead.",
"exc_handler is deprecated, pass a list to exception_handlers instead.",
DeprecationWarning
)
elif isinstance(exception_handlers, list):
Expand Down Expand Up @@ -271,7 +286,7 @@ def register_birth(self):
p.hset(key, 'birth', now_in_string)
p.hset(key, 'last_heartbeat', now_in_string)
p.hset(key, 'queues', queues)
p.sadd(self.redis_workers_keys, key)
worker_registration.register(self, p)
p.expire(key, self.default_worker_ttl)
p.execute()

Expand All @@ -281,7 +296,7 @@ def register_death(self):
with self.connection._pipeline() as p:
# We cannot use self.state = 'dead' here, because that would
# rollback the pipeline
p.srem(self.redis_workers_keys, self.key)
worker_registration.unregister(self, p)
p.hset(self.key, 'death', utcformat(utcnow()))
p.expire(self.key, 60)
p.execute()
Expand Down
45 changes: 45 additions & 0 deletions rq/worker_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from .compat import as_text


WORKERS_BY_QUEUE_KEY = 'rq:workers:%s'
REDIS_WORKER_KEYS = 'rq:workers'


def register(worker, pipeline=None):
"""Store worker key in Redis so we can easily discover active workers."""
connection = pipeline if pipeline is not None else worker.connection
connection.sadd(worker.redis_workers_keys, worker.key)
for name in worker.queue_names():
redis_key = WORKERS_BY_QUEUE_KEY % name
connection.sadd(redis_key, worker.key)


def unregister(worker, pipeline=None):
"""Remove worker key from Redis."""
if pipeline is None:
connection = worker.connection._pipeline()
else:
connection = pipeline

connection.srem(worker.redis_workers_keys, worker.key)
for name in worker.queue_names():
redis_key = WORKERS_BY_QUEUE_KEY % name
connection.srem(redis_key, worker.key)

if pipeline is None:
connection.execute()


def get_keys(queue=None, connection=None):
"""Returnes a list of worker keys for a queue"""
if queue is None and connection is None:
raise ValueError('"queue" or "connection" argument is required')

if queue:
redis = queue.connection
redis_key = WORKERS_BY_QUEUE_KEY % queue.name
else:
redis = connection
redis_key = REDIS_WORKER_KEYS

return {as_text(key) for key in redis.smembers(redis_key)}
29 changes: 27 additions & 2 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,26 @@ def test_work_and_quit(self):
'Expected at least some work done.'
)

def test_worker_all(self):
"""Worker.all() works properly"""
foo_queue = Queue('foo')
bar_queue = Queue('bar')

w1 = Worker([foo_queue, bar_queue], name='w1')
w1.register_birth()
w2 = Worker([foo_queue], name='w2')
w2.register_birth()

self.assertEqual(
set(Worker.all(connection=foo_queue.connection)),
set([w1, w2])
)
self.assertEqual(set(Worker.all(queue=foo_queue)), set([w1, w2]))
self.assertEqual(set(Worker.all(queue=bar_queue)), set([w1]))

w1.register_death()
w2.register_death()

def test_find_by_key(self):
"""Worker.find_by_key restores queues, state and job_id."""
queues = [Queue('foo'), Queue('bar')]
Expand All @@ -119,7 +139,12 @@ def test_find_by_key(self):
self.assertEqual(worker.queues, queues)
self.assertEqual(worker.get_state(), WorkerStatus.STARTED)
self.assertEqual(worker._job_id, None)
w.register_death()
self.assertTrue(worker.key in Worker.all_keys(worker.connection))

# If worker is gone, its keys should also be removed
worker.connection.delete(worker.key)
Worker.find_by_key(worker.key)
self.assertFalse(worker.key in Worker.all_keys(worker.connection))

def test_worker_ttl(self):
"""Worker ttl."""
Expand Down Expand Up @@ -183,7 +208,7 @@ def test_work_is_unreadable(self):
# importable from the worker process.
job = Job.create(func=div_by_zero, args=(3,))
job.save()

job_data = job.data
invalid_data = job_data.replace(b'div_by_zero', b'nonexisting')
assert job_data != invalid_data
Expand Down
70 changes: 70 additions & 0 deletions tests/test_worker_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from tests import RQTestCase

from rq import Queue, Worker
from rq.worker_registration import (get_keys, register, unregister,
WORKERS_BY_QUEUE_KEY)


class TestWorkerRegistry(RQTestCase):

def test_worker_registration(self):
"""Ensure worker.key is correctly set in Redis."""
foo_queue = Queue(name='foo')
bar_queue = Queue(name='bar')
worker = Worker([foo_queue, bar_queue])

register(worker)
redis = worker.connection

self.assertTrue(redis.sismember(worker.redis_workers_keys, worker.key))
self.assertTrue(
redis.sismember(WORKERS_BY_QUEUE_KEY % foo_queue.name, worker.key)
)
self.assertTrue(
redis.sismember(WORKERS_BY_QUEUE_KEY % bar_queue.name, worker.key)
)

unregister(worker)
self.assertFalse(redis.sismember(worker.redis_workers_keys, worker.key))
self.assertFalse(
redis.sismember(WORKERS_BY_QUEUE_KEY % foo_queue.name, worker.key)
)
self.assertFalse(
redis.sismember(WORKERS_BY_QUEUE_KEY % bar_queue.name, worker.key)
)

def test_get_keys_by_queue(self):
"""get_keys_by_queue only returns active workers for that queue"""
foo_queue = Queue(name='foo')
bar_queue = Queue(name='bar')
baz_queue = Queue(name='baz')

worker1 = Worker([foo_queue, bar_queue])
worker2 = Worker([foo_queue])
worker3 = Worker([baz_queue])

self.assertEqual(set(), get_keys(foo_queue))

register(worker1)
register(worker2)
register(worker3)

# get_keys(queue) will return worker keys for that queue
self.assertEqual(
set([worker1.key, worker2.key]),
get_keys(foo_queue)
)
self.assertEqual(set([worker1.key]), get_keys(bar_queue))

# get_keys(connection=connection) will return all worker keys
self.assertEqual(
set([worker1.key, worker2.key, worker3.key]),
get_keys(connection=worker1.connection)
)

# Calling get_keys without arguments raises an exception
self.assertRaises(ValueError, get_keys)

unregister(worker1)
unregister(worker2)
unregister(worker3)

0 comments on commit 7a3c85f

Please sign in to comment.