Skip to content

Commit

Permalink
Improve behavior of halt_job_processing for SerialWorker and MainWorker
Browse files Browse the repository at this point in the history
  • Loading branch information
tdg5 committed Apr 23, 2024
1 parent 26b0d72 commit 561d171
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 53 deletions.
8 changes: 4 additions & 4 deletions .meta/coverage/report.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ reqless/queue_resolvers/transforming_queue_resolver.py
reqless/throttle.py 29 0 100%
reqless/util.py 7 0 100%
reqless/workers/__init__.py 2 0 100%
reqless/workers/base_worker.py 69 3 96%
reqless/workers/base_worker.py 65 2 97%
reqless/workers/forking_worker.py 58 3 95%
reqless/workers/gevent_worker.py 53 0 100%
reqless/workers/main_worker.py 8 0 100%
reqless/workers/serial_worker.py 28 0 100%
reqless/workers/main_worker.py 13 0 100%
reqless/workers/serial_worker.py 27 0 100%
reqless/workers/signals.py 13 0 100%
reqless/workers/util.py 39 1 97%
-----------------------------------------------------------------------------------------------------
TOTAL 1357 25 98%
TOTAL 1357 24 98%
4 changes: 2 additions & 2 deletions reqless/workers/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ def listen(self, listener: Listener) -> None:
except Exception:
logger.exception("Pubsub error")

def halt_job_processing(self, jid: str) -> None:
def halt_job_processing(self, jid: str) -> None: # pragma: no cover
"""Stop processing the provided jid"""
raise NotImplementedError('Derived classes must override "halt_job_processing"')

def run(self) -> None:
def run(self) -> None: # pragma: no cover
"""Run this worker"""
raise NotImplementedError('Derived classes must override "run"')

Expand Down
11 changes: 11 additions & 0 deletions reqless/workers/main_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import _thread
import threading

from reqless.workers.serial_worker import SerialWorker
from reqless.workers.signals import basic_signal_handler, register_signal_handler

Expand All @@ -9,6 +12,14 @@ class MainWorker(SerialWorker):
def before_run(self) -> None:
register_signal_handler(handler=basic_signal_handler(on_quit=self.stop))

def halt_job_processing(self, jid: str) -> None:
"""Since this worker expects to occupy the main thread and this method
is most likely to be called by the listener thread, we can interrupt
the main thread to force a job to halt processing. We should probably
only do this when called from a thread other than the main thread."""
if threading.current_thread() is not threading.main_thread():
_thread.interrupt_main()

def run(self) -> None:
self.before_run()
super().run()
11 changes: 8 additions & 3 deletions reqless/workers/serial_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ def __init__(
)

def halt_job_processing(self, jid: str) -> None:
"""The best way to do this is to fall on our sword"""
if jid == self.jid:
exit(1)
"""Since this method is most likely to be called by the listener, and
the worker is definitely running in a different thread from the
listenter, there's not a lot we can reliably do here. Trying to exit
would only kill the listener thread, while the thread doing the actual
work continued. So, in this scenario, we have to depend on the job
doing a good job of heartbeating since that's the best way for the job
to learn that it should halt. As such, do nothing."""
pass

def run(self) -> None:
"""Run jobs, popping one after another"""
Expand Down
21 changes: 21 additions & 0 deletions reqless_test/common.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
"""A base class for all of our common tests"""

import json
import logging
import time
import unittest
from os import path
from typing import List

from redis import Redis

import reqless
from reqless import logger
from reqless.abstract import AbstractJob


class BlockingJob:
"""Job that can block until a given file is removed from the file system"""

@staticmethod
def process(job: AbstractJob) -> None:
"""Dummy job"""
data_dict = json.loads(job.data)
blocker_file = data_dict.get("blocker_file")
if blocker_file:
while path.exists(blocker_file):
time.sleep(0.1)
try:
job.complete()
except Exception:
logger.exception("Unable to complete job %s" % job.jid)


class NoopJob:
@staticmethod
def process(job: AbstractJob) -> None:
Expand Down
3 changes: 1 addition & 2 deletions reqless_test/workers/test_base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def test_resume(self) -> None:
"""We should be able to resume jobs"""
queue = self.worker.client.queues["foo"]
queue.put("reqless_test.common.NoopJob", "{}")
job = self.client.queues["foo"].peek()
assert isinstance(job, AbstractJob)
job = self.pop_one(self.client, "foo")
# Now, we'll create a new worker and make sure it gets that job first
worker = BaseWorker(["foo"], self.client, resume=[job])
job_from_worker = next(worker.jobs())
Expand Down
37 changes: 32 additions & 5 deletions reqless_test/workers/test_main_worker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import json
import time
from tempfile import NamedTemporaryFile
from threading import Thread
from typing import Generator, Optional
from unittest.mock import patch

import pytest

from reqless.abstract import AbstractJob
from reqless.workers.main_worker import MainWorker
from reqless_test.common import NoopJob, TestReqless
from reqless_test.common import BlockingJob, TestReqless


class ShortLivedMainWorker(MainWorker):
Expand All @@ -15,9 +21,6 @@ def jobs(self) -> Generator[Optional[AbstractJob], None, None]:
for _ in range(5):
yield next(generator)

def halt_job_processing(self, jid: str) -> None:
raise KeyboardInterrupt()


