/
__init__.py
144 lines (125 loc) · 4.89 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# -*- coding: utf-8 -*-
import asyncio
import aiozmq.rpc
import aioredis
import logging
from collections import defaultdict
from ircb.config import settings
logger = logging.getLogger('dispatcher')
class Handler(aiozmq.rpc.AttrHandler):
def __init__(self, dispatcher):
self._dispatcher = dispatcher
# lock for registering subscriber
self._lock = asyncio.Lock()
@aiozmq.rpc.method
def send(self, signal, data, taskid=None):
try:
signals = [signal, '__all__']
for s in signals:
for callback in self._dispatcher._signal_listeners.get(s, []):
callback(signal, data, taskid)
logger.debug('SEND: {} {} {}'.format(signal, data, taskid))
except Exception as e:
logger.error('SEND ERROR: {} {} {} {}'.format(
e, signal, data, taskid), exc_info=True)
@asyncio.coroutine
@aiozmq.rpc.method
def register_sub(self, subscriber_addr, key):
yield from self._lock.acquire()
try:
connections = self._dispatcher.publisher.transport.connections()
if subscriber_addr in connections:
self._lock.release()
return
self._dispatcher.publisher.transport.connect(subscriber_addr)
redis = yield from aioredis.create_redis(
(settings.REDIS_HOST, settings.REDIS_PORT)
)
yield from redis.set(key, 1)
redis.close()
finally:
self._lock.release()
class Dispatcher(object):
def __init__(self, role, loop=None):
self.loop = loop or asyncio.get_event_loop()
self.role = role
self._signal_listeners = defaultdict(set)
self.handler = Handler(self)
self.subscriber = self.publisher = None
self.lock = asyncio.Lock()
self.queue = asyncio.Queue(loop=self.loop)
asyncio.Task(self.lock.acquire())
asyncio.Task(self.setup_pubsub())
@asyncio.coroutine
def process_queue(self):
while True:
while self.lock.locked():
yield from asyncio.sleep(0.01)
continue
try:
(signal, data, taskid) = self.queue.get_nowait()
yield from self._send(signal, data, taskid)
except asyncio.QueueEmpty:
break
@asyncio.coroutine
def setup_pubsub(self):
redis = yield from aioredis.create_redis(
(settings.REDIS_HOST, settings.REDIS_PORT)
)
if self.role == 'stores':
bind_addr = settings.SUBSCRIBER_ENDPOINTS[self.role]
else:
bind_addr = 'tcp://{host}:*'.format(host=settings.INTERNAL_HOST)
self.subscriber = yield from aiozmq.rpc.serve_pubsub(
self.handler, subscribe='',
bind=bind_addr,
log_exceptions=True)
subscriber_addr = list(self.subscriber.transport.bindings())[0]
self.publisher = yield from aiozmq.rpc.connect_pubsub()
if self.role == 'storeclient':
self.publisher.transport.connect(
settings.SUBSCRIBER_ENDPOINTS['stores'])
_key = 'SUBSCRIBER_REGISTERED_{}'.format(subscriber_addr)
ret = 0
yield from redis.set(_key, ret)
while ret != b'1':
yield from self.publisher.publish(
'register_sub'
).register_sub(
subscriber_addr, _key
)
ret = yield from redis.get(_key)
yield from asyncio.sleep(0.01)
self.lock.release()
redis.close()
@property
def subscriber_endpoints(self):
return [endpoint for role, endpoint in
settings.SUBSCRIBER_ENDPOINTS.items()
if role != self.role]
def send(self, signal, data, taskid=None):
asyncio.Task(self.enqueue((signal, data, taskid)))
@asyncio.coroutine
def enqueue(self, data):
empty = self.queue.empty()
yield from self.queue.put(data)
if empty:
asyncio.Task(self.process_queue())
@asyncio.coroutine
def _send(self, signal, data, taskid=None):
logger.debug('PUBLISH from %s: %s' %
(self.role, (signal, data, taskid)))
yield from self.publisher.publish(signal).send(signal, data, taskid)
def register(self, callback, signal=None):
try:
signal = signal or '__all__'
if callback not in self._signal_listeners.get('__all__', []):
callbacks = self._signal_listeners[signal]
callbacks.add(callback)
logger.debug('REGISTER: {} {}'.format(callback, signal))
except Exception as e:
logger.error('REGISTER ERROR: {} {} {}'.format(
e, callback, signal), exc_info=True)
def run_forever(self):
logger.info('Running stores...')
self.loop.run_forever()