diff --git a/graphql_subscriptions/__init__.py b/graphql_subscriptions/__init__.py index dfb2e48..6a17301 100644 --- a/graphql_subscriptions/__init__.py +++ b/graphql_subscriptions/__init__.py @@ -3,4 +3,3 @@ from .subscription_transport_ws import SubscriptionServer __all__ = ['RedisPubsub', 'SubscriptionManager', 'SubscriptionServer'] - diff --git a/graphql_subscriptions/executors/__init__.py b/graphql_subscriptions/executors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graphql_subscriptions/executors/asyncio.py b/graphql_subscriptions/executors/asyncio.py new file mode 100644 index 0000000..6142c22 --- /dev/null +++ b/graphql_subscriptions/executors/asyncio.py @@ -0,0 +1,82 @@ +from __future__ import absolute_import + +import asyncio +from websockets import ConnectionClosed + +try: + from asyncio import ensure_future +except ImportError: + # ensure_future is only implemented in Python 3.4.4+ + # Reference: https://github.com/graphql-python/graphql-core/blob/master/graphql/execution/executors/asyncio.py + def ensure_future(coro_or_future, loop=None): + """Wrap a coroutine or an awaitable in a future. + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, asyncio.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif asyncio.iscoroutine(coro_or_future): + if loop is None: + loop = asyncio.get_event_loop() + task = loop.create_task(coro_or_future) + if task._source_traceback: + del task._source_traceback[-1] + return task + else: + raise TypeError( + 'A Future, a coroutine or an awaitable is required') + + +class AsyncioExecutor(object): + error = ConnectionClosed + task_cancel_error = asyncio.CancelledError + + def __init__(self, loop=None): + if loop is None: + loop = asyncio.get_event_loop() + self.loop = loop + self.futures = [] + + def ws_close(self, code): + return self.ws.close(code) + + def ws_protocol(self): + return self.ws.subprotocol + + def ws_isopen(self): + if self.ws.open: + return True + else: + return False + + def ws_send(self, msg): + return self.ws.send(msg) + + def ws_recv(self): + return self.ws.recv() + + def sleep(self, time): + if self.loop.is_running(): + return asyncio.sleep(time) + return self.loop.run_until_complete(asyncio.sleep(time)) + + @staticmethod + def kill(future): + future.cancel() + + def join(self, future=None, timeout=None): + if not isinstance(future, asyncio.Future): + return + if self.loop.is_running(): + return asyncio.wait_for(future, timeout=timeout) + return self.loop.run_until_complete( + asyncio.wait_for(future, timeout=timeout)) + + def execute(self, fn, *args, **kwargs): + result = fn(*args, **kwargs) + if isinstance(result, asyncio.Future) or asyncio.iscoroutine(result): + future = ensure_future(result, loop=self.loop) + self.futures.append(future) + return future + return result diff --git a/graphql_subscriptions/executors/django.py b/graphql_subscriptions/executors/django.py new file mode 100644 index 0000000..e69de29 diff --git a/graphql_subscriptions/executors/gevent.py b/graphql_subscriptions/executors/gevent.py new file mode 100644 index 0000000..009f444 --- /dev/null +++ b/graphql_subscriptions/executors/gevent.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import + +from geventwebsocket.exceptions import WebSocketError +import gevent + + +class GeventExecutor(object): + # used to patch socket library so it doesn't block + socket = gevent.socket + error = WebSocketError + + def __init__(self): + self.greenlets = [] + + def ws_close(self, code): + self.ws.close(code) + + def ws_protocol(self): + return self.ws.protocol + + def ws_isopen(self): + if self.ws.closed: + return False + else: + return True + + def ws_send(self, msg, **kwargs): + self.ws.send(msg, **kwargs) + + def ws_recv(self): + return self.ws.receive() + + @staticmethod + def sleep(time): + gevent.sleep(time) + + @staticmethod + def kill(greenlet): + gevent.kill(greenlet) + + @staticmethod + def join(greenlet, timeout=None): + greenlet.join(timeout) + + def join_all(self): + gevent.joinall(self.greenlets) + self.greenlets = [] + + def execute(self, fn, *args, **kwargs): + greenlet = gevent.spawn(fn, *args, **kwargs) + self.greenlets.append(greenlet) + return greenlet diff --git a/graphql_subscriptions/subscription_manager/__init__.py b/graphql_subscriptions/subscription_manager/__init__.py new file mode 100644 index 0000000..cd1440e --- /dev/null +++ b/graphql_subscriptions/subscription_manager/__init__.py @@ -0,0 +1,4 @@ +from .manager import SubscriptionManager +from .pubsub import RedisPubsub + +__all__ = ['SubscriptionManager', 'RedisPubsub'] diff --git a/graphql_subscriptions/subscription_manager.py b/graphql_subscriptions/subscription_manager/manager.py similarity index 66% rename from graphql_subscriptions/subscription_manager.py rename to graphql_subscriptions/subscription_manager/manager.py index ee05668..70930a8 100644 --- a/graphql_subscriptions/subscription_manager.py +++ b/graphql_subscriptions/subscription_manager/manager.py @@ -2,71 +2,15 @@ standard_library.install_aliases() from builtins import object from types import FunctionType -import pickle from graphql import parse, validate, specified_rules, value_from_ast, execute from graphql.language.ast import OperationDefinition from promise import Promise -import gevent -import redis from .utils import to_snake_case from .validation import SubscriptionHasSingleRootField -class RedisPubsub(object): - def __init__(self, host='localhost', port=6379, *args, **kwargs): - redis.connection.socket = gevent.socket - self.redis = redis.StrictRedis(host, port, *args, **kwargs) - self.pubsub = self.redis.pubsub() - self.subscriptions = {} - self.sub_id_counter = 0 - self.greenlet = None - - def publish(self, trigger_name, message): - self.redis.publish(trigger_name, pickle.dumps(message)) - return True - - def subscribe(self, trigger_name, on_message_handler, options): - self.sub_id_counter += 1 - try: - if trigger_name not in list(self.subscriptions.values())[0]: - self.pubsub.subscribe(trigger_name) - except IndexError: - self.pubsub.subscribe(trigger_name) - self.subscriptions[self.sub_id_counter] = [ - trigger_name, on_message_handler - ] - if not self.greenlet: - self.greenlet = gevent.spawn(self.wait_and_get_message) - return Promise.resolve(self.sub_id_counter) - - def unsubscribe(self, sub_id): - trigger_name, on_message_handler = self.subscriptions[sub_id] - del self.subscriptions[sub_id] - try: - if trigger_name not in list(self.subscriptions.values())[0]: - self.pubsub.unsubscribe(trigger_name) - except IndexError: - self.pubsub.unsubscribe(trigger_name) - if not self.subscriptions: - self.greenlet = self.greenlet.kill() - - def wait_and_get_message(self): - while True: - message = self.pubsub.get_message(ignore_subscribe_messages=True) - if message: - self.handle_message(message) - gevent.sleep(.001) - - def handle_message(self, message): - if isinstance(message['channel'], bytes): - channel = message['channel'].decode() - for sub_id, trigger_map in self.subscriptions.items(): - if trigger_map[0] == channel: - trigger_map[1](pickle.loads(message['data'])) - - class ValidationError(Exception): def __init__(self, errors): self.errors = errors @@ -79,7 +23,7 @@ def __init__(self, schema, pubsub, setup_funcs={}): self.pubsub = pubsub self.setup_funcs = setup_funcs self.subscriptions = {} - self.max_subscription_id = 0 + self.max_subscription_id = 1 def publish(self, trigger_name, payload): self.pubsub.publish(trigger_name, payload) @@ -145,11 +89,6 @@ def subscribe(self, query, operation_name, callback, variables, context, except AttributeError: channel_options = {} - # TODO: Think about this some more...the Apollo library - # let's all messages through by default, even if - # the users incorrectly uses the setup_funcs (does not - # use 'filter' or 'channel_options' keys); I think it - # would be better to raise an exception here def filter(arg1, arg2): return True @@ -181,7 +120,8 @@ def context_do_execute_handler(result): subscription_promises.append( self.pubsub. subscribe(trigger_name, on_message, channel_options).then( - lambda id: self.subscriptions[external_subscription_id].append(id) + lambda id: self.subscriptions[external_subscription_id]. + append(id) )) return Promise.all(subscription_promises).then( diff --git a/graphql_subscriptions/subscription_manager/pubsub.py b/graphql_subscriptions/subscription_manager/pubsub.py new file mode 100644 index 0000000..3b6c133 --- /dev/null +++ b/graphql_subscriptions/subscription_manager/pubsub.py @@ -0,0 +1,113 @@ +from future import standard_library +standard_library.install_aliases() +from builtins import object +import pickle +import sys + +from promise import Promise +import redis + +from ..executors.gevent import GeventExecutor +from ..executors.asyncio import AsyncioExecutor + +PY3 = sys.version_info[0] == 3 + + +class RedisPubsub(object): + def __init__(self, + host='localhost', + port=6379, + executor=GeventExecutor, + *args, + **kwargs): + + if executor == AsyncioExecutor: + try: + import aredis + except: + raise ImportError( + 'You need the redis client "aredis" installed for use w/ ' + 'asyncio') + + redis_client = aredis + else: + redis_client = redis + + # patch redis socket library so it doesn't block if using gevent + if executor == GeventExecutor: + redis_client.connection.socket = executor.socket + + self.redis = redis_client.StrictRedis(host, port, *args, **kwargs) + self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True) + + self.executor = executor() + self.backgrd_task = None + + self.subscriptions = {} + self.sub_id_counter = 0 + + def publish(self, trigger_name, message): + self.executor.execute(self.redis.publish, trigger_name, + pickle.dumps(message)) + return True + + def subscribe(self, trigger_name, on_message_handler, options): + self.sub_id_counter += 1 + + self.subscriptions[self.sub_id_counter] = [ + trigger_name, on_message_handler] + + if PY3: + trigger_name = trigger_name.encode() + + if trigger_name not in list(self.pubsub.channels.keys()): + self.executor.join(self.executor.execute(self.pubsub.subscribe, + trigger_name)) + if not self.backgrd_task: + self.backgrd_task = self.executor.execute( + self.wait_and_get_message) + + return Promise.resolve(self.sub_id_counter) + + def unsubscribe(self, sub_id): + trigger_name, on_message_handler = self.subscriptions[sub_id] + del self.subscriptions[sub_id] + + if PY3: + trigger_name = trigger_name.encode() + + if trigger_name not in list(self.pubsub.channels.keys()): + self.executor.execute(self.pubsub.unsubscribe, trigger_name) + + if not self.subscriptions: + self.backgrd_task = self.executor.kill(self.backgrd_task) + + async def _wait_and_get_message_async(self): + try: + while True: + message = await self.pubsub.get_message() + if message: + self.handle_message(message) + await self.executor.sleep(.001) + except self.executor.task_cancel_error: + return + + def _wait_and_get_message_sync(self): + while True: + message = self.pubsub.get_message() + if message: + self.handle_message(message) + self.executor.sleep(.001) + + def wait_and_get_message(self): + if hasattr(self.executor, 'loop'): + return self._wait_and_get_message_async() + return self._wait_and_get_message_sync() + + def handle_message(self, message): + + channel = message['channel'].decode() if PY3 else message['channel'] + + for sub_id, trigger_map in self.subscriptions.items(): + if trigger_map[0] == channel: + trigger_map[1](pickle.loads(message['data'])) diff --git a/graphql_subscriptions/utils.py b/graphql_subscriptions/subscription_manager/utils.py similarity index 100% rename from graphql_subscriptions/utils.py rename to graphql_subscriptions/subscription_manager/utils.py diff --git a/graphql_subscriptions/validation.py b/graphql_subscriptions/subscription_manager/validation.py similarity index 80% rename from graphql_subscriptions/validation.py rename to graphql_subscriptions/subscription_manager/validation.py index 9b0bc50..34993bf 100644 --- a/graphql_subscriptions/validation.py +++ b/graphql_subscriptions/subscription_manager/validation.py @@ -3,8 +3,7 @@ FIELD = 'Field' -# XXX from Apollo pacakge: Temporarily use this validation -# rule to make our life a bit easier. +# Temporarily use this validation rule to make our life a bit easier. class SubscriptionHasSingleRootField(ValidationRule): @@ -27,8 +26,8 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors): else: self.context.report_error( GraphQLError( - 'Apollo subscriptions do not support fragments on\ - the root field', [node])) + 'Subscriptions do not support fragments on ' + 'the root field', [node])) if num_fields > 1: self.context.report_error( GraphQLError( @@ -38,5 +37,5 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors): @staticmethod def too_many_subscription_fields_error(subscription_name): - return 'Subscription "{0}" must have only one\ - field.'.format(subscription_name) + return ('Subscription "{0}" must have only one ' + 'field.'.format(subscription_name)) diff --git a/graphql_subscriptions/subscription_transport_ws/__init__.py b/graphql_subscriptions/subscription_transport_ws/__init__.py new file mode 100644 index 0000000..1ed6b70 --- /dev/null +++ b/graphql_subscriptions/subscription_transport_ws/__init__.py @@ -0,0 +1 @@ +from .server import SubscriptionServer diff --git a/graphql_subscriptions/subscription_transport_ws/message_types.py b/graphql_subscriptions/subscription_transport_ws/message_types.py new file mode 100644 index 0000000..58f7de9 --- /dev/null +++ b/graphql_subscriptions/subscription_transport_ws/message_types.py @@ -0,0 +1,10 @@ +SUBSCRIPTION_FAIL = 'subscription_fail' +SUBSCRIPTION_END = 'subscription_end' +SUBSCRIPTION_DATA = 'subscription_data' +SUBSCRIPTION_START = 'subscription_start' +SUBSCRIPTION_SUCCESS = 'subscription_success' +KEEPALIVE = 'keepalive' +INIT = 'init' +INIT_SUCCESS = 'init_success' +INIT_FAIL = 'init_fail' +GRAPHQL_SUBSCRIPTIONS = 'graphql-subscriptions' diff --git a/graphql_subscriptions/subscription_transport_ws.py b/graphql_subscriptions/subscription_transport_ws/server.py similarity index 72% rename from graphql_subscriptions/subscription_transport_ws.py rename to graphql_subscriptions/subscription_transport_ws/server.py index a0d3908..13b2cd5 100644 --- a/graphql_subscriptions/subscription_transport_ws.py +++ b/graphql_subscriptions/subscription_transport_ws/server.py @@ -1,49 +1,72 @@ from builtins import str -from geventwebsocket import WebSocketApplication from promise import Promise -import gevent import json -SUBSCRIPTION_FAIL = 'subscription_fail' -SUBSCRIPTION_END = 'subscription_end' -SUBSCRIPTION_DATA = 'subscription_data' -SUBSCRIPTION_START = 'subscription_start' -SUBSCRIPTION_SUCCESS = 'subscription_success' -KEEPALIVE = 'keepalive' -INIT = 'init' -INIT_SUCCESS = 'init_success' -INIT_FAIL = 'init_fail' -GRAPHQL_SUBSCRIPTIONS = 'graphql-subscriptions' +from .message_types import (SUBSCRIPTION_FAIL, SUBSCRIPTION_END, + SUBSCRIPTION_DATA, SUBSCRIPTION_START, + SUBSCRIPTION_SUCCESS, KEEPALIVE, INIT, + INIT_SUCCESS, INIT_FAIL, GRAPHQL_SUBSCRIPTIONS) -class SubscriptionServer(WebSocketApplication): +class SubscriptionServer(object): def __init__(self, subscription_manager, websocket, + executor=None, keep_alive=None, on_subscribe=None, on_unsubscribe=None, on_connect=None, on_disconnect=None): - assert subscription_manager, "Must provide\ - 'subscription_manager' to websocket app constructor" + assert subscription_manager, ("Must provide\ + 'subscription_manager' to websocket app constructor") self.subscription_manager = subscription_manager self.on_subscribe = on_subscribe self.on_unsubscribe = on_unsubscribe self.on_connect = on_connect self.on_disconnect = on_disconnect - self.keep_alive = keep_alive + self.keep_alive_period = keep_alive self.connection_subscriptions = {} self.connection_context = {} - super(SubscriptionServer, self).__init__(websocket) + if executor: + self.executor = executor() + else: + self.executor = subscription_manager.pubsub.executor + + self.ws = self.executor.ws = websocket + + def _handle_sync(self): + self.on_open() - def timer(self, callback, period): while True: - callback() - gevent.sleep(period) + try: + message = self.executor.ws_recv() + except self.executor.error: + self.on_close() + break + + self.on_message(message) + + async def _handle_async(self): + self.on_open() + + while True: + try: + message = await self.executor.ws_recv() + except self.executor.error: + self.on_close() + break + + self.on_message(message) + + def handle(self): + if hasattr(self.executor, 'loop'): + return self._handle_async() + else: + return self._handle_sync() def unsubscribe(self, graphql_sub_id): self.subscription_manager.unsubscribe(graphql_sub_id) @@ -51,22 +74,40 @@ def unsubscribe(self, graphql_sub_id): if self.on_unsubscribe: self.on_unsubscribe(self.ws) + async def _timer_async(self, callback, period): + try: + while True: + callback() + await self.executor.sleep(period) + except self.executor.task_cancel_error: + return + + def _timer_sync(self, callback, period): + while True: + callback() + self.executor.sleep(period) + + def timer(self, callback, period): + if hasattr(self.executor, 'loop'): + return self._timer_async(callback, period) + return self._timer_sync(callback, period) + def on_open(self): - if self.ws.protocol is None or ( - GRAPHQL_SUBSCRIPTIONS not in self.ws.protocol): - self.ws.close(1002) + if self.executor.ws_protocol() is None or ( + GRAPHQL_SUBSCRIPTIONS not in self.executor.ws_protocol()): + self.executor.execute(self.executor.ws_close, 1002) def keep_alive_callback(): - if not self.ws.closed: + if self.executor.ws_isopen(): self.send_keep_alive() else: - gevent.kill(keep_alive_timer) + self.executor.kill(keep_alive_task) - if self.keep_alive: - keep_alive_timer = gevent.spawn(self.timer, keep_alive_callback, - self.keep_alive) + if self.keep_alive_period: + keep_alive_task = self.executor.execute( + self.timer, keep_alive_callback, self.keep_alive_period) - def on_close(self, reason): + def on_close(self): for sub_id in list(self.connection_subscriptions.keys()): self.unsubscribe(self.connection_subscriptions[sub_id]) del self.connection_subscriptions[sub_id] @@ -75,8 +116,11 @@ def on_close(self, reason): self.on_disconnect(self.ws) def on_message(self, msg): + if msg is None: return + elif hasattr(msg, 'result'): # check if future from asyncio + msg = msg.result() non_local = {'on_init_resolve': None, 'on_init_reject': None} @@ -94,6 +138,7 @@ def on_message_return_handler(message): {'errors': [{ 'message': str(e) }]}) + return sub_id = parsed_message.get('id') @@ -131,9 +176,9 @@ def subscription_start_promise_handler(init_result): 'operation_name': parsed_message.get('operation_name'), 'callback': None, 'variables': parsed_message.get('variables'), - 'context': init_result if isinstance( - init_result, dict) else - parsed_message.get('context', {}), + 'context': + init_result if isinstance(init_result, dict) else + parsed_message.get('context', {}), 'format_error': None, 'format_response': None } @@ -150,8 +195,9 @@ def subscription_start_promise_handler(init_result): def promised_params_handler(params): if not isinstance(params, dict): - error = 'Invalid params returned from\ -OnSubscribe! Return value must be an dict' + error = ('Invalid params returned from' + 'OnSubscribe! Return value must be an' + 'dict') self.send_subscription_fail( sub_id, {'errors': [{ @@ -160,7 +206,6 @@ def promised_params_handler(params): raise TypeError(error) def params_callback(error, result): - # import ipdb; ipdb.set_trace() if not error: self.send_subscription_data( sub_id, {'data': result.data}) @@ -213,11 +258,6 @@ def error_catch_handler(e): graphql_sub_id_promise_handler).catch( error_catch_handler) - # Promise from init statement (line 54) - # seems to reset between if statements - # not sure if this behavior is correct or - # not per promises A spec...need to - # investigate non_local['on_init_resolve'](Promise.resolve(True)) self.connection_context['init_promise'].then( @@ -230,7 +270,6 @@ def subscription_end_promise_handler(result): self.unsubscribe(self.connection_subscriptions[sub_id]) del self.connection_subscriptions[sub_id] - # same rationale as above non_local['on_init_resolve'](Promise.resolve(True)) self.connection_context['init_promise'].then( @@ -247,21 +286,21 @@ def subscription_end_promise_handler(result): def send_subscription_data(self, sub_id, payload): message = {'type': SUBSCRIPTION_DATA, 'id': sub_id, 'payload': payload} - self.ws.send(json.dumps(message)) + self.executor.execute(self.executor.ws_send, json.dumps(message)) def send_subscription_fail(self, sub_id, payload): message = {'type': SUBSCRIPTION_FAIL, 'id': sub_id, 'payload': payload} - self.ws.send(json.dumps(message)) + self.executor.execute(self.executor.ws_send, json.dumps(message)) def send_subscription_success(self, sub_id): message = {'type': SUBSCRIPTION_SUCCESS, 'id': sub_id} - self.ws.send(json.dumps(message)) + self.executor.execute(self.executor.ws_send, json.dumps(message)) def send_init_result(self, result): - self.ws.send(json.dumps(result)) + self.executor.execute(self.executor.ws_send, json.dumps(result)) if result.get('type') == INIT_FAIL: - self.ws.close(1011) + self.executor.execute(self.executor.ws_close, 1011) def send_keep_alive(self): message = {'type': KEEPALIVE} - self.ws.send(json.dumps(message)) + self.executor.execute(self.executor.ws_send, json.dumps(message)) diff --git a/setup.py b/setup.py index e6e8a40..39fd0b9 100644 --- a/setup.py +++ b/setup.py @@ -8,8 +8,8 @@ long_description = open('README.md').read() tests_dep = [ - 'pytest', 'pytest-mock', 'fakeredis', 'graphene', - 'flask', 'flask-graphql', 'flask-sockets', 'multiprocess', 'requests' + 'pytest', 'pytest-mock', 'graphene', 'flask', 'flask-graphql', + 'flask-sockets', 'multiprocess', 'requests' ] if sys.version_info[0] < 3: diff --git a/tests/test_subscription_manager.py b/tests/test_subscription_manager.py index 2091a3a..5b677e5 100644 --- a/tests/test_subscription_manager.py +++ b/tests/test_subscription_manager.py @@ -1,60 +1,41 @@ -from types import FunctionType import sys +from types import FunctionType -from graphql import validate, parse -from promise import Promise -import fakeredis import graphene +import os import pytest -import redis +from graphql import validate, parse +from promise import Promise from graphql_subscriptions import RedisPubsub, SubscriptionManager -from graphql_subscriptions.validation import SubscriptionHasSingleRootField - - -@pytest.fixture -def pubsub(monkeypatch): - monkeypatch.setattr(redis, 'StrictRedis', fakeredis.FakeStrictRedis) - return RedisPubsub() +from graphql_subscriptions.subscription_manager.validation import ( + SubscriptionHasSingleRootField) +from graphql_subscriptions.executors.gevent import GeventExecutor +from graphql_subscriptions.executors.asyncio import AsyncioExecutor +if os.name == 'posix' and sys.version_info[0] < 3: + import subprocess32 as subprocess +else: + import subprocess -@pytest.mark.parametrize('test_input, expected', [('test', 'test'), ({ - 1: 'test' -}, { - 1: 'test' -}), (None, None)]) -def test_pubsub_subscribe_and_publish(pubsub, test_input, expected): - def message_callback(message): - try: - assert message == expected - pubsub.greenlet.kill() - except AssertionError as e: - sys.exit(e) - - def publish_callback(sub_id): - assert pubsub.publish('a', test_input) - pubsub.greenlet.join() - p1 = pubsub.subscribe('a', message_callback, {}) - p2 = p1.then(publish_callback) - p2.get() +@pytest.fixture(scope="module") +def start_redis_server(): + try: + proc = subprocess.Popen(['redis-server']) + except FileNotFoundError: + raise RuntimeError( + "You must have redis installed in order to run these tests") + yield + proc.terminate() -def test_pubsub_subscribe_and_unsubscribe(pubsub): - def message_callback(message): - sys.exit('Message callback should not have been called') +pytestmark = pytest.mark.usefixtures('start_redis_server') - def unsubscribe_publish_callback(sub_id): - pubsub.unsubscribe(sub_id) - assert pubsub.publish('a', 'test') - try: - sub_mgr.pubsub.greenlet.join() - except AttributeError: - return - p1 = pubsub.subscribe('a', message_callback, {}) - p2 = p1.then(unsubscribe_publish_callback) - p2.get() +@pytest.fixture(params=[GeventExecutor, AsyncioExecutor]) +def pubsub(request): + return RedisPubsub(executor=request.param) @pytest.fixture @@ -95,15 +76,17 @@ def resolve_test_channel_options(self, args, context, info): @pytest.fixture -def sub_mgr(pubsub, schema): +def setup_funcs(): def filter_single(**kwargs): args = kwargs.get('args') return { 'filter_1': { - 'filter': lambda root, context: root.get('filterBoolean') == args.get('filterBoolean') + 'filter': lambda root, context: root.get( + 'filterBoolean') == args.get('filterBoolean') }, 'filter_2': { - 'filter': lambda root, context: Promise.resolve(root.get('filterBoolean') == args.get('filterBoolean')) + 'filter': lambda root, context: Promise.resolve(root.get( + 'filterBoolean') == args.get('filterBoolean')) }, } @@ -123,15 +106,49 @@ def filter_channel_options(**kwargs): def filter_context(**kwargs): return {'context_trigger': lambda root, context: context == 'trigger'} - return SubscriptionManager( - schema, - pubsub, - setup_funcs={ - 'test_filter': filter_single, - 'test_filter_multi': filter_multi, - 'test_channel_options': filter_channel_options, - 'test_context': filter_context - }) + return { + 'test_filter': filter_single, + 'test_filter_multi': filter_multi, + 'test_channel_options': filter_channel_options, + 'test_context': filter_context + } + + +@pytest.fixture +def sub_mgr(pubsub, schema, setup_funcs): + return SubscriptionManager(schema, pubsub, setup_funcs) + + +@pytest.mark.parametrize('test_input, expected', [('test', 'test'), ({ + 1: 'test'}, {1: 'test'}), (None, None)]) +def test_pubsub_subscribe_and_publish(pubsub, test_input, expected): + def message_callback(message): + try: + assert message == expected + pubsub.executor.kill(pubsub.backgrd_task) + except AssertionError as e: + sys.exit(e) + + def publish_callback(sub_id): + assert pubsub.publish('a', test_input) + pubsub.executor.join(pubsub.backgrd_task) + + p1 = pubsub.subscribe('a', message_callback, {}) + p2 = p1.then(publish_callback) + p2.get() + + +def test_pubsub_subscribe_and_unsubscribe(pubsub): + def message_callback(message): + sys.exit('Message callback should not have been called') + + def unsubscribe_publish_callback(sub_id): + pubsub.unsubscribe(sub_id) + assert pubsub.publish('a', 'test') + + p1 = pubsub.subscribe('a', message_callback, {}) + p2 = p1.then(unsubscribe_publish_callback) + p2.get() def test_query_is_valid_and_throws_error(sub_mgr): @@ -186,13 +203,13 @@ def test_subscribe_with_valid_query_and_return_root_value(sub_mgr): def callback(e, payload): try: assert payload.data.get('testSubscription') == 'good' - sub_mgr.pubsub.greenlet.kill() + sub_mgr.pubsub.executor.kill(sub_mgr.pubsub.backgrd_task) except AssertionError as e: sys.exit(e) def publish_and_unsubscribe_handler(sub_id): sub_mgr.publish('testSubscription', 'good') - sub_mgr.pubsub.greenlet.join() + sub_mgr.pubsub.executor.join(sub_mgr.pubsub.backgrd_task) sub_mgr.unsubscribe(sub_id) p1 = sub_mgr.subscribe(query, 'X', callback, {}, {}, None, None) @@ -213,14 +230,14 @@ def callback(err, payload): assert True else: assert payload.data.get('testFilter') == 'good_filter' - sub_mgr.pubsub.greenlet.kill() + sub_mgr.pubsub.executor.kill(sub_mgr.pubsub.backgrd_task) except AssertionError as e: sys.exit(e) def publish_and_unsubscribe_handler(sub_id): sub_mgr.publish('filter_1', {'filterBoolean': False}) sub_mgr.publish('filter_1', {'filterBoolean': True}) - sub_mgr.pubsub.greenlet.join() + sub_mgr.pubsub.executor.join(sub_mgr.pubsub.backgrd_task) sub_mgr.unsubscribe(sub_id) p1 = sub_mgr.subscribe(query, 'Filter1', callback, {'filterBoolean': True}, @@ -242,7 +259,7 @@ def callback(err, payload): assert True else: assert payload.data.get('testFilter') == 'good_filter' - sub_mgr.pubsub.greenlet.kill() + sub_mgr.pubsub.executor.kill(sub_mgr.pubsub.backgrd_task) except AssertionError as e: sys.exit(e) @@ -250,7 +267,7 @@ def publish_and_unsubscribe_handler(sub_id): sub_mgr.publish('filter_2', {'filterBoolean': False}) sub_mgr.publish('filter_2', {'filterBoolean': True}) try: - sub_mgr.pubsub.greenlet.join() + sub_mgr.pubsub.executor.join(sub_mgr.pubsub.backgrd_task) except: raise sub_mgr.unsubscribe(sub_id) @@ -281,13 +298,13 @@ def callback(err, payload): except AssertionError as e: sys.exit(e) if non_local['trigger_count'] == 2: - sub_mgr.pubsub.greenlet.kill() + sub_mgr.pubsub.executor.kill(sub_mgr.pubsub.backgrd_task) def publish_and_unsubscribe_handler(sub_id): sub_mgr.publish('not_a_trigger', {'filterBoolean': False}) sub_mgr.publish('trigger_1', {'filterBoolean': True}) sub_mgr.publish('trigger_2', {'filterBoolean': True}) - sub_mgr.pubsub.greenlet.join() + sub_mgr.pubsub.executor.join(sub_mgr.pubsub.backgrd_task) sub_mgr.unsubscribe(sub_id) p1 = sub_mgr.subscribe(query, 'multiTrigger', callback, @@ -338,7 +355,7 @@ def unsubscribe_and_publish_handler(sub_id): sub_mgr.unsubscribe(sub_id) sub_mgr.publish('testSubscription', 'good') try: - sub_mgr.pubsub.greenlet.join() + sub_mgr.pubsub.executor.join(sub_mgr.pubsub.backgrd_task) except AttributeError: return @@ -379,7 +396,8 @@ def unsubscribe_and_unsubscribe_handler(sub_id): p2.get() -def test_calls_the_error_callback_if_there_is_an_execution_error(sub_mgr): +def test_calls_the_error_callback_if_there_is_an_execution_error( + sub_mgr): query = 'subscription X($uga: Boolean!){\ testSubscription @skip(if: $uga)\ }' @@ -387,17 +405,17 @@ def test_calls_the_error_callback_if_there_is_an_execution_error(sub_mgr): def callback(err, payload): try: assert payload is None - assert err.message == 'Variable "$uga" of required type\ - "Boolean!" was not provided.' + assert err.message == ('Variable "$uga" of required type ' + '"Boolean!" was not provided.') - sub_mgr.pubsub.greenlet.kill() + sub_mgr.pubsub.executor.kill(sub_mgr.pubsub.backgrd_task) except AssertionError as e: sys.exit(e) def unsubscribe_and_publish_handler(sub_id): sub_mgr.publish('testSubscription', 'good') try: - sub_mgr.pubsub.greenlet.join() + sub_mgr.pubsub.executor.join(sub_mgr.pubsub.backgrd_task) except AttributeError: return sub_mgr.unsubscribe(sub_id) @@ -421,14 +439,14 @@ def callback(err, payload): try: assert err is None assert payload.data.get('testContext') == 'trigger' - sub_mgr.pubsub.greenlet.kill() + sub_mgr.pubsub.executor.kill(sub_mgr.pubsub.backgrd_task) except AssertionError as e: sys.exit(e) def unsubscribe_and_publish_handler(sub_id): sub_mgr.publish('context_trigger', 'ignored') try: - sub_mgr.pubsub.greenlet.join() + sub_mgr.pubsub.executor.join(sub_mgr.pubsub.backgrd_task) except AttributeError: return sub_mgr.unsubscribe(sub_id) @@ -445,21 +463,22 @@ def unsubscribe_and_publish_handler(sub_id): p2.get() -def test_calls_the_error_callback_if_context_func_throws_error(sub_mgr): +def test_calls_the_error_callback_if_context_func_throws_error( + sub_mgr): query = 'subscription TestContext { testContext }' def callback(err, payload): try: assert payload is None assert str(err) == 'context error' - sub_mgr.pubsub.greenlet.kill() + sub_mgr.pubsub.executor.kill(sub_mgr.pubsub.backgrd_task) except AssertionError as e: sys.exit(e) def unsubscribe_and_publish_handler(sub_id): sub_mgr.publish('context_trigger', 'ignored') try: - sub_mgr.pubsub.greenlet.join() + sub_mgr.pubsub.executor.join(sub_mgr.pubsub.backgrd_task) except AttributeError: return sub_mgr.unsubscribe(sub_id) @@ -497,40 +516,41 @@ class Subscription(graphene.ObjectType): def test_should_allow_a_valid_subscription(validation_schema): sub = 'subscription S1{ test1 }' - errors = validate(validation_schema, parse(sub), - [SubscriptionHasSingleRootField]) + errors = validate(validation_schema, + parse(sub), [SubscriptionHasSingleRootField]) assert len(errors) == 0 def test_should_allow_another_valid_subscription(validation_schema): sub = 'subscription S1{ test1 } subscription S2{ test2 }' - errors = validate(validation_schema, parse(sub), - [SubscriptionHasSingleRootField]) + errors = validate(validation_schema, + parse(sub), [SubscriptionHasSingleRootField]) assert len(errors) == 0 def test_should_not_allow_two_fields_in_the_subscription(validation_schema): sub = 'subscription S3{ test1 test2 }' - errors = validate(validation_schema, parse(sub), - [SubscriptionHasSingleRootField]) + errors = validate(validation_schema, + parse(sub), [SubscriptionHasSingleRootField]) assert len(errors) == 1 assert errors[0].message == 'Subscription "S3" must have only one field.' def test_should_not_allow_inline_fragments(validation_schema): sub = 'subscription S4{ ...on Subscription { test1 } }' - errors = validate(validation_schema, parse(sub), - [SubscriptionHasSingleRootField]) + errors = validate(validation_schema, + parse(sub), [SubscriptionHasSingleRootField]) assert len(errors) == 1 - assert errors[0].message == 'Apollo subscriptions do not support\ - fragments on the root field' + assert errors[0].message == ('Subscriptions do not support ' + 'fragments on the root field') def test_should_not_allow_fragments(validation_schema): - sub = 'subscription S5{ ...testFragment }\ - fragment testFragment on Subscription{ test2 }' - errors = validate(validation_schema, parse(sub), - [SubscriptionHasSingleRootField]) + sub = ('subscription S5{ ...testFragment }' + 'fragment testFragment on Subscription{ test2 }') + + errors = validate(validation_schema, + parse(sub), [SubscriptionHasSingleRootField]) assert len(errors) == 1 - assert errors[0].message == 'Apollo subscriptions do not support\ - fragments on the root field' + assert errors[0].message == ('Subscriptions do not support ' + 'fragments on the root field') diff --git a/tests/test_subscription_transport.py b/tests/test_subscription_transport.py index b0b57b6..3cd0583 100644 --- a/tests/test_subscription_transport.py +++ b/tests/test_subscription_transport.py @@ -6,33 +6,32 @@ from future import standard_library standard_library.install_aliases() -from builtins import object -from functools import wraps import copy import json import os import sys import threading import time +from builtins import object +from functools import wraps -from flask import Flask, request, jsonify -from flask_graphql import GraphQLView -from flask_sockets import Sockets -from geventwebsocket import WebSocketServer -from promise import Promise import queue -import fakeredis import graphene import multiprocess import pytest -import redis import requests +from flask import Flask, request, jsonify +from flask_graphql import GraphQLView +from flask_sockets import Sockets +from geventwebsocket import WebSocketServer +from promise import Promise -from graphql_subscriptions import (RedisPubsub, SubscriptionManager, - SubscriptionServer) - -from graphql_subscriptions.subscription_transport_ws import (SUBSCRIPTION_FAIL, - SUBSCRIPTION_DATA) +from graphql_subscriptions import RedisPubsub, SubscriptionManager +from graphql_subscriptions.executors.gevent import GeventExecutor +from graphql_subscriptions.subscription_transport_ws import ( + SubscriptionServer) +from graphql_subscriptions.subscription_transport_ws.message_types import ( + SUBSCRIPTION_FAIL, SUBSCRIPTION_DATA) if os.name == 'posix' and sys.version_info[0] < 3: import subprocess32 as subprocess @@ -42,7 +41,7 @@ TEST_PORT = 5000 -class PickableMock(object): +class Picklable(object): def __init__(self, return_value=None, side_effect=None, name=None): self._return_value = return_value self._side_effect = side_effect @@ -114,10 +113,33 @@ def data(): } +@pytest.fixture(scope="module") +def start_redis_server(): + try: + proc = subprocess.Popen(['redis-server']) + except FileNotFoundError: + raise RuntimeError( + "You must have redis installed in order to run these tests") + yield + proc.terminate() + + +pytestmark = pytest.mark.usefixtures('start_redis_server') + + +@pytest.fixture(params=[GeventExecutor]) +def executor(request): + return request.param + + @pytest.fixture -def pubsub(monkeypatch): - monkeypatch.setattr(redis, 'StrictRedis', fakeredis.FakeStrictRedis) - return RedisPubsub() +def pubsub(executor): + return RedisPubsub(executor=executor) + + +@pytest.fixture +def sub_server(executor): + return SubscriptionServer @pytest.fixture @@ -196,7 +218,7 @@ def on_subscribe(self, msg, params, websocket): on_sub_mock = { 'on_subscribe': - PickableMock(side_effect=promisify(on_subscribe), name='on_subscribe') + Picklable(side_effect=promisify(on_subscribe), name='on_subscribe') } return on_sub_mock, q @@ -225,22 +247,22 @@ def on_unsubscribe(self, websocket): options_mocks = { 'on_subscribe': - PickableMock(side_effect=promisify(on_subscribe), name='on_subscribe'), + Picklable(side_effect=promisify(on_subscribe), name='on_subscribe'), 'on_unsubscribe': - PickableMock(side_effect=on_unsubscribe, name='on_unsubscribe'), + Picklable(side_effect=on_unsubscribe, name='on_unsubscribe'), 'on_connect': - PickableMock( + Picklable( return_value={'test': 'test_context'}, side_effect=on_connect, name='on_connect'), 'on_disconnect': - PickableMock(side_effect=on_disconnect, name='on_disconnect') + Picklable(side_effect=on_disconnect, name='on_disconnect') } return options_mocks, q -def create_app(sub_mgr, schema, options): +def create_app(sub_mgr, schema, options, executor, sub_server): app = Flask(__name__) sockets = Sockets(app) @@ -257,7 +279,8 @@ def sub_mgr_publish(): @sockets.route('/socket') def socket_channel(websocket): - subscription_server = SubscriptionServer(sub_mgr, websocket, **options) + subscription_server = sub_server(sub_mgr, websocket, + executor, **options) subscription_server.handle() return [] @@ -269,11 +292,11 @@ def app_worker(app, port): server.serve_forever() -@pytest.fixture() -def server(sub_mgr, schema, on_sub_mock): +@pytest.fixture +def server(sub_mgr, executor, sub_server, schema, on_sub_mock): options, q = on_sub_mock - app = create_app(sub_mgr, schema, options) + app = create_app(sub_mgr, schema, options, executor, sub_server) process = multiprocess.Process( target=app_worker, kwargs={'app': app, @@ -283,11 +306,12 @@ def server(sub_mgr, schema, on_sub_mock): process.terminate() -@pytest.fixture() -def server_with_mocks(sub_mgr, schema, options_mocks): +@pytest.fixture +def server_with_mocks(sub_mgr, executor, sub_server, schema, + options_mocks): options, q = options_mocks - app = create_app(sub_mgr, schema, options) + app = create_app(sub_mgr, schema, options, executor, sub_server) process = multiprocess.Process( target=app_worker, kwargs={'app': app, @@ -298,10 +322,12 @@ def server_with_mocks(sub_mgr, schema, options_mocks): process.terminate() -@pytest.fixture() -def server_with_on_sub_handler(sub_mgr, schema, on_sub_handler): +@pytest.fixture +def server_with_on_sub_handler(sub_mgr, executor, sub_server, schema, + on_sub_handler): - app = create_app(sub_mgr, schema, on_sub_handler) + app = create_app(sub_mgr, schema, on_sub_handler, executor, + sub_server) process = multiprocess.Process( target=app_worker, kwargs={'app': app, @@ -311,10 +337,11 @@ def server_with_on_sub_handler(sub_mgr, schema, on_sub_handler): process.terminate() -@pytest.fixture() -def server_with_keep_alive(sub_mgr, schema): +@pytest.fixture +def server_with_keep_alive(sub_mgr, executor, sub_server, schema): - app = create_app(sub_mgr, schema, {'keep_alive': .250}) + app = create_app(sub_mgr, schema, {'keep_alive': .250}, executor, + sub_server) process = multiprocess.Process( target=app_worker, kwargs={'app': app, @@ -324,9 +351,9 @@ def server_with_keep_alive(sub_mgr, schema): process.terminate() -def test_raise_exception_when_create_server_and_no_sub_mgr(): +def test_raise_exception_when_create_server_and_no_sub_mgr(sub_server): with pytest.raises(AssertionError): - SubscriptionServer(None, None) + sub_server(None, None) def test_should_trigger_on_connect_if_client_connect_valid(server_with_mocks): @@ -340,7 +367,7 @@ def test_should_trigger_on_connect_if_client_connect_valid(server_with_mocks): os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) try: subprocess.check_output( - ['node', '-e', node_script], stderr=subprocess.STDOUT, timeout=.2) + ['node', '-e', node_script], stderr=subprocess.STDOUT, timeout=.3) except: mock = server_with_mocks.get_nowait() assert mock.name == 'on_connect' diff --git a/tests/test_subscription_transport_asyncio.py b/tests/test_subscription_transport_asyncio.py new file mode 100644 index 0000000..de79429 --- /dev/null +++ b/tests/test_subscription_transport_asyncio.py @@ -0,0 +1,1297 @@ +# Many, if not most of these tests rely on using a graphql subscriptions +# client. "apollographql/subscriptions-transport-ws" is used here for testing +# the graphql subscriptions server implementation. In order to run these tests, +# "cd" to the "tests" directory and "npm install". Make sure you have nodejs +# installed in your $PATH. + +from future import standard_library +standard_library.install_aliases() +import copy +import json +import os +import sys +import threading +import time +from builtins import object +from functools import wraps + +import queue +import graphene +import multiprocess +import pytest +import requests +import websockets +from promise import Promise +from sanic import Sanic, response +from sanic_graphql import GraphQLView + +from graphql_subscriptions import RedisPubsub, SubscriptionManager +from graphql_subscriptions.executors.asyncio import AsyncioExecutor +from graphql_subscriptions.subscription_transport_ws import ( + SubscriptionServer) +from graphql_subscriptions.subscription_transport_ws.message_types import ( + SUBSCRIPTION_FAIL, SUBSCRIPTION_DATA, GRAPHQL_SUBSCRIPTIONS) + +if os.name == 'posix' and sys.version_info[0] < 3: + import subprocess32 as subprocess +else: + import subprocess + +TEST_PORT = 5000 + + +class Picklable(object): + def __init__(self, return_value=None, side_effect=None, name=None): + self._return_value = return_value + self._side_effect = side_effect + self.name = name + self.called = False + self.call_count = 0 + self.call_args = set() + + def __call__(mock_self, *args, **kwargs): + mock_self.called = True + mock_self.call_count += 1 + call_args = {repr(arg) for arg in args} + call_kwargs = {repr(item) for item in kwargs} + mock_self.call_args = call_args | call_kwargs | mock_self.call_args + + if mock_self._side_effect and mock_self._return_value: + mock_self._side_effect(mock_self, *args, **kwargs) + return mock_self._return_value + elif mock_self._side_effect: + return mock_self._side_effect(mock_self, *args, **kwargs) + elif mock_self._return_value: + return mock_self._return_value + + def assert_called_once(self): + assert self.call_count == 1 + + def assert_called_with(self, *args, **kwargs): + call_args = {repr(json.loads(json.dumps(arg))) for arg in args} + call_kwargs = {repr(json.loads(json.dumps(item))) for item in kwargs} + all_call_args = call_args | call_kwargs + assert all_call_args.issubset(self.call_args) + + def assert_called_with_contains(self, arg_fragment): + assert any([arg_fragment in item for item in self.call_args]) + + +def promisify(f): + @wraps(f) + def wrapper(*args, **kwargs): + def executor(resolve, reject): + return resolve(f(*args, **kwargs)) + + return Promise(executor) + + return wrapper + + +def enqueue_output(out, queue): + with out: + for line in iter(out.readline, b''): + queue.put(line) + + +@pytest.fixture +def data(): + return { + '1': { + 'id': '1', + 'name': 'Dan' + }, + '2': { + 'id': '2', + 'name': 'Marie' + }, + '3': { + 'id': '3', + 'name': 'Jessie' + } + } + + +@pytest.fixture(scope="module") +def start_redis_server(): + try: + proc = subprocess.Popen(['redis-server']) + except FileNotFoundError: + raise RuntimeError( + "You must have redis installed in order to run these tests") + yield + proc.terminate() + + +pytestmark = pytest.mark.usefixtures('start_redis_server') + + +@pytest.fixture(params=[AsyncioExecutor]) +def executor(request): + return request.param + + +@pytest.fixture +def pubsub(monkeypatch, executor): + return RedisPubsub(executor=executor) + + +@pytest.fixture +def sub_server(executor): + return SubscriptionServer + + +@pytest.fixture +def schema(data): + class UserType(graphene.ObjectType): + id = graphene.String() + name = graphene.String() + + class Query(graphene.ObjectType): + test_string = graphene.String() + + class Subscription(graphene.ObjectType): + user = graphene.Field(UserType, id=graphene.String()) + user_filtered = graphene.Field(UserType, id=graphene.String()) + context = graphene.String() + error = graphene.String() + + def resolve_user(self, args, context, info): + id = args['id'] + name = data[args['id']]['name'] + return UserType(id=id, name=name) + + def resolve_user_filtered(self, args, context, info): + id = args['id'] + name = data[args['id']]['name'] + return UserType(id=id, name=name) + + def resolve_context(self, args, context, info): + return context + + def resolve_error(self, args, context, info): + raise Exception('E1') + + return graphene.Schema(query=Query, subscription=Subscription) + + +@pytest.fixture +def sub_mgr(pubsub, schema): + def user_filtered(**kwargs): + args = kwargs.get('args') + return { + 'user_filtered': { + 'filter': lambda root, ctx: root.get('id') == args.get('id') + } + } + + setup_funcs = {'user_filtered': user_filtered} + + return SubscriptionManager(schema, pubsub, setup_funcs) + + +@pytest.fixture +def on_sub_handler(): + def context_handler(): + raise Exception('bad') + + def on_subscribe(msg, params, websocket): + new_params = copy.deepcopy(params) + new_params.update({'context': context_handler}) + return new_params + + return {'on_subscribe': promisify(on_subscribe)} + + +@pytest.fixture +def on_sub_mock(mocker): + + mgr = multiprocess.Manager() + q = mgr.Queue() + + def on_subscribe(self, msg, params, websocket): + new_params = copy.deepcopy(params) + new_params.update({'context': msg.get('context', {})}) + q.put(self) + return new_params + + on_sub_mock = { + 'on_subscribe': + Picklable(side_effect=promisify(on_subscribe), name='on_subscribe') + } + + return on_sub_mock, q + + +@pytest.fixture +def options_mocks(mocker): + + mgr = multiprocess.Manager() + q = mgr.Queue() + + def on_subscribe(self, msg, params, websocket): + new_params = copy.deepcopy(params) + new_params.update({'context': msg.get('context', {})}) + q.put(self) + return new_params + + def on_connect(self, message, websocket): + q.put(self) + + def on_disconnect(self, websocket): + q.put(self) + + def on_unsubscribe(self, websocket): + q.put(self) + + options_mocks = { + 'on_subscribe': + Picklable(side_effect=promisify(on_subscribe), name='on_subscribe'), + 'on_unsubscribe': + Picklable(side_effect=on_unsubscribe, name='on_unsubscribe'), + 'on_connect': + Picklable( + return_value={'test': 'test_context'}, + side_effect=on_connect, + name='on_connect'), + 'on_disconnect': + Picklable(side_effect=on_disconnect, name='on_disconnect') + } + + return options_mocks, q + +def create_app(sub_mgr, schema, options, executor, sub_server): + app = Sanic(__name__) + + app.add_route(GraphQLView.as_view(schema=schema, graphiql=True), + '/graphql') + + @app.route('/publish', methods=['POST']) + async def sub_mgr_publish(request): + await sub_mgr.publish(*request.json) + return await response.json(request.json) + + async def websocket(websocket, path): + if path == '/socket': + subscription_server = sub_server(sub_mgr, websocket, + executor, **options) + await subscription_server.handle() + + ws_server = websockets.serve(websocket, 'localhost', TEST_PORT, + subprotocols=[GRAPHQL_SUBSCRIPTIONS]) + + app.add_task(ws_server) + + return app + + +def app_worker(app, port): + app.run(host="0.0.0.0", port=port) + + +@pytest.fixture +def server(sub_mgr, executor, sub_server, schema, on_sub_mock): + + options, q = on_sub_mock + app = create_app(sub_mgr, schema, options, executor, sub_server) + + process = multiprocess.Process( + target=app_worker, kwargs={'app': app, + 'port': TEST_PORT}) + process.start() + yield q + process.terminate() + + +@pytest.fixture +def server_with_mocks(sub_mgr, executor, sub_server, schema, + options_mocks): + + options, q = options_mocks + app = create_app(sub_mgr, schema, options, executor, sub_server) + + process = multiprocess.Process( + target=app_worker, kwargs={'app': app, + 'port': TEST_PORT}) + + process.start() + yield q + process.terminate() + + +@pytest.fixture +def server_with_on_sub_handler(sub_mgr, executor, sub_server, schema, + on_sub_handler): + + app = create_app(sub_mgr, schema, on_sub_handler, executor, + sub_server) + + process = multiprocess.Process( + target=app_worker, kwargs={'app': app, + 'port': TEST_PORT}) + process.start() + yield + process.terminate() + + +@pytest.fixture +def server_with_keep_alive(sub_mgr, executor, sub_server, schema): + + app = create_app(sub_mgr, schema, {'keep_alive': .250}, executor, + sub_server) + + process = multiprocess.Process( + target=app_worker, kwargs={'app': app, + 'port': TEST_PORT}) + process.start() + yield + process.terminate() + + +def test_raise_exception_when_create_server_and_no_sub_mgr(sub_server): + with pytest.raises(AssertionError): + sub_server(None, None) + + +def test_should_trigger_on_connect_if_client_connect_valid(server_with_mocks): + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + new SubscriptionClient('ws://localhost:{1}/socket') + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + try: + subprocess.check_output( + ['node', '-e', node_script], stderr=subprocess.STDOUT, timeout=.3) + except: + mock = server_with_mocks.get_nowait() + assert mock.name == 'on_connect' + mock.assert_called_once() + + +def test_should_trigger_on_connect_with_correct_cxn_params(server_with_mocks): + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const connectionParams = {{test: true}} + new SubscriptionClient('ws://localhost:{1}/socket', {{ + connectionParams, + }}) + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + try: + subprocess.check_output( + ['node', '-e', node_script], stderr=subprocess.STDOUT, timeout=.2) + except: + mock = server_with_mocks.get_nowait() + assert mock.name == 'on_connect' + mock.assert_called_once() + mock.assert_called_with({'test': True}) + + +def test_trigger_on_disconnect_when_client_disconnects(server_with_mocks): + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + client.client.close() + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + subprocess.check_output(['node', '-e', node_script]) + mock = server_with_mocks.get_nowait() + assert mock.name == 'on_disconnect' + mock.assert_called_once() + + +def test_should_call_unsubscribe_when_client_closes_cxn(server_with_mocks): + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + client.subscribe({{ + query: `subscription useInfo($id: String) {{ + user(id: $id) {{ + id + name + }} + }}`, + operationName: 'useInfo', + variables: {{ + id: 3, + }}, + }}, function (error, result) {{ + // nothing + }} + ) + setTimeout(() => {{ + client.client.close() + }}, 500) + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + try: + subprocess.check_output( + ['node', '-e', node_script], stderr=subprocess.STDOUT, timeout=1) + except: + while True: + mock = server_with_mocks.get_nowait() + if mock.name == 'on_unsubscribe': + mock.assert_called_once() + break + + +def test_should_trigger_on_subscribe_when_client_subscribes(server_with_mocks): + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + client.subscribe({{ + query: `subscription useInfo($id: String) {{ + user(id: $id) {{ + id + name + }} + }}`, + operationName: 'useInfo', + variables: {{ + id: 3, + }}, + }}, function (error, result) {{ + // nothing + }}) + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + try: + subprocess.check_output( + ['node', '-e', node_script], stderr=subprocess.STDOUT, timeout=.2) + except: + while True: + mock = server_with_mocks.get_nowait() + if mock.name == 'on_subscribe': + mock.assert_called_once() + break + + +def test_should_trigger_on_unsubscribe_when_client_unsubscribes( + server_with_mocks): + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + const subId = client.subscribe({{ + query: `subscription useInfo($id: String) {{ + user(id: $id) {{ + id + name + }} + }}`, + operationName: 'useInfo', + variables: {{ + id: 3, + }}, + }}, function (error, result) {{ + // nothing + }}) + client.unsubscribe(subId) + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + try: + subprocess.check_output( + ['node', '-e', node_script], stderr=subprocess.STDOUT, timeout=.2) + except: + while True: + mock = server_with_mocks.get_nowait() + if mock.name == 'on_unsubscribe': + mock.assert_called_once() + break + + +def test_should_send_correct_results_to_multiple_client_subscriptions(server): + + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + const client1 = new SubscriptionClient('ws://localhost:{1}/socket') + let numResults = 0; + client.subscribe({{ + query: `subscription useInfo($id: String) {{ + user(id: $id) {{ + id + name + }} + }}`, + operationName: 'useInfo', + variables: {{ + id: 3, + }}, + + }}, function (error, result) {{ + if (error) {{ + console.log(JSON.stringify(error)); + }} + if (result) {{ + numResults++; + console.log(JSON.stringify({{ + client: {{ + result: result, + numResults: numResults + }} + }})); + }} else {{ + // pass + }} + }} + ); + const client2 = new SubscriptionClient('ws://localhost:{1}/socket') + let numResults1 = 0; + client2.subscribe({{ + query: `subscription useInfo($id: String) {{ + user(id: $id) {{ + id + name + }} + }}`, + operationName: 'useInfo', + variables: {{ + id: 2, + }}, + + }}, function (error, result) {{ + if (error) {{ + console.log(JSON.stringify(error)); + }} + if (result) {{ + numResults1++; + console.log(JSON.stringify({{ + client2: {{ + result: result, + numResults: numResults1 + }} + }})); + }} + }} + ); + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + time.sleep(.2) + requests.post( + 'http://localhost:{0}/publish'.format(TEST_PORT), json=['user', {}]) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = {} + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values[list(line.keys())[0]] = line[list(line.keys())[0]] + except ValueError: + pass + except queue.Empty: + break + client = ret_values['client'] + assert client['result']['user'] + assert client['result']['user']['id'] == '3' + assert client['result']['user']['name'] == 'Jessie' + assert client['numResults'] == 1 + client2 = ret_values['client2'] + assert client2['result']['user'] + assert client2['result']['user']['id'] == '2' + assert client2['result']['user']['name'] == 'Marie' + assert client2['numResults'] == 1 + + +# TODO: Graphene subscriptions implementation does not currently return an +# error for missing or incorrect field(s); this test will continue to fail +# until that is fixed +def test_send_subscription_fail_message_to_client_with_invalid_query(server): + + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + setTimeout(function () {{ + client.subscribe({{ + query: `subscription useInfo($id: String) {{ + user(id: $id) {{ + id + birthday + }} + }}`, + operationName: 'useInfo', + variables: {{ + id: 3, + }}, + + }}, function (error, result) {{ + }} + ); + }}, 100); + client.client.onmessage = (message) => {{ + let msg = JSON.parse(message.data) + console.log(JSON.stringify({{[msg.type]: msg}})) + }}; + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + time.sleep(.2) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = {} + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values[list(line.keys())[0]] = line[list(line.keys())[0]] + except ValueError: + pass + except queue.Empty: + break + assert ret_values['type'] == SUBSCRIPTION_FAIL + assert len(ret_values['payload']['errors']) > 0 + + +# TODO: troubleshoot this a bit...passes, but receives extra messages which I'm +# filtering out w/ the "AttributeError" exception clause +def test_should_setup_the_proper_filters_when_subscribing(server): + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + const client2 = new SubscriptionClient('ws://localhost:{1}/socket') + let numResults = 0; + client.subscribe({{ + query: `subscription useInfoFilter1($id: String) {{ + userFiltered(id: $id) {{ + id + name + }} + }}`, + operationName: 'useInfoFilter1', + variables: {{ + id: 3, + }}, + + }}, function (error, result) {{ + if (error) {{ + console.log(JSON.stringify(error)); + }} + if (result) {{ + numResults += 1; + console.log(JSON.stringify({{ + client: {{ + result: result, + numResults: numResults + }} + }})); + }} else {{ + // pass + }} + }} + ); + client2.subscribe({{ + query: `subscription useInfoFilter1($id: String) {{ + userFiltered(id: $id) {{ + id + name + }} + }}`, + operationName: 'useInfoFilter1', + variables: {{ + id: 1, + }}, + + }}, function (error, result) {{ + if (error) {{ + console.log(JSON.stringify(error)); + }} + if (result) {{ + numResults += 1; + console.log(JSON.stringify({{ + client2: {{ + result: result, + numResults: numResults + }} + }})); + }} + }} + ); + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + time.sleep(.2) + requests.post( + 'http://localhost:{0}/publish'.format(TEST_PORT), + json=['user_filtered', { + 'id': 1 + }]) + requests.post( + 'http://localhost:{0}/publish'.format(TEST_PORT), + json=['user_filtered', { + 'id': 2 + }]) + requests.post( + 'http://localhost:{0}/publish'.format(TEST_PORT), + json=['user_filtered', { + 'id': 3 + }]) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = {} + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values[list(line.keys())[0]] = line[list(line.keys())[0]] + except ValueError: + pass + except AttributeError: + pass + except queue.Empty: + break + client = ret_values['client'] + assert client['result']['userFiltered'] + assert client['result']['userFiltered']['id'] == '3' + assert client['result']['userFiltered']['name'] == 'Jessie' + assert client['numResults'] == 2 + client2 = ret_values['client2'] + assert client2['result']['userFiltered'] + assert client2['result']['userFiltered']['id'] == '1' + assert client2['result']['userFiltered']['name'] == 'Dan' + assert client2['numResults'] == 1 + + +def test_correctly_sets_the_context_in_on_subscribe(server): + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const CTX = 'testContext'; + const client = new SubscriptionClient('ws://localhost:{1}/socket') + client.subscribe({{ + query: `subscription context {{ + context + }}`, + variables: {{}}, + context: CTX, + }}, (error, result) => {{ + client.unsubscribeAll(); + if (error) {{ + console.log(JSON.stringify(error)); + }} + if (result) {{ + console.log(JSON.stringify({{ + client: {{ + result: result, + }} + }})); + }} else {{ + // pass + }} + }} + ); + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + time.sleep(.2) + requests.post( + 'http://localhost:{0}/publish'.format(TEST_PORT), json=['context', {}]) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = {} + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values[list(line.keys())[0]] = line[list(line.keys())[0]] + except ValueError: + pass + except queue.Empty: + break + client = ret_values['client'] + assert client['result']['context'] + assert client['result']['context'] == 'testContext' + + +def test_passes_through_websocket_request_to_on_subscribe(server): + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + client.subscribe({{ + query: `subscription context {{ + context + }}`, + variables: {{}}, + }}, (error, result) => {{ + if (error) {{ + console.log(JSON.stringify(error)); + }} + }} + ); + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + try: + subprocess.check_output( + ['node', '-e', node_script], stderr=subprocess.STDOUT, timeout=.2) + except: + while True: + mock = server.get_nowait() + if mock.name == 'on_subscribe': + mock.assert_called_once() + mock.assert_called_with_contains('websocket') + break + + +def test_does_not_send_subscription_data_after_client_unsubscribes(server): + + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + setTimeout(function () {{ + let subId = client.subscribe({{ + query: `subscription useInfo($id: String) {{ + user(id: $id) {{ + id + name + }} + }}`, + operationName: 'useInfo', + variables: {{ + id: 3, + }}, + + }}, function (error, result) {{ + }} + ); + client.unsubscribe(subId); + }}, 100); + client.client.onmessage = (message) => {{ + let msg = JSON.parse(message.data) + console.log(JSON.stringify({{[msg.type]: msg}})) + }}; + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + time.sleep(.2) + requests.post( + 'http://localhost:{0}/publish'.format(TEST_PORT), json=['user', {}]) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = {} + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values[list(line.keys())[0]] = line[list(line.keys())[0]] + except ValueError: + pass + except queue.Empty: + break + with pytest.raises(KeyError): + assert ret_values[SUBSCRIPTION_DATA] + + +# TODO: Need to look into why this test is throwing code 1006, not 1002 like +# it should be (1006 more general than 1002 protocol error) +def test_rejects_client_that_does_not_specifiy_a_supported_protocol(server): + + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const client = new WebSocket('ws://localhost:{1}/socket') + client.on('close', (code) => {{ + console.log(JSON.stringify(code)) + }} + ); + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = [] + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values.append(line) + except ValueError: + pass + except queue.Empty: + break + assert ret_values[0] == 1002 or 1006 + + +def test_rejects_unparsable_message(server): + + node_script = ''' + module.paths.push('{0}'); + WebSocket = require('ws'); + const GRAPHQL_SUBSCRIPTIONS = 'graphql-subscriptions'; + const client = new WebSocket('ws://localhost:{1}/socket', + GRAPHQL_SUBSCRIPTIONS); + client.onmessage = (message) => {{ + let msg = JSON.parse(message.data) + console.log(JSON.stringify({{[msg.type]: msg}})) + client.close(); + }}; + client.onopen = () => {{ + client.send('HI'); + }} + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = {} + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values[list(line.keys())[0]] = line[list(line.keys())[0]] + except ValueError: + pass + except queue.Empty: + break + assert ret_values['subscription_fail'] + assert len(ret_values['subscription_fail']['payload']['errors']) > 0 + + +def test_rejects_nonsense_message(server): + + node_script = ''' + module.paths.push('{0}'); + WebSocket = require('ws'); + const GRAPHQL_SUBSCRIPTIONS = 'graphql-subscriptions'; + const client = new WebSocket('ws://localhost:{1}/socket', + GRAPHQL_SUBSCRIPTIONS); + client.onmessage = (message) => {{ + let msg = JSON.parse(message.data) + console.log(JSON.stringify({{[msg.type]: msg}})) + client.close(); + }}; + client.onopen = () => {{ + client.send(JSON.stringify({{}})); + }} + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = {} + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values[list(line.keys())[0]] = line[list(line.keys())[0]] + except ValueError: + pass + except queue.Empty: + break + assert ret_values['subscription_fail'] + assert len(ret_values['subscription_fail']['payload']['errors']) > 0 + + +def test_does_not_crash_on_unsub_from_unknown_sub(server): + + node_script = ''' + module.paths.push('{0}'); + WebSocket = require('ws'); + const GRAPHQL_SUBSCRIPTIONS = 'graphql-subscriptions'; + const client = new WebSocket('ws://localhost:{1}/socket', + GRAPHQL_SUBSCRIPTIONS); + setTimeout(function () {{ + client.onopen = () => {{ + const SUBSCRIPTION_END = 'subscription_end'; + let subEndMsg = {{type: SUBSCRIPTION_END, id: 'toString'}} + client.send(JSON.stringify(subEndMsg)); + }} + }}, 200); + client.onmessage = (message) => {{ + let msg = JSON.parse(message.data) + console.log(JSON.stringify({{[msg.type]: msg}})) + }}; + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = [] + while True: + try: + line = q.get_nowait() + ret_values.append(line) + except ValueError: + pass + except queue.Empty: + break + assert len(ret_values) == 0 + + +def test_sends_back_any_type_of_error(server): + + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + client.subscribe({{ + query: `invalid useInfo {{ + error + }}`, + variables: {{}}, + }}, function (errors, result) {{ + client.unsubscribeAll(); + if (errors) {{ + console.log(JSON.stringify({{'errors': errors}})) + }} + if (result) {{ + console.log(JSON.stringify({{'result': result}})) + }} + }} + ); + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + time.sleep(5) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = {} + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values[list(line.keys())[0]] = line[list(line.keys())[0]] + except ValueError: + pass + except queue.Empty: + break + assert len(ret_values['errors']) > 0 + + +def test_handles_errors_prior_to_graphql_execution(server_with_on_sub_handler): + + node_script = ''' + module.paths.push('{0}') + WebSocket = require('ws') + const SubscriptionClient = + require('subscriptions-transport-ws').SubscriptionClient + const client = new SubscriptionClient('ws://localhost:{1}/socket') + client.subscribe({{ + query: `subscription context {{ + context + }}`, + variables: {{}}, + context: {{}}, + }}, function (errors, result) {{ + client.unsubscribeAll(); + if (errors) {{ + console.log(JSON.stringify({{'errors': errors}})) + }} + if (result) {{ + console.log(JSON.stringify({{'result': result}})) + }} + }} + ); + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + time.sleep(.2) + requests.post( + 'http://localhost:{0}/publish'.format(TEST_PORT), json=['context', {}]) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.2) + ret_values = {} + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + line = json.loads(line) + ret_values[list(line.keys())[0]] = line[list(line.keys())[0]] + except ValueError: + pass + except queue.Empty: + break + assert isinstance(ret_values['errors'], list) + assert ret_values['errors'][0]['message'] == 'bad' + + +def test_sends_a_keep_alive_signal_in_the_socket(server_with_keep_alive): + + node_script = ''' + module.paths.push('{0}'); + WebSocket = require('ws'); + const GRAPHQL_SUBSCRIPTIONS = 'graphql-subscriptions'; + const KEEP_ALIVE = 'keepalive'; + const client = new WebSocket('ws://localhost:{1}/socket', + GRAPHQL_SUBSCRIPTIONS); + let yieldCount = 0; + client.onmessage = (message) => {{ + let msg = JSON.parse(message.data) + if (msg.type === KEEP_ALIVE) {{ + yieldCount += 1; + if (yieldCount > 1) {{ + let returnMsg = {{'type': msg.type, 'yieldCount': yieldCount}} + console.log(JSON.stringify(returnMsg)) + client.close(); + }} + }} + }}; + '''.format( + os.path.join(os.path.dirname(__file__), 'node_modules'), TEST_PORT) + + p = subprocess.Popen( + ['node', '-e', node_script], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + q = queue.Queue() + t = threading.Thread(target=enqueue_output, args=(p.stdout, q)) + t.daemon = True + t.start() + time.sleep(.5) + while True: + try: + _line = q.get_nowait() + if isinstance(_line, bytes): + line = _line.decode() + ret_value = json.loads(line) + except ValueError: + pass + except queue.Empty: + break + assert ret_value['type'] == 'keepalive' + assert ret_value['yieldCount'] > 1