Skip to content

Commit

Permalink
Merge pull request #32 from tillahoffmann/context
Browse files Browse the repository at this point in the history
Multiple improvements and bug fixes
  • Loading branch information
tillahoffmann committed Jun 14, 2018
2 parents fea2461 + 1182216 commit ed28c8c
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 51 deletions.
55 changes: 41 additions & 14 deletions pythonflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import functools
import importlib
import operator
import traceback
import uuid

from .util import _noop_callback
Expand Down Expand Up @@ -159,7 +160,7 @@ def apply(self, fetches, context=None, *, callback=None, **kwargs):

fetches = [self.normalize_operation(operation) for operation in fetches]
context = self.normalize_context(context, **kwargs)
values = [fetch.evaluate(context, callback=callback) for fetch in fetches]
values = [fetch.evaluate_operation(fetch, context, callback=callback) for fetch in fetches]
return values[0] if single else tuple(values)

__call__ = apply
Expand Down Expand Up @@ -189,7 +190,13 @@ def get_active_graph(graph=None):
return graph


class Operation: # pylint:disable=too-few-public-methods
class EvaluationError(RuntimeError):
"""
Failed to evaluate an operation.
"""


class Operation: # pylint:disable=too-few-public-methods,too-many-instance-attributes
"""
Base class for operations.
Expand Down Expand Up @@ -221,6 +228,8 @@ def __init__(self, *args, length=None, graph=None, name=None, dependencies=None,
# Get a list of all dependencies relevant to this operation
self.dependencies = [] if dependencies is None else dependencies
self.dependencies.extend(self.graph.dependencies)
# Get the stack context so we can report where the operation was defined
self._stack = traceback.extract_stack()

def __getstate__(self):
return self.__dict__
Expand Down Expand Up @@ -322,18 +331,36 @@ def evaluate_operation(cls, operation, context, **kwargs):
"""
Evaluate an operation or constant given a context.
"""
if isinstance(operation, Operation):
return operation.evaluate(context, **kwargs)
partial = functools.partial(cls.evaluate_operation, context=context, **kwargs)
if isinstance(operation, tuple):
return tuple(partial(element) for element in operation)
if isinstance(operation, list):
return [partial(element) for element in operation]
if isinstance(operation, dict):
return {partial(key): partial(value) for key, value in operation.items()}
if isinstance(operation, slice):
return slice(*[partial(getattr(operation, attr)) for attr in ['start', 'stop', 'step']])
return operation
try:
if isinstance(operation, Operation):
return operation.evaluate(context, **kwargs)
partial = functools.partial(cls.evaluate_operation, context=context, **kwargs)
if isinstance(operation, tuple):
return tuple(partial(element) for element in operation)
if isinstance(operation, list):
return [partial(element) for element in operation]
if isinstance(operation, dict):
return {partial(key): partial(value) for key, value in operation.items()}
if isinstance(operation, slice):
return slice(*[partial(getattr(operation, attr))
for attr in ['start', 'stop', 'step']])
return operation
except Exception as ex:
stack = []
interactive = False
for frame in reversed(operation._stack): # pylint: disable=protected-access
# Do not capture any internal stack traces
if 'pythonflow' in frame.filename:
continue
# Stop tracing at the last interactive cell
if interactive and not frame.filename.startswith('<'):
break # pragma: no cover
interactive = frame.filename.startswith('<')
stack.append(frame)

stack = "".join(traceback.format_list(reversed(stack)))
message = "Failed to evaluate operation `%s` defined at:\n\n%s" % (operation, stack)
raise ex from EvaluationError(message)

def __bool__(self):
return True
Expand Down
19 changes: 11 additions & 8 deletions pythonflow/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ def __init__(self, predicate, x, y=None, *, length=None, name=None, dependencies

def evaluate(self, context, callback=None):
# Evaluate all dependencies first
self.evaluate_dependencies(context)
callback = callback or _noop_callback
self.evaluate_dependencies(context, callback)

predicate, x, y = self.args # pylint: disable=E0632,C0103
# Evaluate the predicate and pick the right operation
predicate = self.evaluate_operation(predicate, context)
callback = callback or _noop_callback
predicate = self.evaluate_operation(predicate, context, callback=callback)
with callback(self, context):
context[self] = value = self.evaluate_operation(x if predicate else y, context)
value = self.evaluate_operation(x if predicate else y, context, callback=callback)
context[self] = value
return value


Expand All @@ -83,20 +84,22 @@ def __init__(self, operation, except_=None, finally_=None, **kwargs):

def evaluate(self, context, callback=None):
# Evaluate all dependencies first
self.evaluate_dependencies(context)
callback = callback or _noop_callback
self.evaluate_dependencies(context, callback=callback)

operation, except_, finally_ = self.args # pylint: disable=E0632,C0103
callback = callback or _noop_callback
with callback(self, context):
try:
context[self] = value = self.evaluate_operation(operation, context)
value = self.evaluate_operation(operation, context, callback=callback)
context[self] = value
return value
except:
# Check the exceptions
_, ex, _ = sys.exc_info()
for type_, alternative in except_:
if isinstance(ex, type_):
context[self] = value = self.evaluate_operation(alternative, context)
value = self.evaluate_operation(alternative, context, callback=callback)
context[self] = value
return value
raise
finally:
Expand Down
2 changes: 1 addition & 1 deletion pythonflow/pfmq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .task import Task, SerializationError
from .task import Task, SerializationError, apply
from .broker import Broker
from .worker import Worker
17 changes: 15 additions & 2 deletions pythonflow/pfmq/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def __init__(self, start):
}
STATUS.update({value: key for key, value in STATUS.items()})

def __enter__(self):
self.run_async()
return self

def __exit__(self, *_):
self.cancel()

@property
def is_alive(self):
"""
Expand All @@ -65,13 +72,19 @@ def cancel(self, timeout=None):
----------
timeout : float
Timeout for joining the background thread.
Returns
-------
cancelled : bool
Whether the background thread was cancelled. `False` if the background thread was not
running.
"""
if self.is_alive:
self._cancel_parent.send_multipart([b''])
self._thread.join(timeout)
self._cancel_parent.close()
else:
raise RuntimeError('background thread is not running')
return True
return False

def run_async(self):
"""
Expand Down
45 changes: 31 additions & 14 deletions pythonflow/pfmq/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import zmq

