Skip to content

Commit

Permalink
Merge pull request #29 from tillahoffmann/pfmq-exceptions
Browse files Browse the repository at this point in the history
Propagate exceptions from workers; improve testing.
  • Loading branch information
tillahoffmann committed Jun 6, 2018
2 parents 8699a2f + c493f2b commit 926a39c
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 48 deletions.
3 changes: 1 addition & 2 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ exclude_lines =
# Have to re-enable the standard pragma
pragma: no cover

# Don't complain if tests don't raise exceptions
raise
raise NotImplementedError

# Don't complain about representations not being covered
__repr__
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

all : tests docs

tests : lint_tests code_tests
tests : code_tests lint_tests

lint_tests :
pylint pythonflow

code_tests :
py.test --cov pythonflow --cov-fail-under=100 --cov-report=term-missing --cov-report=html --verbose --durations=5
py.test --cov pythonflow --cov-fail-under=100 --cov-report=term-missing --cov-report=html --verbose --durations=5 -s

docs :
sphinx-build -b doctest docs build
Expand Down
2 changes: 1 addition & 1 deletion pythonflow/pfmq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .task import Task
from .task import Task, SerializationError
from .broker import Broker
from .worker import Worker
9 changes: 9 additions & 0 deletions pythonflow/pfmq/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def __init__(self, start):
if start:
self.run_async()

STATUS = {
'ok': b'\x00',
'end': b'\x01',
'error': b'\x02',
'timeout': b'\x03',
'serialization_error': b'\x04',
}
STATUS.update({value: key for key, value in STATUS.items()})

@property
def is_alive(self):
"""
Expand Down
34 changes: 20 additions & 14 deletions pythonflow/pfmq/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,31 @@ def run(self): # pylint: disable=too-many-statements,too-many-locals,too-many-b
LOGGER.debug('received CANCEL signal on %s', self._cancel_address)
break

# Receive results or sign-up messages from the backend
# Receive responses or sign-up messages from the backend
if sockets.get(backend) == zmq.POLLIN:
worker, _, client, *message = backend.recv_multipart()
workers.add(worker)

if client:
_, identifier, result = message
LOGGER.debug('received RESPONSE with identifier %s from %s for %s',
int.from_bytes(identifier, 'little'), worker, client)
_, identifier, status, response = message
LOGGER.debug(
'received RESPONSE with identifier %s from %s for %s with status %s',
int.from_bytes(identifier, 'little'), worker, client,
self.STATUS[status]
)
# Try to forward the message to a waiting client
try:
clients.remove(client)
frontend.send_multipart([client, _, identifier, result])
self._forward_response(frontend, client, identifier, status, response)
# Add it to the cache otherwise
except KeyError:
cache.setdefault(client, []).append((identifier, result))
cache.setdefault(client, []).append((identifier, status, response))
else:
LOGGER.debug('received SIGN-UP message from %s; now %d workers', worker,
len(workers))
del worker

# Receive requests from the frontend and forward to the workers or return results
# Receive requests from the frontend, forward to the workers, and return responses
if sockets.get(frontend) == zmq.POLLIN:
client, _, identifier, *request = frontend.recv_multipart()
LOGGER.debug('received REQUEST with byte identifier %s from %s',
Expand All @@ -117,21 +120,24 @@ def run(self): # pylint: disable=too-many-statements,too-many-locals,too-many-b
int.from_bytes(identifier, 'little'), client, worker)

try:
identifier, result = cache[client].pop(0)
frontend.send_multipart([client, _, identifier, result])
LOGGER.debug('forwarded RESPONSE with identifier %s to %s',
int.from_bytes(identifier, 'little'), client)
self._forward_response(frontend, client, *cache[client].pop(0))
except (KeyError, IndexError):
# Send a dispatch notification if the task sent a new message
if identifier:
frontend.send_multipart([client, _, _])
LOGGER.debug('notified %s of REQUEST dispatch', client)
# Add the task to the list if tasks otherwise
# Add the task to the list of tasks waiting for responses otherwise
else:
clients.add(client)

LOGGER.debug('exiting communication loop')

@classmethod
def _forward_response(cls, frontend, client, identifier, status, response): # pylint: disable=too-many-arguments
frontend.send_multipart([client, b'', identifier, status, response])
LOGGER.debug('forwarded RESPONSE with identifier %s to %s with status %s',
int.from_bytes(identifier, 'little'), client, cls.STATUS[status])

def imap(self, requests, **kwargs):
"""
Convenience method for applying a target to requests remotely.
Expand All @@ -144,5 +150,5 @@ def apply(self, request, **kwargs):
"""
task = self.imap([request], start=False, **kwargs)
task.run()
_, result = task.results.get()
return result
for result in task.iter_results(timeout=0):
return result
53 changes: 38 additions & 15 deletions pythonflow/pfmq/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
LOGGER = logging.getLogger(__name__)


class SerializationError(RuntimeError):
"""
Error serialising a remote result.
"""


class Task(Base):
"""
A task that is executed remotely.
Expand Down Expand Up @@ -122,7 +128,7 @@ def run(self): # pylint: disable=too-many-statements,too-many-branches,too-many
LOGGER.debug('no more requests; waiting for responses')