class TestMainWorker(TestReqless):
"""Test the worker"""
Expand All @@ -31,7 +34,7 @@ def tearDown(self) -> None:

def test_signal_handlers_are_registered(self) -> None:
"""Test signal handler would be registered"""
jids = [self.queue.put(NoopJob, "{}") for _ in range(5)]
jids = [self.queue.put(BlockingJob, "{}") for _ in range(5)]
with patch(
"reqless.workers.main_worker.register_signal_handler",
) as register_signal_handler_mock:
Expand All @@ -43,3 +46,27 @@ def test_signal_handlers_are_registered(self) -> None:
states.append(job.state)
self.assertEqual(states, ["complete"] * 5)
register_signal_handler_mock.assert_called_once()

def test_halt_job_processing(self) -> None:
"""The worker should be able to stop processing if need be"""
temp_file = NamedTemporaryFile()
jid = self.queue.put(BlockingJob, json.dumps({"blocker_file": temp_file.name}))

def job_killer() -> None:
job = self.client.jobs[jid]
assert job is not None and isinstance(job, AbstractJob)
# Now, we'll timeout one of the jobs and ensure that
# halt_job_processing is invoked
while job.state != "running":
time.sleep(0.01)
job = self.client.jobs[jid]
assert job is not None and isinstance(job, AbstractJob)
job.timeout()
temp_file.close()

thread = Thread(target=job_killer)
thread.start()
with pytest.raises(KeyboardInterrupt) as exinfo:
ShortLivedMainWorker(["foo"], self.client, interval=0.2).run()
assert KeyboardInterrupt == exinfo.type
thread.join()
49 changes: 12 additions & 37 deletions reqless_test/workers/test_serial_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,28 @@

import json
import time
from os import path
from tempfile import NamedTemporaryFile
from threading import Thread
from typing import Generator, Optional

from reqless import logger
from reqless.abstract import AbstractJob
from reqless.listener import Listener
from reqless.workers.serial_worker import SerialWorker
from reqless_test.common import TestReqless
from reqless_test.common import BlockingJob, TestReqless


class SerialJob:
"""Dummy class"""

@staticmethod
def foo(job: AbstractJob) -> None:
"""Dummy job"""
data_dict = json.loads(job.data)
blocker_file = data_dict.get("blocker_file")
if blocker_file:
while path.exists(blocker_file):
time.sleep(0.1)
try:
job.complete()
except Exception:
logger.exception("Unable to complete job %s" % job.jid)


class Worker(SerialWorker):
"""A worker that limits the number of jobs it runs"""

class ShortLivedSerialWorker(SerialWorker):
def jobs(self) -> Generator[Optional[AbstractJob], None, None]:
"""Yield only a few jobs"""
generator = SerialWorker.jobs(self)
for _ in range(5):
yield next(generator)

def halt_job_processing(self, jid: str) -> None:
"""We'll push a message to the database instead of falling on our sword"""
self.client.database.rpush("foo", jid)
raise KeyboardInterrupt()


class NoListenWorker(Worker):
"""A worker that just won't listen"""

class NoListenSerialWorker(ShortLivedSerialWorker):
def listen(self, listener: Listener) -> None:
"""Don't listen for lost locks"""
pass


Expand All @@ -69,8 +42,8 @@ def tearDown(self) -> None:

def test_basic(self) -> None:
"""Can complete jobs in a basic way"""
jids = [self.queue.put(SerialJob, "{}") for _ in range(5)]
NoListenWorker(["foo"], self.client, interval=0.2).run()
jids = [self.queue.put(BlockingJob, "{}") for _ in range(5)]
NoListenSerialWorker(["foo"], self.client, interval=0.2).run()
states = []
for jid in jids:
job = self.client.jobs[jid]
Expand All @@ -80,22 +53,24 @@ def test_basic(self) -> None:

def test_jobs(self) -> None:
"""The jobs method yields None if there are no jobs"""
worker = NoListenWorker(["foo"], self.client, interval=0.2)
worker = NoListenSerialWorker(["foo"], self.client, interval=0.2)
self.assertEqual(next(worker.jobs()), None)

def test_sleeps(self) -> None:
"""Make sure the client sleeps if there aren't jobs to be had"""
for _ in range(4):
self.queue.put(SerialJob, "{}")
self.queue.put(BlockingJob, "{}")
before = time.time()
NoListenWorker(["foo"], self.client, interval=0.2).run()
NoListenSerialWorker(["foo"], self.client, interval=0.2).run()
self.assertGreater(time.time() - before, 0.2)

def test_lost_locks(self) -> None:
"""The worker should be able to stop processing if need be"""
temp_file = NamedTemporaryFile()
jid = self.queue.put(SerialJob, json.dumps({"blocker_file": temp_file.name}))
self.thread = Thread(target=Worker(["foo"], self.client, interval=0.2).run)
jid = self.queue.put(BlockingJob, json.dumps({"blocker_file": temp_file.name}))
self.thread = Thread(
target=ShortLivedSerialWorker(["foo"], self.client, interval=0.2).run
)
self.thread.start()
job = self.client.jobs[jid]
assert job is not None and isinstance(job, AbstractJob)
Expand Down

0 comments on commit 561d171

Please sign in to comment.