Permalink
Browse files

implement tunnel which links to known address

  • Loading branch information...
1 parent d7524d8 commit 4cfb40a55c92c995cf2c57cd2f87bd6268f250b3 @sublee committed Apr 5, 2013
Showing with 98 additions and 74 deletions.
  1. +1 −1 setup.py
  2. +94 −70 zeronimo.py
  3. +3 −3 zeronimotests.py
View
@@ -51,7 +51,7 @@ def whoami(self):
customer = zeronimo.Customer()
customer.bind('ipc://customer')
- with customer.link(['ipc://worker'], ['ipc://worker_fanout']) as tunnel:
+ with customer.link('ipc://worker', 'ipc://worker_fanout') as tunnel:
for result in tunnel(fanout=True).whoami():
print 'hostname=', result.next()
print 'public address=', result.next()
View
@@ -6,10 +6,15 @@
:copyright: (c) 2013 by Heungsub Lee
:license: BSD, see LICENSE for more details.
"""
-from collections import namedtuple, Iterable, Sequence, Set, Mapping
+from collections import (
+ defaultdict, namedtuple, Iterable, Sequence, Set, Mapping)
from contextlib import contextmanager, nested
import functools
import hashlib
+try:
+ import cPickle as pickle
+except ImportError:
+ import pickle
from types import MethodType
import uuid
@@ -24,24 +29,17 @@
__all__ = []
-# socket type helpers
-SOCKET_TYPE_NAMES = {zmq.REQ: 'REQ', zmq.REP: 'REP', zmq.DEALER: 'DEALER',
- zmq.ROUTER: 'ROUTER', zmq.PUB: 'PUB', zmq.SUB: 'SUB',
- zmq.XPUB: 'XPUB', zmq.XSUB: 'XSUB', zmq.PUSH: 'PUSH',
- zmq.PULL: 'PULL', zmq.PAIR: 'PAIR'}
-REQ_REP = (zmq.REQ, zmq.REP)
-PUB_SUB = (zmq.PUB, zmq.SUB)
-PUSH_PULL = (zmq.PUSH, zmq.PULL)
+# exceptions
-# exceptions
class ZeronimoError(RuntimeError): pass
class ZeronimoWarning(RuntimeWarning): pass
class AcceptanceError(ZeronimoError): pass
-class SubscriptionWarning(ZeronimoWarning): pass
# message frames
+
+
class Invocation(namedtuple('Invocation', [
'name', 'args', 'kwargs', 'task_id', 'customer_addr'])):
@@ -61,7 +59,7 @@ def __repr__(self):
return '{}({}, {!r}, {!r}, {!r}, {!r})'.format(*args)
-# methods
+# reply methods
ACCEPT = 1
REJECT = 0
RETURN = 10
@@ -107,6 +105,22 @@ def should_yield(val):
not isinstance(val, (Sequence, Set, Mapping)))
+def ensure_list(val):
+ if val is None:
+ return []
+ return list(val) if isinstance(val, (Sequence, Set)) else [val]
+
+
+def send(sock, obj, flags=0, topic=''):
+ msg = pickle.dumps(obj)
+ return sock.send(topic + msg, flags)
+
+
+def recv(sock, flags=0):
+ msg = sock.recv(flags)
+ return pickle.loads(msg)
+
+
class Base(object):
"""Manages ZeroMQ sockets."""
@@ -162,27 +176,23 @@ class Worker(Base):
functions = None
fanout_sock = None
fanout_addrs = None
- fanout_filters = None
+ fanout_topic = None
- def __init__(self, obj, bind=None, bind_fanout=None, subscribe=None,
+ def __init__(self, obj, bind=None, bind_fanout=None, fanout_topic='',
**kwargs):
super(Worker, self).__init__(**kwargs)
self.functions = collect_remote_functions(obj)
bind and self.bind(bind)
bind_fanout and self.bind_fanout(bind_fanout)
- if subscribe is not None:
- self.subscribe(subscribe)
-
- def possible_addrs(self, sock_type):
- return self.addrs if sock_type == zmq.PULL else self.fanout_addrs
+ self.fanout_topic = fanout_topic
+ self.subscribe(fanout_topic)
def reset_sockets(self):
super(Worker, self).reset_sockets()
if self.fanout_sock is not None:
self.fanout_sock.close()
self.fanout_sock = self.context.socket(zmq.SUB)
self.fanout_addrs = set()
- self.fanout_filters = set()
def bind_fanout(self, addr):
self.fanout_sock.bind(addr)
@@ -192,16 +202,11 @@ def unbind_fanout(self, addr):
self.fanout_sock.unbind(addr)
self.fanout_addrs.remove(addr)
- def subscribe(self, fanout_filter):
- self.fanout_sock.setsockopt(zmq.SUBSCRIBE, fanout_filter)
- self.fanout_filters.add(fanout_filter)
-
- def unsubscribe(self, fanout_filter):
- self.fanout_sock.setsockopt(zmq.UNSUBSCRIBE, fanout_filter)
- try:
- self.fanout_filters.remove(fanout_filter)
- except KeyError:
- pass
+ def subscribe(self, fanout_topic):
+ if self.fanout_topic is not None:
+ self.fanout_sock.setsockopt(zmq.UNSUBSCRIBE, self.fanout_topic)
+ self.fanout_sock.setsockopt(zmq.SUBSCRIBE, fanout_topic)
+ self.fanout_topic = fanout_topic
def run_task(self, invocation):
run_id = uuid_str()
@@ -215,35 +220,32 @@ def run_task(self, invocation):
else:
sock = self.context.socket(zmq.PUSH)
sock.connect(invocation.customer_addr)
- sock.send_pyobj(Reply(ACCEPT, None, *meta))
+ send(sock, Reply(ACCEPT, None, *meta))
try:
val = self.functions[name](*args, **kwargs)
except Exception, error:
print 'worker send %r' % (Reply(RAISE, error, *meta),)
- sock and sock.send_pyobj(Reply(RAISE, error, *meta))
+ sock and send(sock, Reply(RAISE, error, *meta))
raise
if should_yield(val):
try:
for item in val:
print 'worker send %r' % (Reply(YIELD, item, *meta),)
- sock and sock.send_pyobj(Reply(YIELD, item, *meta))
+ sock and send(sock, Reply(YIELD, item, *meta))
except Exception, error:
print 'worker send %r' % (Reply(RAISE, error, *meta),)
- sock and sock.send_pyobj(Reply(RAISE, error, *meta))
+ sock and send(sock, Reply(RAISE, error, *meta))
else:
print 'worker send %r' % (Reply(BREAK, None, *meta),)
- sock and sock.send_pyobj(Reply(BREAK, None, *meta))
+ sock and send(sock, Reply(BREAK, None, *meta))
else:
print 'worker send %r' % (Reply(RETURN, val, *meta),)
- sock and sock.send_pyobj(Reply(RETURN, val, *meta))
+ sock and send(sock, Reply(RETURN, val, *meta))
def run(self):
- if not self.fanout_filters:
- from warnings import warn
- warn('Didn\'t subscribe any topic', SubscriptionWarning)
def serve(sock):
while self.running:
- spawn(self.run_task, sock.recv_pyobj())
+ spawn(self.run_task, recv(sock))
joinall([spawn(serve, self.sock), spawn(serve, self.fanout_sock)])
@@ -304,7 +306,7 @@ def _restore_missing_messages(self, task):
def run(self):
while self.tunnels:
try:
- reply = self.sock.recv_pyobj()
+ reply = recv(self.sock)
except zmq.ZMQError:
continue
task_id = reply.task_id
@@ -338,22 +340,40 @@ class Tunnel(object):
request of RPC through the customer's sockets.
:param customer: the :class:`Customer` object.
- :param workers: the :class:`Worker` objects.
- :param return_task: if set to ``True``, the remote functions return a
- :class:`Task` object instead of received value.
- :type return_task: bool
+ :param dests: the destinations.
+ :param fanout_topic: the filter the workers are subscribing.
+
+ :type dests: array of :class:`Worker` or address
"""
- def __init__(self, customer, workers,
+ def __init__(self, customer, dests, fanout_topic='',
wait=True, fanout=False, as_task=False):
self._znm_customer = customer
- self._znm_workers = workers
+ self._znm_addrs, self._znm_fanout_addrs = self._znm_merge_dests(dests)
+ self._znm_fanout_topic = fanout_topic
self._znm_sockets = {}
# options
self._znm_wait = wait
self._znm_fanout = fanout
self._znm_as_task = as_task
+ def _znm_merge_dests(self, dests):
+ dests = ensure_list(dests)
+ merged_addrs = set()
+ merged_fanout_addrs = set()
+ for dest in dests:
+ if isinstance(dest, Worker):
+ addrs = dest.addrs
+ fanout_addrs = dest.fanout_addrs
+ elif isinstance(dest, tuple):
+ addrs, fanout_addrs = dest
+ else:
+ raise TypeError('{!r} is not allowed destination '
+ 'type'.format(dest))
+ merged_addrs.update(addrs)
+ merged_fanout_addrs.update(fanout_addrs)
+ return merged_addrs, merged_fanout_addrs
+
def _znm_is_alive(self):
return self in self._znm_customer.tunnels
@@ -363,15 +383,36 @@ def _znm_invoke(self, name, *args, **kwargs):
if self._znm_wait else None
task = Task(self)
sock = self._znm_sockets[zmq.PUB if self._znm_fanout else zmq.PUSH]
- print 'tunnel send %r' % (Invocation(name, args, kwargs, task.id, customer_addr),)
- sock.send_pyobj(Invocation(name, args, kwargs, task.id, customer_addr))
+ invocation = Invocation(name, args, kwargs, task.id, customer_addr)
+ print 'tunnel send %r' % (invocation,)
+ topic = self._znm_fanout_topic if self._znm_fanout else ''
+ send(sock, invocation, topic=topic)
if not self._znm_wait:
# immediately if workers won't wait
return
if not self._znm_customer.running:
spawn(self._znm_customer.run)
return task.collect()
+ def __getattr__(self, attr):
+ return functools.partial(self._znm_invoke, attr)
+
+ def __enter__(self):
+ self._znm_customer.register_tunnel(self)
+ for socket_type, addrs in [(zmq.PUSH, self._znm_addrs),
+ (zmq.PUB, self._znm_fanout_addrs)]:
+ sock = self._znm_customer.context.socket(socket_type)
+ self._znm_sockets[socket_type] = sock
+ for addr in addrs:
+ sock.connect(addr)
+ return self
+
+ def __exit__(self, error, error_type, traceback):
+ for sock in self._znm_sockets.viewvalues():
+ sock.close()
+ self._znm_sockets.clear()
+ self._znm_customer.unregister_tunnel(self)
+
def __call__(self, wait=None, fanout=None, as_task=None):
"""Creates a :class:`Tunnel` object which follows same consumer and
workers but replaced options.
@@ -383,29 +424,12 @@ def __call__(self, wait=None, fanout=None, as_task=None):
if as_task is None:
as_task = self._znm_as_task
opts = (wait, fanout, as_task)
- tunnel = Tunnel(self._znm_customer, self._znm_workers, *opts)
+ tunnel = Tunnel(self._znm_customer, [], self._znm_fanout_topic, *opts)
+ tunnel._znm_addrs = self._znm_addrs
+ tunnel._znm_fanout_addrs = self._znm_fanout_addrs
tunnel._znm_sockets = self._znm_sockets
return tunnel
- def __getattr__(self, attr):
- return functools.partial(self._znm_invoke, attr)
-
- def __enter__(self):
- self._znm_customer.register_tunnel(self)
- for send_type, recv_type in [PUSH_PULL, PUB_SUB]:
- sock = self._znm_customer.context.socket(send_type)
- for worker in self._znm_workers:
- for addr in worker.possible_addrs(recv_type):
- sock.connect(addr)
- self._znm_sockets[send_type] = sock
- return self
-
- def __exit__(self, error, error_type, traceback):
- for sock in self._znm_sockets.viewvalues():
- sock.close()
- self._znm_sockets.clear()
- self._znm_customer.unregister_tunnel(self)
-
class Task(object):
@@ -429,7 +453,7 @@ def collect(self, timeout=0.01):
break
self.customer.unregister_task(self)
if not replies:
- raise AcceptanceError('No workers which accepted')
+ raise AcceptanceError('Failed to find workers that accepted')
if self.tunnel._znm_fanout:
tasks = []
for reply in replies:
View
@@ -14,7 +14,7 @@
zmq_context = zmq.Context()
-gevent.hub.get_hub().print_exception = lambda *a, **k: 'do not print exception'
+#gevent.hub.get_hub().print_exception = lambda *a, **k: 'do not print exception'
@decorator
@@ -287,7 +287,7 @@ def test_slow(customer, worker):
@green
-def _test_link_to_addrs(customer, worker):
+def test_link_to_addrs(customer, worker):
start_workers([worker])
- with customer.link(worker.addrs) as tunnel:
+ with customer.link([(worker.addrs, worker.fanout_addrs)]) as tunnel:
assert tunnel.add(1, 1) == 'cutie'

0 comments on commit 4cfb40a

Please sign in to comment.