Browse files

implement fanout

  • Loading branch information...
1 parent 61364aa commit 846fec04fa3a550bc1361b21495dadbf932158cd @sublee committed Apr 4, 2013
Showing with 276 additions and 104 deletions.
  1. +186 −64 zeronimo/core.py
  2. +16 −0 zeronimo/exceptions.py
  3. +3 −0 zeronimo/functional.py
  4. +71 −40 zeronimotests.py
View
250 zeronimo/core.py
@@ -10,14 +10,15 @@
from contextlib import contextmanager, nested
import functools
from types import MethodType
-from uuid import uuid4
+import uuid
-from gevent import joinall, spawn
+from gevent import joinall, spawn, Timeout
from gevent.coros import Semaphore
from gevent.event import AsyncResult
-from gevent.queue import Queue
+from gevent.queue import Queue, Empty
import zmq.green as zmq
+from .exceptions import ZeronimoError, NoWorkerError
from .functional import extract_blueprint, make_fingerprint, should_yield
@@ -39,6 +40,23 @@
PUSH_PULL = (zmq.PUSH, zmq.PULL)
+def st(x):
+ return SOCKET_TYPE_NAMES[x]
+
+
+def dt(x):
+ return {1: 'ACK', 2: 'RETURN', 3: 'RAISE', 4: 'YIELD', 5: 'BREAK'}[x]
+
+
+def generate_inproc_addr():
+ return 'inproc://{0}'.format(uuid_str())
+
+
+def uuid_str():
+ import hashlib
+ return hashlib.md5(str(uuid.uuid4())).hexdigest()[:6]
+
+
class Communicator(object):
"""Manages ZeroMQ sockets."""
@@ -47,14 +65,19 @@ class Communicator(object):
def __new__(cls, *args, **kwargs):
obj = super(Communicator, cls).__new__(cls)
+ obj._running_lock = Semaphore()
def run(self):
- obj.running += 1
+ if self._running_lock.locked():
+ return
try:
- obj._actual_run()
+ with self._running_lock:
+ obj.running += 1
+ rv = obj._run()
finally:
obj.running -= 1
assert obj.running >= 0
- obj.run, obj._actual_run = MethodType(run, obj), obj.run
+ return rv
+ obj.run, obj._run = MethodType(run, obj), obj.run
return obj
def __init__(self, context=None):
@@ -69,88 +92,107 @@ def __del__(self):
class Worker(Communicator):
- addr = None
+ addrs = None
+ fanout_addrs = None
+ fanout_filters = None
blueprint = None
-
- def __init__(self, obj, addr=None, **kwargs):
- if addr is None:
- addr = 'inproc://{0}'.format(str(uuid4()))
- self.addr = addr
+ fingerprint = None
+
+ def __init__(self, obj, addrs=None, fanout_addrs=None, fanout_filters='',
+ **kwargs):
+ if addrs is None:
+ addrs = [generate_inproc_addr()]
+ if fanout_addrs is None:
+ fanout_addrs = [generate_inproc_addr()]
+ self.addrs = addrs
+ self.fanout_addrs = fanout_addrs
+ self.fanout_filters = fanout_filters
self.blueprint = extract_blueprint(obj)
self.fingerprint = make_fingerprint(self.blueprint)
super(Worker, self).__init__(**kwargs)
def possible_addrs(self, socket_type):
if socket_type == zmq.PULL:
- return [self.addr]
+ return self.addrs
elif socket_type == zmq.SUB:
- return []
- raise NotImplementedError
+ return self.fanout_addrs
else:
socket_type_name = SOCKET_TYPE_NAMES[socket_type]
raise ValueError('{!r} is not acceptable'.format(socket_type_name))
- def run(self):
- joinall([spawn(self.serve, zmq.PULL, self.addr)])
- #spawn(self.serve, zmq.SUB, self.fanout_addr)])
-
- def serve(self, sock_type, addr):
- sock = self.context.socket(sock_type)
- sock.bind(addr)
- while self.running:
- spawn(self.task_received, *sock.recv_pyobj())
-
- def task_received(self, customer_addr, task_id, fn, args, kwargs):
+ def run_task(self, customer_addr, task_id, fn, args, kwargs):
spec = self.blueprint[fn]
+ run_id = uuid_str()
+ print 'worker recv %s%r from %s:%s of %s' % \
+ (fn, args, task_id, run_id, customer_addr)
if spec.reply:
sock = self.context.socket(zmq.PUSH)
sock.connect(customer_addr)
- sock.send_pyobj((task_id, ACK, self.addr))
+ #TODO: addrs[0] -> public_addr
+ sock.send_pyobj((task_id, run_id, ACK, (self.addrs[0], run_id)))
else:
sock = False
try:
val = spec.func(*args, **kwargs)
except Exception, error:
- sock and sock.send_pyobj((task_id, RAISE, error))
+ print 'worker %s %r to %s:%s' % \
+ (dt(RAISE), error, task_id, run_id)
+ sock and sock.send_pyobj((task_id, run_id, RAISE, error))
raise
if should_yield(val):
try:
for item in val:
- sock and sock.send_pyobj((task_id, YIELD, item))
+ print 'worker %s %r to %s:%s' % \
+ (dt(YIELD), item, task_id, run_id)
+ sock and sock.send_pyobj((task_id, run_id, YIELD, item))
except Exception, error:
- sock and sock.send_pyobj((task_id, RAISE, error))
+ print 'worker %s %r to %s:%s' % \
+ (dt(RAISE), error, task_id, run_id)
+ sock and sock.send_pyobj((task_id, run_id, RAISE, error))
else:
- sock and sock.send_pyobj((task_id, BREAK, None))
+ print 'worker %s %r to %s:%s' % \
+ (dt(BREAK), None, task_id, run_id)
+ sock and sock.send_pyobj((task_id, run_id, BREAK, None))
else:
- sock and sock.send_pyobj((task_id, RETURN, val))
+ print 'worker %s %r to %s:%s' % \
+ (dt(RETURN), val, task_id, run_id)
+ sock and sock.send_pyobj((task_id, run_id, RETURN, val))
+
+ def run(self):
+ self.sock = self.context.socket(zmq.PULL)
+ self.fanout_sock = self.context.socket(zmq.SUB)
+ self.fanout_sock.setsockopt(zmq.SUBSCRIBE, '')
+ # bind addresses
+ for addr in self.addrs:
+ self.sock.bind(addr)
+ for addr in self.fanout_addrs:
+ self.fanout_sock.bind(addr)
+ # serve both sockets
+ def serve(sock):
+ while self.running:
+ spawn(self.run_task, *sock.recv_pyobj())
+ joinall([spawn(serve, self.sock), spawn(serve, self.fanout_sock)])
class Customer(Communicator):
- replies = None
+ addr = None
+ sock = None
+ tunnels = None
+ tasks = None
def __init__(self, addr=None, **kwargs):
if addr is None:
- addr = 'inproc://{0}'.format(str(uuid4()))
+ addr = 'inproc://{0}'.format(uuid_str())
self.addr = addr
self.tunnels = set()
- self.lock = Semaphore()
self.tasks = {}
+ self._missing_tasks = {}
super(Customer, self).__init__(**kwargs)
def link(self, *args, **kwargs):
return Tunnel(self, *args, **kwargs)
- def run(self):
- if self.lock.locked():
- return
- with self.lock:
- sock = self.context.socket(zmq.PULL)
- sock.bind(self.addr)
- while self.running:
- task_id, do, val = sock.recv_pyobj()
- self.tasks[task_id].put(do, val)
-
def register_tunnel(self, tunnel):
"""Registers the :class:`Tunnel` object to run and ensures a socket
which pulls replies.
@@ -162,12 +204,64 @@ def register_tunnel(self, tunnel):
def unregister_tunnel(self, tunnel):
"""Unregisters the :class:`Tunnel` object."""
self.tunnels.remove(tunnel)
+ if self.sock is not None and not self.tunnels:
+ self.sock.close()
def register_task(self, task):
- self.tasks[task.id] = task
+ try:
+ self.tasks[task.id][task.run_id] = task
+ except KeyError:
+ self.tasks[task.id] = {task.run_id: task}
+ self._restore_missing_messages(task)
def unregister_task(self, task):
- assert self.tasks.pop(task.id) is task
+ assert self.tasks[task.id].pop(task.run_id) is task
+ if task.run_id is None or not self.tasks[task.id]:
+ del self.tasks[task.id]
+
+ def _restore_missing_messages(self, task):
+ try:
+ missing = self._missing_tasks[task.id].pop(task.run_id)
+ except KeyError:
+ return
+ if not self._missing_tasks[task.id]:
+ del self._missing_tasks[task.id]
+ try:
+ while missing.queue:
+ task.put(*missing.queue.get(block=False))
+ except Empty:
+ pass
+
+ def run(self):
+ assert self.sock is None
+ self.sock = self.context.socket(zmq.PULL)
+ self.sock.bind(self.addr)
+ while self.tunnels:
+ try:
+ task_id, run_id, do, val = self.sock.recv_pyobj()
+ except zmq.ZMQError:
+ continue
+ if do == ACK:
+ run_id = None
+ try:
+ tasks = self.tasks[task_id]
+ except KeyError:
+ # drop message
+ continue
+ try:
+ task = tasks[run_id]
+ except KeyError:
+ try:
+ task = self._missing_tasks[task_id][run_id]
+ except KeyError:
+ # tasks to collect missing messages
+ task = Task(None, task_id, run_id)
+ if task_id not in self._missing_tasks:
+ self._missing_tasks[task_id] = {run_id: task}
+ elif run_id not in self._missing_tasks[task_id]:
+ self._missing_tasks[task_id][run_id] = task
+ task.put(do, val)
+ self.sock = None
class Tunnel(object):
@@ -183,20 +277,16 @@ class Tunnel(object):
"""
def __init__(self, customer, workers, return_task=False):
+ self._znm_verify_workers(workers)
self._znm_customer = customer
- workers = self._znm_verify_workers(workers)
self._znm_workers = workers
self._znm_blueprint = workers[0].blueprint
self._znm_return_task = return_task
self._znm_sockets = {}
self._znm_reflect(self._znm_blueprint)
def _znm_verify_workers(self, workers):
- if isinstance(workers, Worker):
- worker = workers
- workers = [worker]
- else:
- worker = workers[0]
+ worker = workers[0]
for other_worker in workers[1:]:
if worker.fingerprint != other_worker.fingerprint:
raise ValueError('All workers must have same fingerprint')
@@ -210,16 +300,22 @@ def _znm_reflect(self, blueprint):
setattr(self, fn, functools.partial(self._znm_invoke, fn))
def _znm_invoke(self, fn, *args, **kwargs):
+ """Invokes remote function."""
+ ack_task = Task(self)
spec = self._znm_blueprint[fn]
- task = Task(self)
sock = self._znm_sockets[zmq.PUB if spec.fanout else zmq.PUSH]
- sock.send_pyobj((self._znm_customer.addr, task.id, fn, args, kwargs))
+ sock.send_pyobj((
+ self._znm_customer.addr, ack_task.id, fn, args, kwargs))
if not spec.reply:
+ # immediately if workers won't reply
return
if not self._znm_customer.running:
spawn(self._znm_customer.run)
- task.prepare()
- return task if self._znm_return_task else task()
+ tasks = ack_task.acknowledge(fanout=spec.fanout)
+ if self._znm_return_task:
+ return tasks if spec.fanout else tasks[0]
+ else:
+ return (task() for task in tasks) if spec.fanout else tasks[0]()
def __enter__(self):
self._znm_customer.register_tunnel(self)
@@ -240,25 +336,51 @@ def __exit__(self, error, error_type, traceback):
class Task(object):
- def __init__(self, tunnel, id=None):
+ def __init__(self, tunnel, id=None, run_id=None):
self.tunnel = tunnel
- self.customer = tunnel._znm_customer
- self.id = str(uuid4()) if id is None else id
+ self.customer = tunnel._znm_customer if tunnel is not None else None
+ self.id = uuid_str() if id is None else id
+ self.run_id = run_id
self.queue = Queue()
- def prepare(self):
+ def acknowledge(self, fanout=False, timeout=0.01):
self.customer.register_task(self)
- do, val = self.queue.get()
- assert do == ACK
- self.worker_addr = val
+ msgs = []
+ with Timeout(timeout, False):
+ while True:
+ msgs.append(self.queue.get())
+ if not fanout:
+ break
+ self.customer.unregister_task(self)
+ if not msgs:
+ raise NoWorkerError('There are no workers which respond')
+ if fanout:
+ tasks = []
+ for do, (worker_addr, run_id) in msgs:
+ assert do == ACK
+ each_task = Task(self.tunnel, self.id, run_id)
+ each_task.worker_addr = worker_addr
+ tasks.append(each_task)
+ self.customer.register_task(each_task)
+ return tasks
+ else:
+ do, val = msgs[0]
+ assert len(msgs) == 1
+ assert do == ACK
+ self.worker_addr, self.run_id = val
+ self.customer.register_task(self)
+ return [self]
def put(self, do, val):
+ print 'task(%s:%s) recv %s %r' % \
+ (self.id, self.run_id, dt(do), val)
self.queue.put((do, val))
def __call__(self):
do, val= self.queue.get()
if do in (RETURN, RAISE):
self.customer.unregister_task(self)
+ assert do != ACK
if do == RETURN:
return val
elif do == RAISE:
View
16 zeronimo/exceptions.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+"""
+ zeronimo.exceptions
+ ~~~~~~~~~~~~~~~~~~~
+
+ :copyright: (c) 2013 by Heungsub Lee
+ :license: BSD, see LICENSE for more details.
+"""
+import gevent
+
+
+class ZeronimoError(RuntimeError): pass
+class NoWorkerError(ZeronimoError): pass
+
+#: an alias for :exc:`gevent.Timeout`.
+TimeoutError = gevent.Timeout
View
3 zeronimo/functional.py
@@ -9,8 +9,11 @@
:license: BSD, see LICENSE for more details.
"""
from collections import Iterable, Sequence, Set, Mapping, namedtuple
+import functools
import hashlib
+from gevent.coros import Semaphore
+
Spec = namedtuple('Spec', ['func', 'fanout', 'reply'])
View
111 zeronimotests.py
@@ -1,22 +1,24 @@
import functools
import os
+import textwrap
import uuid
from decorator import decorator
import gevent
-from gevent import joinall, killall, spawn
+from gevent import joinall, killall, spawn, Timeout
import pytest
import zmq.green as zmq
import zeronimo
zmq_context = zmq.Context()
-gevent.hub.get_hub().print_exception = lambda *a, **k: 'do not print exception'
+#gevent.hub.get_hub().print_exception = lambda *a, **k: 'do not print exception'
@decorator
def green(f, *args, **kwargs):
+ print
return spawn(f, *args, **kwargs).get()
@@ -28,7 +30,7 @@ def busywait(while_, until=False, timeout=None):
gevent.sleep(0.001)
-def zmq_bound(addr, socket_type, context=zmq_context):
+def is_connectable(addr, socket_type, context=zmq_context):
"""Checks that the address is connectable."""
try:
context.socket(socket_type).connect(addr)
@@ -38,9 +40,13 @@ def zmq_bound(addr, socket_type, context=zmq_context):
return True
-def ensure_worker(worker):
- spawn(worker.run)
- busywait(lambda: zmq_bound(worker.addr, zmq.PUSH), until=True)
+def start_workers(workers):
+ waits = []
+ for worker in workers:
+ spawn(worker.run)
+ until = lambda: is_connectable(worker.addrs[0], zmq.PUSH)
+ waits.append(spawn(busywait, until, until=True, timeout=1))
+ joinall(waits)
'''
@@ -113,33 +119,28 @@ def launch_rocket(self):
@zeronimo.register(fanout=True)
def rycbar123(self):
- for word in 'Run, you clever boy; and remember.'.split():
+ for word in 'run, you clever boy; and remember.'.split():
yield word
-
-# fixtures
-
-
-@pytest.fixture
-def worker():
- app = Application()
- return zeronimo.Worker(app, context=zmq_context)
-
-
-@pytest.fixture
-def worker2():
- app = Application()
- return zeronimo.Worker(app, context=zmq_context)
+ @zeronimo.register
+ def sleep(self):
+ gevent.sleep(0.1)
+ return 'slept'
-@pytest.fixture
-def customer():
- return zeronimo.Customer(context=zmq_context)
+# fixtures
-@pytest.fixture
-def customer2():
- return zeronimo.Customer(context=zmq_context)
+for x in xrange(4):
+ exec(textwrap.dedent('''
+ @pytest.fixture
+ def worker{x}():
+ app = Application()
+ return zeronimo.Worker(app, context=zmq_context)
+ @pytest.fixture
+ def customer{x}():
+ return zeronimo.Customer(context=zmq_context)
+ ''').format(x=x if x else ''))
# tests
@@ -190,7 +191,7 @@ class Nothing(object): pass
def test_default_addr(customer, worker):
- assert worker.addr.startswith('inproc://')
+ assert worker.addrs[0].startswith('inproc://')
assert customer.addr.startswith('inproc://')
@@ -207,19 +208,25 @@ def run(self):
@green
def test_tunnel(customer, worker):
- ensure_worker(worker)
+ start_workers([worker])
assert len(customer.tunnels) == 0
with customer.link([worker]) as tunnel:
assert len(customer.tunnels) == 1
assert len(customer.tunnels) == 0
- with customer.link([worker]) as tunnel1, customer.link([worker]) as tunnel2:
+ with customer.link([worker]) as tunnel1, \
+ customer.link([worker]) as tunnel2:
+ assert not customer.running
assert len(customer.tunnels) == 2
+ tunnel1.add(0, 0)
+ assert customer.running
assert len(customer.tunnels) == 0
+ busywait(lambda: customer.running, timeout=1)
+ assert not customer.running
@green
def test_return(customer, worker):
- ensure_worker(worker)
+ start_workers([worker])
with customer.link([worker]) as tunnel:
assert tunnel.add(1, 1) == 'cutie'
assert tunnel.add(2, 2) == 'cutie'
@@ -232,7 +239,7 @@ def test_return(customer, worker):
@green
def test_yield(customer, worker):
- ensure_worker(worker)
+ start_workers([worker])
with customer.link([worker]) as tunnel:
assert len(list(tunnel.jabberwocky())) == 4
assert list(tunnel.xrange()) == [0, 1, 2, 3, 4]
@@ -249,7 +256,7 @@ def test_yield(customer, worker):
@green
def test_raise(customer, worker):
- ensure_worker(worker)
+ start_workers([worker])
with customer.link([worker]) as tunnel:
with pytest.raises(ZeroDivisionError):
tunnel.divide_by_zero()
@@ -262,21 +269,22 @@ def test_raise(customer, worker):
@green
-def test_2to1(customer, customer2, worker):
- ensure_worker(worker)
+def test_2to1(customer1, customer2, worker):
+ start_workers([worker])
def test(tunnel):
assert tunnel.add(1, 1) == 'cutie'
assert len(list(tunnel.jabberwocky())) == 4
with pytest.raises(ZeroDivisionError):
tunnel.divide_by_zero()
- with customer.link([worker]) as tunnel, customer2.link([worker]) as tunnel2:
- joinall([spawn(test, tunnel), spawn(test, tunnel2)])
+ with customer1.link([worker]) as tunnel1, \
+ customer2.link([worker]) as tunnel2:
+ joinall([spawn(test, tunnel1), spawn(test, tunnel2)])
@green
-def test_1to2(customer, worker, worker2):
- joinall([spawn(ensure_worker, worker), spawn(ensure_worker, worker2)])
- with customer.link([worker, worker2], return_task=True) as tunnel:
+def test_1to2(customer, worker1, worker2):
+ start_workers([worker1, worker2])
+ with customer.link([worker1, worker2], return_task=True) as tunnel:
task1 = tunnel.add(1, 1)
task2 = tunnel.add(2, 2)
assert task1() == 'cutie'
@@ -285,6 +293,29 @@ def test_1to2(customer, worker, worker2):
@green
+def test_fanout(customer, worker1, worker2):
+ start_workers([worker1, worker2])
+ with customer.link([worker1, worker2]) as tunnel:
+ for rycbar123 in tunnel.rycbar123():
+ assert rycbar123.next() == 'run,'
+ assert rycbar123.next() == 'you'
+ assert rycbar123.next() == 'clever'
+ assert rycbar123.next() == 'boy;'
+ assert rycbar123.next() == 'and'
+ assert rycbar123.next() == 'remember.'
+
+
+@green
+def test_slow(customer, worker):
+ start_workers([worker])
+ with customer.link([worker]) as tunnel:
+ with pytest.raises(Timeout):
+ with Timeout(0.1):
+ tunnel.sleep()
+ assert tunnel.sleep() == 'slept'
+
+
+@green
def test_link_to_different_workers(customer, worker):
worker2 = zeronimo.Worker(2)
with pytest.raises(ValueError):

0 comments on commit 846fec0

Please sign in to comment.