socket.send_multipart(message)
LOGGER.debug('sent request with identifier %s: %s', identifier, message)
LOGGER.debug('sent REQUEST with identifier %s: %s', identifier, message)

# Add to the list of pending requests
if identifier is not None:
Expand All @@ -146,7 +152,7 @@ def run(self): # pylint: disable=too-many-statements,too-many-branches,too-many
message = "maximum number of retries (%d) for %s exceeded" % \
(self.max_retries, self.address)
LOGGER.error(message)
self.results.put((None, TimeoutError(message)))
self.results.put(('timeout', TimeoutError(message)))
return
break

Expand All @@ -155,34 +161,40 @@ def run(self): # pylint: disable=too-many-statements,too-many-branches,too-many

# Cancel the communication thread
if sockets.get(cancel) == zmq.POLLIN:
LOGGER.debug('received cancel signal on %s', self._cancel_address)
LOGGER.debug('received CANCEL signal on %s', self._cancel_address)
return

if sockets.get(socket) == zmq.POLLIN:
identifier, *result = socket.recv_multipart()
identifier, *response = socket.recv_multipart()
if not identifier:
LOGGER.debug('received dispatch notification')
continue

# Decode the identifier and remove the corresponding request from `pending`
identifier = int.from_bytes(identifier, 'little')
LOGGER.debug('received result for identifier %d (next: %d, end: %s)',
LOGGER.debug('received RESPONSE for identifier %d (next: %d, end: %s)',
identifier, next_identifier, last_identifier)
pending.pop(identifier, None)

# Drop the message if it is outdated
if identifier < next_identifier: # pragma: no cover
LOGGER.debug('dropped message with identifier %d (next: %d)',
LOGGER.debug('dropped RESPONSE with identifier %d (next: %d)',
identifier, next_identifier)
continue

# Add the message to the cache
cache[identifier] = self.loads(*result)
cache[identifier] = response
while True:
try:
self.results.put((next_identifier, cache.pop(next_identifier)))
status, response = cache.pop(next_identifier)
status = self.STATUS[status]
self.results.put(
(status, identifier if status == 'serialization_error' else
self.loads(response))
)

if next_identifier == last_identifier:
self.results.put((None, None))
self.results.put(('end', None))
return
next_identifier += 1
except KeyError:
Expand All @@ -198,12 +210,23 @@ def iter_results(self, timeout=None):
Timeout for getting results.
"""
while True:
identifier, result = self.results.get(timeout=timeout)
if identifier is None:
if result:
raise result
break
yield result
status, result = self.results.get(timeout=timeout)
if status == 'ok':
yield result
elif status in 'error':
value, tb = result # pylint: disable=invalid-name
LOGGER.error(tb)
raise value
elif status == 'timeout':
raise result
elif status == 'end':
return
elif status == 'serialization_error':
raise SerializationError(
"failed to serialize result for request with identifier %s" % result
)
else:
raise KeyError(status) # pragma: no cover

def __iter__(self):
return self.iter_results()
38 changes: 31 additions & 7 deletions pythonflow/pfmq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import logging
import pickle
import sys
import traceback
import uuid

import zmq
Expand Down Expand Up @@ -64,7 +66,7 @@ def __init__(self, target, address, dumps=None, loads=None, start=False, timeout

super(Worker, self).__init__(start)

def run(self):
def run(self): # pylint: disable=too-many-locals
context = zmq.Context.instance()
# Use a specific identity for the worker such that reconnects don't change it.
identity = uuid.uuid4().bytes
Expand Down Expand Up @@ -104,13 +106,35 @@ def run(self):
# Process messages
if sockets.get(socket) == zmq.POLLIN:
client, _, identifier, *request = socket.recv_multipart()
LOGGER.debug('received request with identifier %d from %s',
int.from_bytes(identifier, 'little'), client)
result = self.dumps(self.target(self.loads(*request)))
socket.send_multipart([client, _, identifier, result])
LOGGER.debug('sent result with identifier %s to %s',
LOGGER.debug('received REQUEST with identifier %d from %s',
int.from_bytes(identifier, 'little'), client)

try:
response = self.target(self.loads(*request))
status = self.STATUS['ok']
except Exception: # pylint: disable=broad-except
etype, value, tb = sys.exc_info() # pylint: disable=invalid-name
response = value, "".join(traceback.format_exception(etype, value, tb))
status = self.STATUS['error']
LOGGER.exception("failed to process REQUEST with identifier %d from %s",
int.from_bytes(identifier, 'little'), client)

try:
response = self.dumps(response)
except Exception: # pylint: disable=broad-except
LOGGER.exception(
"failed to serialise RESPONSE with identifier %d for %s",
int.from_bytes(identifier, 'little'), client
)
response = b""
status = self.STATUS['serialization_error']

socket.send_multipart([client, b'', identifier, status, response])
LOGGER.debug(
'sent RESPONSE with identifier %s to %s with status %s',
int.from_bytes(identifier, 'little'), client, self.STATUS[status]
)

LOGGER.error("maximum number of retries (%d) for %s exceeded", self.max_retries,
self.address)

Expand All @@ -129,5 +153,5 @@ def _target(request):
return graph(request['fetches'], request['context'])
elif 'contexts' in request:
return [graph(request['fetches'], context) for context in request['contexts']]
raise KeyError
raise KeyError("`context` or `contexts` must be in the request")
return cls(_target, *args, **kwargs)
36 changes: 30 additions & 6 deletions tests/test_pfmq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import multiprocessing
import time
import random
import uuid
Expand All @@ -17,7 +16,8 @@ def broker(backend_address):
b = pfmq.Broker(backend_address)
b.run_async()
yield b
b.cancel()
if b.is_alive:
b.cancel()


@pytest.fixture
Expand All @@ -27,7 +27,9 @@ def workers(broker):
y = pf.placeholder('y')
sleep = pf.func_op(time.sleep, pf.func_op(random.uniform, 0, .1))
with pf.control_dependencies([sleep]):
z = (x + y).set_name('z')
(x / y).set_name('z')
# Can't pickle entire modules
pf.constant(pf).set_name('not_pickleable')

# Create multiple workers
_workers = []
Expand All @@ -53,13 +55,19 @@ def test_workers_running(workers):
def test_apply(broker, workers):
request = {'fetches': 'z', 'context': {'x': 1, 'y': 3}}
result = broker.apply(request)
assert result == 4
assert result == 1 / 3


def test_apply_error(broker, workers):
request = {'fetches': 'z', 'context': {'x': 1, 'y': 0}}
with pytest.raises(ZeroDivisionError):
broker.apply(request)


def test_apply_batch(broker, workers):
request = {'fetches': 'z', 'contexts': [{'x': 1, 'y': 3 + i} for i in range(5)]}
result = broker.apply(request)
assert result == [4 + i for i in range(5)]
assert result == [1 / (3 + i) for i in range(5)]


def test_cancel_task():
Expand All @@ -72,7 +80,7 @@ def test_imap(broker, workers):
requests = [{'fetches': 'z', 'context': {'x': 1, 'y': 3 + i}} for i in range(200)]
task = broker.imap(requests)
for i, result in enumerate(task):
assert result == i + 4
assert result == 1 / (3 + i)
# Make sure the task finishes
task._thread.join()

Expand All @@ -93,3 +101,19 @@ def test_task_timeout(backend_address):
assert duration > .3
with pytest.raises(TimeoutError):
list(task)


def test_cancel_not_running(broker):
broker.cancel()
with pytest.raises(RuntimeError):
broker.cancel()


def test_not_pickleable(broker, workers):
with pytest.raises(pfmq.SerializationError):
broker.apply({'fetches': 'not_pickleable', 'context': {}})


def test_no_context(broker, workers):
with pytest.raises(KeyError):
broker.apply({})

0 comments on commit 926a39c

Please sign in to comment.