from ._base import Base
from .task import Task
from .task import Task, apply


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,7 +92,7 @@ def run(self): # pylint: disable=too-many-statements,too-many-locals,too-many-b
_, identifier, status, response = message
LOGGER.debug(
'received RESPONSE with identifier %s from %s for %s with status %s',
int.from_bytes(identifier, 'little'), worker, client,
int.from_bytes(identifier, 'little'), worker.hex(), client.hex(),
self.STATUS[status]
)
# Try to forward the message to a waiting client
Expand All @@ -103,29 +103,29 @@ def run(self): # pylint: disable=too-many-statements,too-many-locals,too-many-b
except KeyError:
cache.setdefault(client, []).append((identifier, status, response))
else:
LOGGER.debug('received SIGN-UP message from %s; now %d workers', worker,
len(workers))
del worker
LOGGER.debug('received SIGN-UP message from %s; now %d workers',
worker.hex(), len(workers))

# Receive requests from the frontend, forward to the workers, and return responses
if sockets.get(frontend) == zmq.POLLIN:
client, _, identifier, *request = frontend.recv_multipart()
LOGGER.debug('received REQUEST with byte identifier %s from %s',
identifier, client)
identifier, client.hex())

if identifier:
worker = workers.pop()
backend.send_multipart([worker, _, client, _, identifier, *request])
LOGGER.debug('forwarded REQUEST with identifier %s from %s to %s',
int.from_bytes(identifier, 'little'), client, worker)
int.from_bytes(identifier, 'little'), client.hex(),
worker.hex())

try:
self._forward_response(frontend, client, *cache[client].pop(0))
except (KeyError, IndexError):
# Send a dispatch notification if the task sent a new message
if identifier:
frontend.send_multipart([client, _, _])
LOGGER.debug('notified %s of REQUEST dispatch', client)
LOGGER.debug('notified %s of REQUEST dispatch', client.hex())
# Add the task to the list of tasks waiting for responses otherwise
else:
clients.add(client)
Expand All @@ -140,20 +140,37 @@ def _forward_response(cls, frontend, client, identifier, status, response): # p

def imap(self, requests, **kwargs):
"""
Convenience method for applying a target to requests remotely.
Process a sequence of requests remotely.
Parameters
----------
requsests : iterable
Sequence of requests to process.
Returns
-------
task : Task
Remote task that can be iterated over.
"""
if not self.is_alive:
raise RuntimeError("broker is not running")
return Task(requests, self.frontend_address, **kwargs)

def apply(self, request, **kwargs):
"""
Convenience method for applying a target to a request remotely.
Process a request remotely.
Parameters
----------
request : object
Request to process.
Returns
-------
ressult : object
Result of remotely-processed request.
"""
if not self.is_alive:
raise RuntimeError("broker is not running")

task = self.imap([request], start=False, **kwargs)
task.run()
for result in task.iter_results(timeout=0):
return result
return apply(request, self.frontend_address, **kwargs)
22 changes: 22 additions & 0 deletions pythonflow/pfmq/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,25 @@ def iter_results(self, timeout=None):

def __iter__(self):
return self.iter_results()


def apply(request, frontend_address, **kwargs):
"""
Process a request remotely.
Parameters
----------
request : object
Request to process.
frontend_address : str
Address of the broker frontend.
Returns
-------
ressult : object
Result of remotely-processed request.
"""
task = Task([request], frontend_address, start=False, **kwargs)
task.run()
for result in task.iter_results(timeout=0):
return result
8 changes: 4 additions & 4 deletions pythonflow/pfmq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def run(self): # pylint: disable=too-many-locals
if sockets.get(socket) == zmq.POLLIN:
client, _, identifier, *request = socket.recv_multipart()
LOGGER.debug('received REQUEST with identifier %d from %s',
int.from_bytes(identifier, 'little'), client)
int.from_bytes(identifier, 'little'), client.hex())

try:
response = self.target(self.loads(*request))
Expand All @@ -117,22 +117,22 @@ def run(self): # pylint: disable=too-many-locals
response = value, "".join(traceback.format_exception(etype, value, tb))
status = self.STATUS['error']
LOGGER.exception("failed to process REQUEST with identifier %d from %s",
int.from_bytes(identifier, 'little'), client)
int.from_bytes(identifier, 'little'), client.hex())

try:
response = self.dumps(response)
except Exception: # pylint: disable=broad-except
LOGGER.exception(
"failed to serialise RESPONSE with identifier %d for %s",
int.from_bytes(identifier, 'little'), client
int.from_bytes(identifier, 'little'), client.hex()
)
response = b""
status = self.STATUS['serialization_error']

socket.send_multipart([client, b'', identifier, status, response])
LOGGER.debug(
'sent RESPONSE with identifier %s to %s with status %s',
int.from_bytes(identifier, 'little'), client, self.STATUS[status]
int.from_bytes(identifier, 'little'), client.hex(), self.STATUS[status]
)

LOGGER.error("maximum number of retries (%d) for %s exceeded", self.max_retries,
Expand Down

0 comments on commit ed28c8c

Please sign in to comment.