From d4d2e272463d65a9c49b070f0cc4253d5b74d9cc Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 6 Jul 2018 11:11:47 +0200 Subject: [PATCH 01/17] Add tox configuration file and modify travis config to use tox --- .travis.yml | 11 +++-------- tox.ini | 11 +++++++++++ 2 files changed, 14 insertions(+), 8 deletions(-) create mode 100644 tox.ini diff --git a/.travis.yml b/.travis.yml index 5a415e1..de5441b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,11 +1,6 @@ language: python python: - "3.6" -before_install: -- pip install -q coveralls -install: -- pip install .[tests] -script: -- coverage run -m unittest discover tests -after_success: -- coveralls +- "3.7" +install: pip install tox-travis +script: tox diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..8443c01 --- /dev/null +++ b/tox.ini @@ -0,0 +1,11 @@ +[tox] +envlist = py36,py37 + +[testenv] +passenv = TRAVIS TRAVIS_* +deps = + coveralls +extras = tests +commands = + coverage run -m unittest discover tests + coveralls From 2f2ed35bd189b3958a2697100135b772277b69c4 Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 6 Jul 2018 11:21:34 +0200 Subject: [PATCH 02/17] Use xenail as the distro for python 3.7 testing on travis --- .travis.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.travis.yml b/.travis.yml index de5441b..3da4318 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,10 @@ +dist: xenial +sudo: required + language: python python: - "3.6" - "3.7" + install: pip install tox-travis script: tox From 087dabeb8ed4912ede65aabc4cb0b4118f3a1151 Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Sat, 13 Oct 2018 09:14:10 +0200 Subject: [PATCH 03/17] Fix newly discovered code style issues --- aiocometd/client.py | 9 +++++---- aiocometd/exceptions.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/aiocometd/client.py b/aiocometd/client.py index 0c71ea5..41ffb40 100644 --- a/aiocometd/client.py +++ b/aiocometd/client.py @@ -379,9 +379,9 @@ async def receive(self): response = await self._get_message(self.connection_timeout) self._verify_response(response) return response - else: - raise ClientInvalidOperation("The client is closed and there are " - "no pending messages.") + + raise ClientInvalidOperation("The client is closed and there are " + "no pending messages.") async def __aiter__(self): """Asynchronous iterator @@ -466,7 +466,8 @@ async def _get_message(self, connection_timeout): # handle the completed task if get_task in done: return get_task.result() - elif server_disconnected_task in done: + + if server_disconnected_task in done: await self._check_server_disconnected() else: raise TransportTimeoutError("Lost connection with the " diff --git a/aiocometd/exceptions.py b/aiocometd/exceptions.py index 47cba01..8fe9ef3 100644 --- a/aiocometd/exceptions.py +++ b/aiocometd/exceptions.py @@ -52,12 +52,12 @@ class ServerError(AiocometdException): @property def message(self): """Error description""" - return self.args[0] + return self.args[0] # pylint: disable=unsubscriptable-object @property def response(self): """Server response message""" - return self.args[1] + return self.args[1] # pylint: disable=unsubscriptable-object @property def error(self): From 3c90fa2a5d8aca1dd6a31fed9345f6fa9acea870 Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Sat, 13 Oct 2018 13:07:17 +0200 Subject: [PATCH 04/17] Add type hints to the modules in the main package --- aiocometd/client.py | 121 +++++++++++++++++++++++----------------- aiocometd/exceptions.py | 32 ++++++----- aiocometd/extensions.py | 18 +++--- aiocometd/utils.py | 41 +++++++------- docs/source/conf.py | 3 +- setup.py | 3 +- 6 files changed, 123 insertions(+), 95 deletions(-) diff --git a/aiocometd/client.py b/aiocometd/client.py index 41ffb40..f342f89 100644 --- a/aiocometd/client.py +++ b/aiocometd/client.py @@ -5,16 +5,25 @@ from collections import abc from contextlib import suppress import json +from typing import Optional, List, Union, Set, AsyncIterator, Type +import ssl as ssl_module +from types import TracebackType + +import aiohttp from .transports import create_transport +from .transports.abc import Transport from .constants import DEFAULT_CONNECTION_TYPE, ConnectionType, MetaChannel, \ SERVICE_CHANNEL_PREFIX, TransportState from .exceptions import ServerError, ClientInvalidOperation, \ TransportTimeoutError, ClientError -from .utils import is_server_error_message +from .utils import is_server_error_message, JsonObject, JsonDumper, JsonLoader +from .extensions import Extension, AuthExtension LOGGER = logging.getLogger(__name__) +ConnectionTypeSpec = Union[ConnectionType, List[ConnectionType]] +SSLValidationMode = Union[ssl_module.SSLContext, aiohttp.Fingerprint, bool] class Client: # pylint: disable=too-many-instance-attributes @@ -31,21 +40,25 @@ class Client: # pylint: disable=too-many-instance-attributes _DEFAULT_CONNECTION_TYPES = [ConnectionType.WEBSOCKET, ConnectionType.LONG_POLLING] - def __init__(self, url, connection_types=None, *, - connection_timeout=10.0, ssl=None, max_pending_count=100, - extensions=None, auth=None, json_dumps=json.dumps, - json_loads=json.loads, loop=None): + def __init__(self, url: str, + connection_types: Optional[ConnectionTypeSpec] = None, *, + connection_timeout: Union[int, float] = 10.0, + ssl: Optional[SSLValidationMode] = None, + max_pending_count: int = 100, + extensions: Optional[List[Extension]] = None, + auth: Optional[AuthExtension] = None, + json_dumps: JsonDumper = json.dumps, + json_loads: JsonLoader = json.loads, + loop: Optional[asyncio.AbstractEventLoop] = None) -> None: """ - :param str url: CometD service url + :param url: CometD service url :param connection_types: List of connection types in order of \ preference, or a single connection type name. If ``None``, \ [:obj:`~ConnectionType.WEBSOCKET`, \ :obj:`~ConnectionType.LONG_POLLING`] will be used as a default value. - :type connection_types: list[ConnectionType], ConnectionType or None :param connection_timeout: The maximum amount of time to wait for the \ transport to re-establish a connection with the server when the \ connection fails. - :type connection_timeout: int, float or None :param ssl: SSL validation mode. None for default SSL check \ (:func:`ssl.create_default_context` is used), False for skip SSL \ certificate validation, \ @@ -53,20 +66,17 @@ def __init__(self, url, connection_types=None, *, client_reference.html#aiohttp.Fingerprint>`_ for fingerprint \ validation, :obj:`ssl.SSLContext` for custom SSL certificate \ validation. - :param int max_pending_count: The maximum number of messages to \ + :param max_pending_count: The maximum number of messages to \ prefetch from the server. If the number of prefetched messages reach \ this size then the connection will be suspended, until messages are \ consumed. \ If it is less than or equal to zero, the count is infinite. :param extensions: List of protocol extension objects - :type extensions: list[Extension] or None - :param AuthExtension auth: An auth extension + :param auth: An auth extension :param json_dumps: Function for JSON serialization, the default is \ :func:`json.dumps` - :type json_dumps: :func:`callable` :param json_loads: Function for JSON deserialization, the default is \ :func:`json.loads` - :type json_loads: :func:`callable` :param loop: Event :obj:`loop ` used to schedule tasks. If *loop* is ``None`` then :func:`asyncio.get_event_loop` is used to get the default @@ -75,7 +85,6 @@ def __init__(self, url, connection_types=None, *, #: CometD service url self.url = url #: List of connection types to use in order of preference - self._connection_types = None if isinstance(connection_types, ConnectionType): self._connection_types = [connection_types] elif isinstance(connection_types, abc.Iterable): @@ -84,9 +93,9 @@ def __init__(self, url, connection_types=None, *, self._connection_types = self._DEFAULT_CONNECTION_TYPES self._loop = loop or asyncio.get_event_loop() #: queue for consuming incoming event messages - self._incoming_queue = None + self._incoming_queue: Optional[asyncio.Queue] = None #: transport object - self._transport = None + self._transport: Optional[Transport] = None #: marks whether the client is open or closed self._closed = True #: The maximum amount of time to wait for the transport to re-establish @@ -105,7 +114,7 @@ def __init__(self, url, connection_types=None, *, #: Function for JSON deserialization self._json_loads = json_loads - def __repr__(self): + def __repr__(self) -> str: """Formal string representation""" cls_name = type(self).__name__ fmt_spec = "{}({}, {}, connection_timeout={}, ssl={}, " \ @@ -121,19 +130,19 @@ def __repr__(self): reprlib.repr(self._loop)) @property - def closed(self): + def closed(self) -> bool: """Marks whether the client is open or closed""" return self._closed @property - def subscriptions(self): + def subscriptions(self) -> Set[str]: """Set of subscribed channels""" if self._transport: return self._transport.subscriptions return set() @property - def connection_type(self): + def connection_type(self) -> Optional[ConnectionType]: """The current connection type in use if the client is open, otherwise ``None``""" if self._transport is not None: @@ -141,7 +150,7 @@ def connection_type(self): return None @property - def pending_count(self): + def pending_count(self) -> int: """The number of pending incoming messages Once :obj:`open` is called the client starts listening for messages @@ -153,19 +162,19 @@ def pending_count(self): return self._incoming_queue.qsize() @property - def has_pending_messages(self): + def has_pending_messages(self) -> bool: """Marks whether the client has any pending incoming messages""" return self.pending_count > 0 - def _pick_connection_type(self, connection_types): + def _pick_connection_type(self, connection_types: List[str]) \ + -> Optional[ConnectionType]: """Pick a connection type based on the *connection_types* supported by the server and on the user's preferences - :param list[str] connection_types: Connection types \ + :param connection_types: Connection types \ supported by the server :return: The connection type with the highest precedence \ which is supported by the server - :rtype: ConnectionType or None """ server_connection_types = [] for type_string in connection_types: @@ -180,12 +189,11 @@ def _pick_connection_type(self, connection_types): result = min(intersection, key=self._connection_types.index) return result - async def _negotiate_transport(self): + async def _negotiate_transport(self) -> Transport: """Negotiate the transport type to use with the server and create the transport object :return: Transport object - :rtype: Transport :raise ClientError: If none of the connection types offered by the \ server are supported """ @@ -234,7 +242,7 @@ async def _negotiate_transport(self): await transport.close() raise - async def open(self): + async def open(self) -> None: """Establish a connection with the CometD server This method works mostly the same way as the `handshake` method of @@ -258,10 +266,12 @@ async def open(self): response = await self._transport.connect() self._verify_response(response) self._closed = False + + assert self.connection_type is not None LOGGER.info("Client opened with connection_type %r", self.connection_type.value) - async def close(self): + async def close(self) -> None: """Disconnect from the CometD server""" if not self.closed: if self.pending_count == 0: @@ -278,10 +288,10 @@ async def close(self): self._closed = True LOGGER.info("Client closed.") - async def subscribe(self, channel): + async def subscribe(self, channel: str) -> None: """Subscribe to *channel* - :param str channel: Name of the channel + :param channel: Name of the channel :raise ClientInvalidOperation: If the client is :obj:`closed` :raise TransportError: If a network or transport related error occurs :raise ServerError: If the subscribe request gets rejected by the \ @@ -292,14 +302,15 @@ async def subscribe(self, channel): "the client is closed.") await self._check_server_disconnected() + assert self._transport is not None response = await self._transport.subscribe(channel) self._verify_response(response) LOGGER.info("Subscribed to channel %s", channel) - async def unsubscribe(self, channel): + async def unsubscribe(self, channel: str) -> None: """Unsubscribe from *channel* - :param str channel: Name of the channel + :param channel: Name of the channel :raise ClientInvalidOperation: If the client is :obj:`closed` :raise TransportError: If a network or transport related error occurs :raise ServerError: If the unsubscribe request gets rejected by the \ @@ -310,17 +321,17 @@ async def unsubscribe(self, channel): "while, the client is closed.") await self._check_server_disconnected() + assert self._transport is not None response = await self._transport.unsubscribe(channel) self._verify_response(response) LOGGER.info("Unsubscribed from channel %s", channel) - async def publish(self, channel, data): + async def publish(self, channel: str, data: JsonObject) -> JsonObject: """Publish *data* to the given *channel* - :param str channel: Name of the channel - :param dict data: Data to send to the server + :param channel: Name of the channel + :param data: Data to send to the server :return: Publish response - :rtype: dict :raise ClientInvalidOperation: If the client is :obj:`closed` :raise TransportError: If a network or transport related error occurs :raise ServerError: If the publish request gets rejected by the server @@ -330,28 +341,29 @@ async def publish(self, channel, data): "the client is closed.") await self._check_server_disconnected() + assert self._transport is not None response = await self._transport.publish(channel, data) self._verify_response(response) return response - def _verify_response(self, response): + def _verify_response(self, response: JsonObject) -> None: """Check the ``successful`` status of the *response* and raise \ the appropriate :obj:`~aiocometd.exceptions.ServerError` if it's False If the *response* has no ``successful`` field, it's considered to be successful. - :param dict response: Response message + :param response: Response message :raise ServerError: If the *response* is not ``successful`` """ if is_server_error_message(response): self._raise_server_error(response) - def _raise_server_error(self, response): + def _raise_server_error(self, response: JsonObject) -> None: """Raise the appropriate :obj:`~aiocometd.exceptions.ServerError` for \ the failed *response* - :param dict response: Response message + :param response: Response message :raise ServerError: If the *response* is not ``successful`` """ channel = response["channel"] @@ -363,11 +375,10 @@ def _raise_server_error(self, response): message = "Publish request failed." raise ServerError(message, response) - async def receive(self): + async def receive(self) -> JsonObject: """Wait for incoming messages from the server :return: Incoming message - :rtype: dict :raise ClientInvalidOperation: If the client is closed, and has no \ more pending incoming messages :raise ServerError: If the client receives a confirmation message \ @@ -383,7 +394,7 @@ async def receive(self): raise ClientInvalidOperation("The client is closed and there are " "no pending messages.") - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator: """Asynchronous iterator :raise ServerError: If the client receives a confirmation message \ @@ -397,7 +408,7 @@ async def __aiter__(self): except ClientInvalidOperation: break - async def __aenter__(self): + async def __aenter__(self) -> "Client": """Enter the runtime context and call :obj:`open` :raise ClientInvalidOperation: If the client is already open, or in \ @@ -415,18 +426,20 @@ async def __aenter__(self): raise return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType) -> None: """Exit the runtime context and call :obj:`open`""" await self.close() - async def _get_message(self, connection_timeout): + async def _get_message(self, connection_timeout: Union[int, float]) \ + -> JsonObject: """Get the next incoming message :param connection_timeout: The maximum amount of time to wait for the \ transport to re-establish a connection with the server when the \ connection fails. :return: Incoming message - :rtype: dict :raise TransportTimeoutError: If the transport can't re-establish \ connection with the server in :obj:`connection_timeout` time. :raise ServerError: If the connection gets closed by the server. @@ -440,11 +453,13 @@ async def _get_message(self, connection_timeout): ) tasks.append(timeout_task) + assert self._incoming_queue is not None # task waiting on incoming messages get_task = asyncio.ensure_future(self._incoming_queue.get(), loop=self._loop) tasks.append(get_task) + assert self._transport is not None # task waiting on server side disconnect server_disconnected_task = asyncio.ensure_future( self._transport.wait_for_state( @@ -468,7 +483,9 @@ async def _get_message(self, connection_timeout): return get_task.result() if server_disconnected_task in done: - await self._check_server_disconnected() + await self.close() + raise ServerError("Connection closed by the server", + self._transport.last_connect_result) else: raise TransportTimeoutError("Lost connection with the " "server.") @@ -478,7 +495,8 @@ async def _get_message(self, connection_timeout): task.cancel() raise - async def _wait_connection_timeout(self, timeout): + async def _wait_connection_timeout(self, timeout: Union[int, float]) \ + -> None: """Wait for and return when the transport can't re-establish \ connection with the server in *timeout* time @@ -486,6 +504,7 @@ async def _wait_connection_timeout(self, timeout): transport to re-establish a connection with the server when the \ connection fails. """ + assert self._transport is not None while True: await self._transport.wait_for_state(TransportState.CONNECTING) try: @@ -496,7 +515,7 @@ async def _wait_connection_timeout(self, timeout): except asyncio.TimeoutError: break - async def _check_server_disconnected(self): + async def _check_server_disconnected(self) -> None: """Checks whether the current transport'state is :obj:`TransportState.SERVER_DISCONNECTED` and if it is then closes the client and raises an error diff --git a/aiocometd/exceptions.py b/aiocometd/exceptions.py index 8fe9ef3..b8996d1 100644 --- a/aiocometd/exceptions.py +++ b/aiocometd/exceptions.py @@ -11,6 +11,8 @@ TransportConnectionClosed ServerError """ +from typing import Optional, List + from . import utils @@ -39,47 +41,51 @@ class TransportConnectionClosed(TransportError): class ServerError(AiocometdException): - """CometD server side error + """CometD server side error""" + # pylint: disable=useless-super-delegation + def __init__(self, message: str, response: utils.JsonObject) -> None: + """If the *response* contains an error field it gets parsed + according to the \ + `specs `_ - If the *response* contains an error field it gets parsed - according to the \ - `specs `_ + :param message: Error description + :param response: Server response message + """ + super().__init__(message, response) - :param str message: Error description - :param dict response: Server response message - """ + # pylint: enable=useless-super-delegation @property - def message(self): + def message(self) -> str: """Error description""" return self.args[0] # pylint: disable=unsubscriptable-object @property - def response(self): + def response(self) -> utils.JsonObject: """Server response message""" return self.args[1] # pylint: disable=unsubscriptable-object @property - def error(self): + def error(self) -> Optional[str]: """Error field in the :obj:`response`""" return self.response.get("error") @property - def error_code(self): + def error_code(self) -> Optional[int]: """Error code part of the error code part of the `error\ `_, \ message field""" return utils.get_error_code(self.error) @property - def error_message(self): + def error_message(self) -> Optional[str]: """Description part of the `error\ `_, \ message field""" return utils.get_error_message(self.error) @property - def error_args(self): + def error_args(self) -> Optional[List[str]]: """Arguments part of the `error\ `_, \ message field""" diff --git a/aiocometd/extensions.py b/aiocometd/extensions.py index 7d59fbe..a30490f 100644 --- a/aiocometd/extensions.py +++ b/aiocometd/extensions.py @@ -1,34 +1,38 @@ """Extension classes""" from abc import ABC, abstractmethod +from typing import List, Dict, Optional + +from .utils import JsonObject class Extension(ABC): """Defines operations supported by extensions""" @abstractmethod - async def outgoing(self, payload, headers): + async def outgoing(self, payload: List[JsonObject], + headers: Dict[str, str]) -> None: """Process outgoing *payload* and *headers* Called just before a payload is sent - :param list[dict] payload: List of outgoing messages - :param dict headers: Headers to send + :param payload: List of outgoing messages + :param headers: Headers to send """ @abstractmethod - async def incoming(self, payload, headers=None): + async def incoming(self, payload: List[JsonObject], + headers: Optional[Dict[str, str]] = None) -> None: """Process incoming *payload* and *headers* Called just after a payload is received - :param list[dict] payload: List of incoming messages + :param payload: List of incoming messages :param headers: Headers to send - :type headers: dict or None """ class AuthExtension(Extension): # pylint: disable=abstract-method """Extension with support for authentication""" - async def authenticate(self): + async def authenticate(self) -> None: """Called after a failed authentication attempt For authentication schemes where the credentials are static it doesn't diff --git a/aiocometd/utils.py b/aiocometd/utils.py index 5050c43..c14fd6a 100644 --- a/aiocometd/utils.py +++ b/aiocometd/utils.py @@ -3,19 +3,25 @@ import asyncio from functools import wraps from http import HTTPStatus +from typing import Union, Optional, List, Awaitable, Callable, Dict, Any from .constants import META_CHANNEL_PREFIX, SERVICE_CHANNEL_PREFIX -def defer(coro_func, delay=None, *, loop=None): +CoroFunction = Callable[..., Awaitable[Any]] +JsonObject = Dict[str, Any] +JsonDumper = Callable[[JsonObject], str] +JsonLoader = Callable[[str], JsonObject] + + +def defer(coro_func: CoroFunction, delay: Union[int, float, None] = None, *, + loop: Optional[asyncio.AbstractEventLoop] = None) -> CoroFunction: """Returns a coroutine function that will defer the call to the given *coro_func* by *delay* seconds - :param asyncio.coroutine coro_func: A coroutine function + :param coro_func: A coroutine function :param delay: Delay in seconds - :type delay: int, float or None :param loop: An event loop - :type loop: asyncio.BaseEventLoop or None :return: Coroutine function wrapper """ @wraps(coro_func) @@ -27,7 +33,7 @@ async def wrapper(*args, **kwargs): # pylint: disable=missing-docstring return wrapper -def get_error_code(error_field): +def get_error_code(error_field: Union[str, None]) -> Optional[int]: """Get the error code part of the `error\ `_, message \ field @@ -35,11 +41,9 @@ def get_error_code(error_field): :param error_field: `Error\ `_, message \ field - :type error_field: str or None :return: The error code as an int if 3 digits can be matched at the \ beginning of the error field, for all other cases (``None`` or invalid \ error field) return ``None`` - :rtype: int or None """ result = None if error_field is not None: @@ -49,7 +53,7 @@ def get_error_code(error_field): return result -def get_error_message(error_field): +def get_error_message(error_field: Union[str, None]) -> Optional[str]: """Get the description part of the `error\ `_, message \ field @@ -57,10 +61,8 @@ def get_error_message(error_field): :param error_field: `Error\ `_, message \ field - :type error_field: str or None :return: The third part of the error field as a string if it can be \ matched otherwise return ``None`` - :rtype: str or None """ result = None if error_field is not None: @@ -70,7 +72,7 @@ def get_error_message(error_field): return result -def get_error_args(error_field): +def get_error_args(error_field: Union[str, None]) -> Optional[List[str]]: """Get the arguments part of the `error\ `_, message \ field @@ -78,10 +80,8 @@ def get_error_args(error_field): :param error_field: `Error\ `_, message \ field - :type error_field: str or None :return: The second part of the error field as a list of strings if it \ can be matched otherwise return ``None`` - :rtype: list[str] or None """ result = None if error_field is not None: @@ -94,15 +94,15 @@ def get_error_args(error_field): return result -def is_matching_response(response_message, message): +def is_matching_response(response_message: JsonObject, + message: JsonObject) -> bool: """Check whether the *response_message* is a response for the given *message*. - :param dict message: A sent message + :param message: A sent message :param response_message: A response message :return: True if the *response_message* is a match for *message* otherwise False. - :rtype: bool """ if message is None or response_message is None: return False @@ -115,24 +115,22 @@ def is_matching_response(response_message, message): "successful" in response_message) -def is_server_error_message(response_message): +def is_server_error_message(response_message: JsonObject) -> bool: """Check whether the *response_message* is a server side error message :param response_message: A response message :return: True if the *response_message* is a server side error message otherwise False. - :rtype: bool """ return not response_message.get("successful", True) -def is_event_message(response_message): +def is_event_message(response_message: JsonObject) -> bool: """Check whether the *response_message* is an event message :param response_message: A response message :return: True if the *response_message* is an event message otherwise False. - :rtype: bool """ channel = response_message["channel"] return (not channel.startswith(META_CHANNEL_PREFIX) and @@ -140,14 +138,13 @@ def is_event_message(response_message): "data" in response_message) -def is_auth_error_message(response_message): +def is_auth_error_message(response_message: JsonObject) -> bool: """Check whether the *response_message* is an authentication error message :param response_message: A response message :return: True if the *response_message* is an authentication error \ message, otherwise False. - :rtype: bool """ error_code = get_error_code(response_message.get("error")) # Strictly speaking, only UNAUTHORIZED should be considered as an auth diff --git a/docs/source/conf.py b/docs/source/conf.py index 0afb460..c327d6b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -57,7 +57,8 @@ def read(file_path): 'sphinx.ext.ifconfig', 'sphinx.ext.viewcode', 'sphinx.ext.githubpages', - 'sphinxcontrib.asyncio' + 'sphinxcontrib.asyncio', + 'sphinx_autodoc_typehints' ] # Add any paths that contain templates here, relative to this directory. diff --git a/setup.py b/setup.py index 0c2d3c8..2e99b3f 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,8 @@ ] DOCS_REQUIRE = [ "Sphinx>=1.7,<2.0", - "sphinxcontrib-asyncio>=0.2.0" + "sphinxcontrib-asyncio>=0.2.0", + "sphinx-autodoc-typehints" ] EXAMPLES_REQUIRE = [ "aioconsole>=0.1.7,<1.0.0" From 537815d9266e317432188d66017d573bef2cf8bc Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Tue, 27 Nov 2018 21:55:17 +0100 Subject: [PATCH 05/17] Add integration tests Add integration tests for testing normal interaction with the cometd chat example service. Test network error detection and recovery as well. --- setup.py | 3 +- .../__init__.py | 0 tests/integration/helpers.py | 74 +++++ tests/integration/test_client.py | 305 ++++++++++++++++++ tests/unit/__init__.py | 0 tests/{ => unit}/test_client.py | 0 tests/{ => unit}/test_exceptions.py | 0 tests/unit/test_transports/__init__.py | 0 tests/{ => unit}/test_transports/test_base.py | 0 .../test_transports/test_long_polling.py | 0 .../test_transports/test_registry.py | 0 .../test_transports/test_websocket.py | 0 tests/{ => unit}/test_utils.py | 0 tox.ini | 3 +- 14 files changed, 383 insertions(+), 2 deletions(-) rename tests/{test_transports => integration}/__init__.py (100%) create mode 100644 tests/integration/helpers.py create mode 100644 tests/integration/test_client.py create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/test_client.py (100%) rename tests/{ => unit}/test_exceptions.py (100%) create mode 100644 tests/unit/test_transports/__init__.py rename tests/{ => unit}/test_transports/test_base.py (100%) rename tests/{ => unit}/test_transports/test_long_polling.py (100%) rename tests/{ => unit}/test_transports/test_registry.py (100%) rename tests/{ => unit}/test_transports/test_websocket.py (100%) rename tests/{ => unit}/test_utils.py (100%) diff --git a/setup.py b/setup.py index 0c2d3c8..504b51e 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,8 @@ ] TESTS_REQUIRE = [ "asynctest>=0.12.0,<1.0.0", - "coverage>=4.5,<5.0" + "coverage>=4.5,<5.0", + "docker>=3.5.1" ] DOCS_REQUIRE = [ "Sphinx>=1.7,<2.0", diff --git a/tests/test_transports/__init__.py b/tests/integration/__init__.py similarity index 100% rename from tests/test_transports/__init__.py rename to tests/integration/__init__.py diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py new file mode 100644 index 0000000..fa8267f --- /dev/null +++ b/tests/integration/helpers.py @@ -0,0 +1,74 @@ +import time +from urllib.request import urlopen +from http import HTTPStatus +from contextlib import suppress + +import docker + + +class DockerContainer: + def __init__(self, image_name, name, container_port, host_port): + self.image_name = image_name + self.name = name + self.contaner_port = container_port + self.host_port = host_port + self._container = None + self.client = docker.from_env() + self._ensure_exists() + + def _ensure_exists(self): + filter = { + "name": self.name, + "ancestor": self.image_name + } + results = self.client.containers.list(all=True, filters=filter) + if not results: + self._container = self.client.containers.run( + name=self.name, + image=self.image_name, + ports={f"{self.contaner_port}/tcp": self.host_port}, + detach=True + ) + else: + self._container = results[0] + + def _wait_for_state(self, state): + while self._container.status != state: + time.sleep(1) + self._container.reload() + + def _ensure_running(self): + self._ensure_exists() + + if self._container.status != "running": + if self._container.status == "exited": + self._container.start() + if self._container.status == "paused": + self._container.unpause() + + self._wait_for_state("running") + + def _get_url(self): + return f"http://localhost:{self.host_port}" + + def ensure_reacheable(self): + self._ensure_running() + url = self._get_url() + + status = None + with suppress(Exception): + status = urlopen(url).status + while status != HTTPStatus.OK: + time.sleep(1) + with suppress(Exception): + status = urlopen(url).status + + return url + + def stop(self): + self._container.stop() + self._wait_for_state("exited") + + def pause(self): + self._container.pause() + self._wait_for_state("paused") diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py new file mode 100644 index 0000000..3b4520f --- /dev/null +++ b/tests/integration/test_client.py @@ -0,0 +1,305 @@ +import logging +import asyncio + +from asynctest import TestCase, TestSuite + +from aiocometd import Client, ConnectionType +from aiocometd.exceptions import TransportTimeoutError +from tests.integration.helpers import DockerContainer + + +class BaseTestCase(TestCase): + IMAGE_NAME = "robertmrk/cometd-demos:alpine" + CONTAINER_NAME = "aiocometd-integration-test" + CONTAINER_PORT = 8080 + HOST_PORT = 9999 + COMETD_URL = f"http://localhost:{HOST_PORT}/cometd" + CONNECTION_TYPE = None + + container = None + + CHAT_ROOM = "demo" + CHAT_ROOM_CHANNEL = "/chat/" + CHAT_ROOM + MEMBERS_CHANGED_CHANNEL = "/members/" + CHAT_ROOM + MEMBERS_CHANNEL = "/service/members" + USER_NAME1 = "user1" + USER_NAME2 = "user2" + + def setUp(self): + self.container.ensure_reacheable() + + @classmethod + def setUpClass(cls): + cls.container = DockerContainer(cls.IMAGE_NAME, + cls.CONTAINER_NAME, + cls.CONTAINER_PORT, + cls.HOST_PORT) + + @classmethod + def tearDownClass(cls): + cls.container.stop() + + +class TestChat(BaseTestCase): + async def test_single_client_chat(self): + # create client + async with Client(self.COMETD_URL, + self.CONNECTION_TYPE) as client: + # subscribe to channels + await client.subscribe(self.CHAT_ROOM_CHANNEL) + await client.subscribe(self.MEMBERS_CHANGED_CHANNEL) + + # send initial message + join_message = { + "user": self.USER_NAME1, + "membership": "join", + "chat": self.USER_NAME1 + " has joined" + } + await client.publish(self.CHAT_ROOM_CHANNEL, join_message) + # add the user to the room's members + await client.publish(self.MEMBERS_CHANNEL, { + "user": self.USER_NAME1, + "room": self.CHAT_ROOM_CHANNEL + }) + + # verify the reception of the initial and members messages + message = await client.receive() + self.assertEqual(message, { + "data": join_message, + "channel": self.CHAT_ROOM_CHANNEL + }) + message = await client.receive() + self.assertEqual(message, { + "data": [self.USER_NAME1], + "channel": self.MEMBERS_CHANGED_CHANNEL + }) + + # send a chat message + chat_message = { + "user": self.USER_NAME1, + "chat": "test_message" + } + await client.publish(self.CHAT_ROOM_CHANNEL, chat_message) + # verify chat message + message = await client.receive() + self.assertEqual(message, { + "data": chat_message, + "channel": self.CHAT_ROOM_CHANNEL + }) + + async def test_multi_client_chat(self): + # create two clients + async with Client(self.COMETD_URL, self.CONNECTION_TYPE) as client1, \ + Client(self.COMETD_URL, self.CONNECTION_TYPE) as client2: + # subscribe to channels with client1 + await client1.subscribe(self.CHAT_ROOM_CHANNEL) + await client1.subscribe(self.MEMBERS_CHANGED_CHANNEL) + + # send initial message with client1 + join_message1 = { + "user": self.USER_NAME1, + "membership": "join", + "chat": self.USER_NAME1 + " has joined" + } + await client1.publish(self.CHAT_ROOM_CHANNEL, join_message1) + # add the user1 to the room's members + await client1.publish(self.MEMBERS_CHANNEL, { + "user": self.USER_NAME1, + "room": self.CHAT_ROOM_CHANNEL + }) + + # verify the reception of the initial and members messages + # for client1 + message = await client1.receive() + self.assertEqual(message, { + "data": join_message1, + "channel": self.CHAT_ROOM_CHANNEL + }) + message = await client1.receive() + self.assertEqual(message, { + "data": [self.USER_NAME1], + "channel": self.MEMBERS_CHANGED_CHANNEL + }) + + # subscribe to channels with client2 + await client2.subscribe(self.CHAT_ROOM_CHANNEL) + await client2.subscribe(self.MEMBERS_CHANGED_CHANNEL) + + # send initial message with client2 + join_message2 = { + "user": self.USER_NAME2, + "membership": "join", + "chat": self.USER_NAME2 + " has joined" + } + await client2.publish(self.CHAT_ROOM_CHANNEL, join_message2) + # add the user2 to the room's members + await client2.publish(self.MEMBERS_CHANNEL, { + "user": self.USER_NAME2, + "room": self.CHAT_ROOM_CHANNEL + }) + + # verify the reception of the initial and members messages of + # client2 for client1 + message = await client1.receive() + self.assertEqual(message, { + "data": join_message2, + "channel": self.CHAT_ROOM_CHANNEL + }) + message = await client1.receive() + self.assertEqual(message, { + "data": [self.USER_NAME1, self.USER_NAME2], + "channel": self.MEMBERS_CHANGED_CHANNEL + }) + + # verify the reception of the initial and members messages + # for client2 + message = await client2.receive() + self.assertEqual(message, { + "data": join_message2, + "channel": self.CHAT_ROOM_CHANNEL + }) + message = await client2.receive() + self.assertEqual(message, { + "data": [self.USER_NAME1, self.USER_NAME2], + "channel": self.MEMBERS_CHANGED_CHANNEL + }) + + +class TestTimeoutDetection(BaseTestCase): + async def test_timeout_on_server_termination(self): + # create client + async with Client(self.COMETD_URL, self.CONNECTION_TYPE, + connection_timeout=1) as client: + # subscribe to the room's channel + await client.subscribe(self.CHAT_ROOM_CHANNEL) + + # stop the service + self.container.stop() + # give a few seconds for the client to detect the lost connection + await asyncio.sleep(3) + + with self.assertRaises(TransportTimeoutError): + await client.receive() + + async def test_timeout_network_outage(self): + # create client + async with Client(self.COMETD_URL, self.CONNECTION_TYPE, + connection_timeout=1) as client: + # subscribe to the room's channel + await client.subscribe(self.CHAT_ROOM_CHANNEL) + + # pause the service + self.container.pause() + # wait for the server's connect request timeout to elapse to be + # able to detect the problem + await asyncio.sleep(client._transport.request_timeout) + # give a few seconds for the client to detect the lost connection + await asyncio.sleep(3) + + with self.assertRaises(TransportTimeoutError): + await client.receive() + + +class TestErrorRecovery(BaseTestCase): + async def test_revover_on_server_restart(self): + # create client + async with Client(self.COMETD_URL, self.CONNECTION_TYPE, + connection_timeout=3*60) as client: + # subscribe to the room's channel + await client.subscribe(self.CHAT_ROOM_CHANNEL) + + # stop the service + self.container.stop() + # give a few seconds for the client to detect the lost connection + await asyncio.sleep(3) + + # start the service + self.container.ensure_reacheable() + # give a few seconds for the client to recover the connection + await asyncio.sleep(3) + + # send a chat message + chat_message = { + "user": self.USER_NAME1, + "chat": "test_message" + } + await client.publish(self.CHAT_ROOM_CHANNEL, chat_message) + # verify chat message + message = await client.receive() + self.assertEqual(message, { + "data": chat_message, + "channel": self.CHAT_ROOM_CHANNEL + }) + + async def test_recover_on_temporary_network_outage(self): + # create client + async with Client(self.COMETD_URL, self.CONNECTION_TYPE, + connection_timeout=1) as client: + # subscribe to the room's channel + await client.subscribe(self.CHAT_ROOM_CHANNEL) + + # pause the service + self.container.pause() + # wait for the server's connect request timeout to elapse to be + # able to detect the problem + await asyncio.sleep(client._transport.request_timeout) + # give a few seconds for the client to detect the lost connection + await asyncio.sleep(3) + + # start the service + self.container.ensure_reacheable() + # give a few seconds for the client to recover the connection + await asyncio.sleep(3) + + # send a chat message + chat_message = { + "user": self.USER_NAME1, + "chat": "test_message" + } + await client.publish(self.CHAT_ROOM_CHANNEL, chat_message) + # verify chat message + message = await client.receive() + self.assertEqual(message, { + "data": chat_message, + "channel": self.CHAT_ROOM_CHANNEL + }) + + +class TestChatLongPolling(TestChat): + CONNECTION_TYPE = ConnectionType.LONG_POLLING + + +class TestChatWebsocket(TestChat): + CONNECTION_TYPE = ConnectionType.WEBSOCKET + + +class TestTimeoutDetectionLongPolling(TestTimeoutDetection): + CONNECTION_TYPE = ConnectionType.LONG_POLLING + + +class TestTimeoutDetectionWebsocket(TestTimeoutDetection): + CONNECTION_TYPE = ConnectionType.WEBSOCKET + + +class TestErrorRecoveryLongPolling(TestErrorRecovery): + CONNECTION_TYPE = ConnectionType.LONG_POLLING + + +class TestErrorRecoveryWebsocket(TestErrorRecovery): + CONNECTION_TYPE = ConnectionType.WEBSOCKET + + +def load_tests(loader, tests, pattern): + suite = TestSuite() + cases = ( + TestChatLongPolling, + TestChatWebsocket, + TestTimeoutDetectionLongPolling, + TestTimeoutDetectionWebsocket, + TestErrorRecoveryLongPolling, + TestErrorRecoveryWebsocket + ) + for case in cases: + tests = loader.loadTestsFromTestCase(case) + suite.addTests(tests) + return suite diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_client.py b/tests/unit/test_client.py similarity index 100% rename from tests/test_client.py rename to tests/unit/test_client.py diff --git a/tests/test_exceptions.py b/tests/unit/test_exceptions.py similarity index 100% rename from tests/test_exceptions.py rename to tests/unit/test_exceptions.py diff --git a/tests/unit/test_transports/__init__.py b/tests/unit/test_transports/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_transports/test_base.py b/tests/unit/test_transports/test_base.py similarity index 100% rename from tests/test_transports/test_base.py rename to tests/unit/test_transports/test_base.py diff --git a/tests/test_transports/test_long_polling.py b/tests/unit/test_transports/test_long_polling.py similarity index 100% rename from tests/test_transports/test_long_polling.py rename to tests/unit/test_transports/test_long_polling.py diff --git a/tests/test_transports/test_registry.py b/tests/unit/test_transports/test_registry.py similarity index 100% rename from tests/test_transports/test_registry.py rename to tests/unit/test_transports/test_registry.py diff --git a/tests/test_transports/test_websocket.py b/tests/unit/test_transports/test_websocket.py similarity index 100% rename from tests/test_transports/test_websocket.py rename to tests/unit/test_transports/test_websocket.py diff --git a/tests/test_utils.py b/tests/unit/test_utils.py similarity index 100% rename from tests/test_utils.py rename to tests/unit/test_utils.py diff --git a/tox.ini b/tox.ini index 8443c01..74125f4 100644 --- a/tox.ini +++ b/tox.ini @@ -7,5 +7,6 @@ deps = coveralls extras = tests commands = - coverage run -m unittest discover tests + coverage run -m unittest discover tests/unit + python -m unittest discover tests/integration coveralls From 51a4d4cddb9ea1278f3d02aa7382ed76ce4ef52f Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Tue, 27 Nov 2018 22:09:47 +0100 Subject: [PATCH 06/17] Enable docker usage in travis-ci --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.travis.yml b/.travis.yml index 3da4318..9a1c5b2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,5 +6,8 @@ python: - "3.6" - "3.7" +services: + - docker + install: pip install tox-travis script: tox From 2138689d2065902e41649903ade73d2769070d0e Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Wed, 28 Nov 2018 08:12:06 +0100 Subject: [PATCH 07/17] Improve the documentation of integration tests --- tests/integration/helpers.py | 42 +++++++++++++++++++++++++++++++- tests/integration/test_client.py | 20 ++++++++++++--- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index fa8267f..1a993c6 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -1,3 +1,4 @@ +"""Helper classes for integration tests""" import time from urllib.request import urlopen from http import HTTPStatus @@ -7,21 +8,43 @@ class DockerContainer: + """Docker container encapsulation + + If the container with the given *name* doesn't exists yet it'll be created. + """ def __init__(self, image_name, name, container_port, host_port): + """ + :param str image_name: A docker image name with or without a tag + :param name: Container name + :param container_port: TCP port exposed by the container + :param host_port: TCP port on the host where the exposed container \ + port gets published + """ + #: docker image name with or without a tag self.image_name = image_name + #: container name self.name = name + #: TCP port exposed by the container self.contaner_port = container_port + #: TCP port on the host where the exposed container port gets published self.host_port = host_port + #: container instance self._container = None + #: docker client self.client = docker.from_env() + self._ensure_exists() def _ensure_exists(self): + """Create the container if it doesn't already exists""" + # try to find the container by name and image filter = { "name": self.name, "ancestor": self.image_name } results = self.client.containers.list(all=True, filters=filter) + + # if it doesn't exists then create it if not results: self._container = self.client.containers.run( name=self.name, @@ -29,32 +52,47 @@ def _ensure_exists(self): ports={f"{self.contaner_port}/tcp": self.host_port}, detach=True ) + # if it exists assign it to the instance attribute else: self._container = results[0] def _wait_for_state(self, state): + """Wait until the state of the container becomes the given *state* + value + """ while self._container.status != state: time.sleep(1) self._container.reload() def _ensure_running(self): + """Get the container into the ``running`` state if it's in a different + one""" + # make sure the container exists self._ensure_exists() + # if the container is not running if self._container.status != "running": + # if the container is stopped then start it if self._container.status == "exited": self._container.start() + # if the container is paused then resume it if self._container.status == "paused": self._container.unpause() self._wait_for_state("running") def _get_url(self): + """Return the url of the container's service""" return f"http://localhost:{self.host_port}" - def ensure_reacheable(self): + def ensure_reachable(self): + """Start the container and make sure it's exposed service is reachable + """ + # make sure the container is running self._ensure_running() url = self._get_url() + # query the service's URL until it can be reached status = None with suppress(Exception): status = urlopen(url).status @@ -66,9 +104,11 @@ def ensure_reacheable(self): return url def stop(self): + """Stop the container""" self._container.stop() self._wait_for_state("exited") def pause(self): + """Pause the container""" self._container.pause() self._wait_for_state("paused") diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index 3b4520f..e4857fb 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -1,4 +1,3 @@ -import logging import asyncio from asynctest import TestCase, TestSuite @@ -9,24 +8,37 @@ class BaseTestCase(TestCase): + #: name of the docker image containing the CometD demo services IMAGE_NAME = "robertmrk/cometd-demos:alpine" + #: a name for the container CONTAINER_NAME = "aiocometd-integration-test" + #: TCP port exposed by the container CONTAINER_PORT = 8080 + #: TCP port where the container's port will be published HOST_PORT = 9999 + #: URL of the CometD service COMETD_URL = f"http://localhost:{HOST_PORT}/cometd" + #: CometD connection type CONNECTION_TYPE = None + #: container instance container = None + #: name of the chat room CHAT_ROOM = "demo" + #: channel where the room's messages get published CHAT_ROOM_CHANNEL = "/chat/" + CHAT_ROOM + #: channel where the room's memeber get published MEMBERS_CHANGED_CHANNEL = "/members/" + CHAT_ROOM + #: channel for adding user's to the room MEMBERS_CHANNEL = "/service/members" + #: name of the first user USER_NAME1 = "user1" + #: name of the second user USER_NAME2 = "user2" def setUp(self): - self.container.ensure_reacheable() + self.container.ensure_reachable() @classmethod def setUpClass(cls): @@ -214,7 +226,7 @@ async def test_revover_on_server_restart(self): await asyncio.sleep(3) # start the service - self.container.ensure_reacheable() + self.container.ensure_reachable() # give a few seconds for the client to recover the connection await asyncio.sleep(3) @@ -247,7 +259,7 @@ async def test_recover_on_temporary_network_outage(self): await asyncio.sleep(3) # start the service - self.container.ensure_reacheable() + self.container.ensure_reachable() # give a few seconds for the client to recover the connection await asyncio.sleep(3) From 4d5c8480aef7e0a7929fd89190f3088913f74758 Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Wed, 28 Nov 2018 08:22:39 +0100 Subject: [PATCH 08/17] Run flake8 and pylint during CI testing --- setup.py | 9 ++++----- tox.ini | 2 ++ 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 504b51e..8b5ba51 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,9 @@ TESTS_REQUIRE = [ "asynctest>=0.12.0,<1.0.0", "coverage>=4.5,<5.0", - "docker>=3.5.1" + "docker>=3.5.1", + "flake8", + "pylint" ] DOCS_REQUIRE = [ "Sphinx>=1.7,<2.0", @@ -18,10 +20,7 @@ EXAMPLES_REQUIRE = [ "aioconsole>=0.1.7,<1.0.0" ] -DEV_REQUIRE = [ - "flake8", - "pylint" -] +DEV_REQUIRE = [] def read(file_path): diff --git a/tox.ini b/tox.ini index 74125f4..1327bee 100644 --- a/tox.ini +++ b/tox.ini @@ -10,3 +10,5 @@ commands = coverage run -m unittest discover tests/unit python -m unittest discover tests/integration coveralls + flake8 + pylint aiocometd From 73930953a4c3e367b7971a5b40511933ff8f135b Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Wed, 28 Nov 2018 08:37:55 +0100 Subject: [PATCH 09/17] Add .tox to the list of directories ignored by flake8 --- setup.cfg | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 46409dc..ab70fd7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,8 +7,10 @@ exclude = # sphinx conf file docs/source/conf.py, # virtual enviroments - venv - env + venv, + env, + # tox + .tox [coverage:run] source = From e3f9fbf458198e5bc2bfdce292b4feb3f3810a44 Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 4 Jan 2019 16:47:47 +0100 Subject: [PATCH 10/17] Add type hints to the transports subpackage --- aiocometd/_typing.py | 25 +++ aiocometd/client.py | 17 +- aiocometd/exceptions.py | 16 +- aiocometd/extensions.py | 11 +- aiocometd/transports/abc.py | 56 ++--- aiocometd/transports/base.py | 211 +++++++++--------- aiocometd/transports/long_polling.py | 25 ++- aiocometd/transports/registry.py | 26 ++- aiocometd/transports/websocket.py | 74 +++--- aiocometd/utils.py | 14 +- setup.py | 3 +- tests/unit/test_exceptions.py | 11 +- tests/unit/test_transports/test_base.py | 47 ++-- .../unit/test_transports/test_long_polling.py | 44 ++++ tests/unit/test_transports/test_websocket.py | 6 +- 15 files changed, 354 insertions(+), 232 deletions(-) create mode 100644 aiocometd/_typing.py diff --git a/aiocometd/_typing.py b/aiocometd/_typing.py new file mode 100644 index 0000000..f4296d9 --- /dev/null +++ b/aiocometd/_typing.py @@ -0,0 +1,25 @@ +"""Type definitions""" +from typing import List, Union, Callable, Awaitable, Any, Dict +import ssl as ssl_module + +import aiohttp + +from .constants import ConnectionType + + +#: Coroutine function +CoroFunction = Callable[..., Awaitable[Any]] +#: JSON object value +JsonObject = Dict[str, Any] +#: JSON serializer function +JsonDumper = Callable[[JsonObject], str] +#: JSON deserializer function +JsonLoader = Callable[[str], JsonObject] +#: Message payload (list of messages) +Payload = List[JsonObject] +#: Header values +Headers = Dict[str, str] +#: Connection type specification +ConnectionTypeSpec = Union[ConnectionType, List[ConnectionType]] +#: SSL validation mode +SSLValidationMode = Union[ssl_module.SSLContext, aiohttp.Fingerprint, bool] diff --git a/aiocometd/client.py b/aiocometd/client.py index f342f89..b23f628 100644 --- a/aiocometd/client.py +++ b/aiocometd/client.py @@ -5,25 +5,22 @@ from collections import abc from contextlib import suppress import json -from typing import Optional, List, Union, Set, AsyncIterator, Type -import ssl as ssl_module +from typing import Optional, List, Union, Set, AsyncIterator, Type, Any from types import TracebackType -import aiohttp - from .transports import create_transport from .transports.abc import Transport from .constants import DEFAULT_CONNECTION_TYPE, ConnectionType, MetaChannel, \ SERVICE_CHANNEL_PREFIX, TransportState from .exceptions import ServerError, ClientInvalidOperation, \ TransportTimeoutError, ClientError -from .utils import is_server_error_message, JsonObject, JsonDumper, JsonLoader +from .utils import is_server_error_message from .extensions import Extension, AuthExtension +from ._typing import ConnectionTypeSpec, SSLValidationMode, JsonObject, \ + JsonDumper, JsonLoader LOGGER = logging.getLogger(__name__) -ConnectionTypeSpec = Union[ConnectionType, List[ConnectionType]] -SSLValidationMode = Union[ssl_module.SSLContext, aiohttp.Fingerprint, bool] class Client: # pylint: disable=too-many-instance-attributes @@ -93,7 +90,7 @@ def __init__(self, url: str, self._connection_types = self._DEFAULT_CONNECTION_TYPES self._loop = loop or asyncio.get_event_loop() #: queue for consuming incoming event messages - self._incoming_queue: Optional[asyncio.Queue] = None + self._incoming_queue: "Optional[asyncio.Queue[JsonObject]]" = None #: transport object self._transport: Optional[Transport] = None #: marks whether the client is open or closed @@ -394,7 +391,7 @@ async def receive(self) -> JsonObject: raise ClientInvalidOperation("The client is closed and there are " "no pending messages.") - async def __aiter__(self) -> AsyncIterator: + async def __aiter__(self) -> AsyncIterator[JsonObject]: """Asynchronous iterator :raise ServerError: If the client receives a confirmation message \ @@ -444,7 +441,7 @@ async def _get_message(self, connection_timeout: Union[int, float]) \ connection with the server in :obj:`connection_timeout` time. :raise ServerError: If the connection gets closed by the server. """ - tasks = [] + tasks: List[asyncio.Future[Any]] = [] # task waiting on connection timeout if connection_timeout: timeout_task = asyncio.ensure_future( diff --git a/aiocometd/exceptions.py b/aiocometd/exceptions.py index b8996d1..a437e7e 100644 --- a/aiocometd/exceptions.py +++ b/aiocometd/exceptions.py @@ -11,7 +11,7 @@ TransportConnectionClosed ServerError """ -from typing import Optional, List +from typing import Optional, List, cast from . import utils @@ -43,7 +43,8 @@ class TransportConnectionClosed(TransportError): class ServerError(AiocometdException): """CometD server side error""" # pylint: disable=useless-super-delegation - def __init__(self, message: str, response: utils.JsonObject) -> None: + def __init__(self, message: str, response: Optional[utils.JsonObject]) \ + -> None: """If the *response* contains an error field it gets parsed according to the \ `specs `_ @@ -58,16 +59,21 @@ def __init__(self, message: str, response: utils.JsonObject) -> None: @property def message(self) -> str: """Error description""" - return self.args[0] # pylint: disable=unsubscriptable-object + # pylint: disable=unsubscriptable-object + return cast(str, self.args[0]) + # pylint: enable=unsubscriptable-object @property - def response(self) -> utils.JsonObject: + def response(self) -> Optional[utils.JsonObject]: """Server response message""" - return self.args[1] # pylint: disable=unsubscriptable-object + return cast(Optional[utils.JsonObject], + self.args[1]) # pylint: disable=unsubscriptable-object @property def error(self) -> Optional[str]: """Error field in the :obj:`response`""" + if self.response is None: + return None return self.response.get("error") @property diff --git a/aiocometd/extensions.py b/aiocometd/extensions.py index a30490f..a0281cd 100644 --- a/aiocometd/extensions.py +++ b/aiocometd/extensions.py @@ -1,15 +1,14 @@ """Extension classes""" from abc import ABC, abstractmethod -from typing import List, Dict, Optional +from typing import Optional -from .utils import JsonObject +from ._typing import Payload, Headers class Extension(ABC): """Defines operations supported by extensions""" @abstractmethod - async def outgoing(self, payload: List[JsonObject], - headers: Dict[str, str]) -> None: + async def outgoing(self, payload: Payload, headers: Headers) -> None: """Process outgoing *payload* and *headers* Called just before a payload is sent @@ -19,8 +18,8 @@ async def outgoing(self, payload: List[JsonObject], """ @abstractmethod - async def incoming(self, payload: List[JsonObject], - headers: Optional[Dict[str, str]] = None) -> None: + async def incoming(self, payload: Payload, + headers: Optional[Headers] = None) -> None: """Process incoming *payload* and *headers* Called just after a payload is received diff --git a/aiocometd/transports/abc.py b/aiocometd/transports/abc.py index 5a25229..fb0586d 100644 --- a/aiocometd/transports/abc.py +++ b/aiocometd/transports/abc.py @@ -1,58 +1,67 @@ """Transport abstract base class definition""" from abc import ABC, abstractmethod +from typing import Set, Optional, List + +from ..constants import ConnectionType, TransportState +from .._typing import JsonObject class Transport(ABC): """Defines the operations that all transport classes should support""" @property @abstractmethod - def connection_type(self): + def connection_type(self) -> ConnectionType: """The transport's connection type""" @property @abstractmethod - def endpoint(self): + def endpoint(self) -> str: """CometD service url""" @property @abstractmethod - def client_id(self): + def client_id(self) -> Optional[str]: """Clinet id value assigned by the server""" @property @abstractmethod - def state(self): + def state(self) -> TransportState: """Current state of the transport""" @property @abstractmethod - def subscriptions(self): + def subscriptions(self) -> Set[str]: """Set of subscribed channels""" @property @abstractmethod - def last_connect_result(self): + def last_connect_result(self) -> Optional[JsonObject]: """Result of the last connect request""" + @property + @abstractmethod + def reconnect_advice(self) -> JsonObject: + """Reconnection advice parameters returned by the server""" + @abstractmethod - async def handshake(self, connection_types): + async def handshake(self, connection_types: List[ConnectionType]) \ + -> JsonObject: """Executes the handshake operation - :param list[ConnectionType] connection_types: list of connection types + :param connection_types: list of connection types :return: Handshake response - :rtype: dict :raises TransportError: When the network request fails. """ @abstractmethod - async def connect(self): + async def connect(self) -> JsonObject: """Connect to the server The transport will try to start and maintain a continuous connection with the server, but it'll return with the response of the first successful connection as soon as possible. - :return dict: The response of the first successful connection. + :return: The response of the first successful connection. :raise TransportInvalidOperation: If the transport doesn't has a \ client id yet, or if it's not in a :obj:`~TransportState.DISCONNECTED`\ :obj:`state`. @@ -60,7 +69,7 @@ async def connect(self): """ @abstractmethod - async def disconnect(self): + async def disconnect(self) -> None: """Disconnect from server The disconnect message is only sent to the server if the transport is @@ -68,16 +77,15 @@ async def disconnect(self): """ @abstractmethod - async def close(self): + async def close(self) -> None: """Close transport and release resources""" @abstractmethod - async def subscribe(self, channel): + async def subscribe(self, channel: str) -> JsonObject: """Subscribe to *channel* - :param str channel: Name of the channel + :param channel: Name of the channel :return: Subscribe response - :rtype: dict :raise TransportInvalidOperation: If the transport is not in the \ :obj:`~TransportState.CONNECTED` or :obj:`~TransportState.CONNECTING` \ :obj:`state` @@ -85,12 +93,11 @@ async def subscribe(self, channel): """ @abstractmethod - async def unsubscribe(self, channel): + async def unsubscribe(self, channel: str) -> JsonObject: """Unsubscribe from *channel* - :param str channel: Name of the channel + :param channel: Name of the channel :return: Unsubscribe response - :rtype: dict :raise TransportInvalidOperation: If the transport is not in the \ :obj:`~TransportState.CONNECTED` or :obj:`~TransportState.CONNECTING` \ :obj:`state` @@ -98,13 +105,12 @@ async def unsubscribe(self, channel): """ @abstractmethod - async def publish(self, channel, data): + async def publish(self, channel: str, data: JsonObject) -> JsonObject: """Publish *data* to the given *channel* - :param str channel: Name of the channel - :param dict data: Data to send to the server + :param channel: Name of the channel + :param data: Data to send to the server :return: Publish response - :rtype: dict :raise TransportInvalidOperation: If the transport is not in the \ :obj:`~TransportState.CONNECTED` or :obj:`~TransportState.CONNECTING` \ :obj:`state` @@ -112,8 +118,8 @@ async def publish(self, channel, data): """ @abstractmethod - async def wait_for_state(self, state): + async def wait_for_state(self, state: TransportState) -> None: """Waits for and returns when the transport enters the given *state* - :param TransportState state: A state value + :param state: A state value """ diff --git a/aiocometd/transports/base.py b/aiocometd/transports/base.py index 1968c07..e9d4f2c 100644 --- a/aiocometd/transports/base.py +++ b/aiocometd/transports/base.py @@ -4,16 +4,20 @@ from abc import abstractmethod from contextlib import suppress import json +from typing import Union, Optional, List, Set, Awaitable, Any import aiohttp from .abc import Transport -from ..constants import MetaChannel, TransportState, HANDSHAKE_MESSAGE, \ - CONNECT_MESSAGE, DISCONNECT_MESSAGE, SUBSCRIBE_MESSAGE, \ - UNSUBSCRIBE_MESSAGE, PUBLISH_MESSAGE +from ..constants import ConnectionType, MetaChannel, TransportState, \ + HANDSHAKE_MESSAGE, CONNECT_MESSAGE, DISCONNECT_MESSAGE, \ + SUBSCRIBE_MESSAGE, UNSUBSCRIBE_MESSAGE, PUBLISH_MESSAGE from ..utils import defer, is_matching_response, is_auth_error_message, \ is_server_error_message, is_event_message from ..exceptions import TransportInvalidOperation, TransportError +from .._typing import SSLValidationMode, JsonObject, JsonLoader, JsonDumper, \ + Headers, Payload +from ..extensions import Extension, AuthExtension LOGGER = logging.getLogger(__name__) @@ -32,18 +36,24 @@ class TransportBase(Transport): # pylint: disable=too-many-instance-attributes #: The increase factor to use for request timeout REQUEST_TIMEOUT_INCREASE_FACTOR = 1.2 - def __init__(self, *, url, incoming_queue, client_id=None, - reconnection_timeout=1, ssl=None, extensions=None, auth=None, - json_dumps=json.dumps, json_loads=json.loads, - reconnect_advice=None, loop=None): + def __init__(self, *, url: str, + incoming_queue: "asyncio.Queue[JsonObject]", + client_id: Optional[str] = None, + reconnection_timeout: Union[int, float] = 1, + ssl: Optional[SSLValidationMode] = None, + extensions: Optional[List[Extension]] = None, + auth: Optional[AuthExtension] = None, + json_dumps: JsonDumper = json.dumps, + json_loads: JsonLoader = json.loads, + reconnect_advice: Optional[JsonObject] = None, + loop: Optional[asyncio.AbstractEventLoop] = None) -> None: """ - :param str url: CometD service url - :param asyncio.Queue incoming_queue: Queue for consuming incoming event + :param url: CometD service url + :param incoming_queue: Queue for consuming incoming event messages - :param str client_id: Clinet id value assigned by the server + :param client_id: Clinet id value assigned by the server :param reconnection_timeout: The time to wait before trying to \ reconnect to the server after a network failure - :type reconnection_timeout: None or int or float :param ssl: SSL validation mode. None for default SSL check \ (:func:`ssl.create_default_context` is used), False for skip SSL \ certificate validation, \ @@ -52,16 +62,12 @@ def __init__(self, *, url, incoming_queue, client_id=None, validation, :obj:`ssl.SSLContext` for custom SSL certificate \ validation. :param extensions: List of protocol extension objects - :type extensions: list[Extension] or None - :param AuthExtension auth: An auth extension + :param auth: An auth extension :param json_dumps: Function for JSON serialization, the default is \ :func:`json.dumps` - :type json_dumps: :func:`callable` :param json_loads: Function for JSON deserialization, the default is \ :func:`json.loads` - :type json_loads: :func:`callable` :param reconnect_advice: Initial reconnect advice - :type reconnect_advice: dict or None :param loop: Event :obj:`loop ` used to schedule tasks. If *loop* is ``None`` then :func:`asyncio.get_event_loop` is used to get the default @@ -79,11 +85,9 @@ def __init__(self, *, url, incoming_queue, client_id=None, #: session self._message_id = 0 #: reconnection advice parameters returned by the server - self.reconnect_advice = reconnect_advice - if self.reconnect_advice is None: - self.reconnect_advice = {} + self._reconnect_advice: JsonObject = reconnect_advice or dict() #: set of subscribed channels - self._subscriptions = set() + self._subscriptions: Set[str] = set() #: boolean to mark whether to resubscribe on connect self._subscribe_on_connect = False #: dictionary of TransportState and asyncio.Event pairs @@ -91,13 +95,13 @@ def __init__(self, *, url, incoming_queue, client_id=None, #: current state of the transport self._state = TransportState.DISCONNECTED #: asyncio connection task - self._connect_task = None + self._connect_task: Optional[asyncio.Future[JsonObject]] = None #: time to wait before reconnecting after a network failure self._reconnect_timeout = reconnection_timeout #: SSL validation mode self.ssl = ssl #: http session - self._http_session = None + self._http_session: Optional[aiohttp.ClientSession] = None #: List of protocol extension objects self._extensions = extensions or [] #: An auth extension @@ -107,7 +111,7 @@ def __init__(self, *, url, incoming_queue, client_id=None, #: Function for JSON deserialization self._json_loads = json_loads - async def _get_http_session(self): + async def _get_http_session(self) -> aiohttp.ClientSession: """Factory method for getting the current HTTP session :return: The current session if it's not None, otherwise it creates a @@ -123,7 +127,7 @@ async def _get_http_session(self): ) return self._http_session - async def _close_http_session(self): + async def _close_http_session(self) -> None: """Close the http session if it's not already closed""" # graceful shutdown recommended by the documentation # https://aiohttp.readthedocs.io/en/stable/client_advanced.html\ @@ -133,86 +137,90 @@ async def _close_http_session(self): await asyncio.sleep(self._HTTP_SESSION_CLOSE_TIMEOUT) @property - def connection_type(self): + def connection_type(self) -> ConnectionType: # pragma: no cover """The transport's connection type""" - return None # pragma: no cover + return None # type: ignore @property - def endpoint(self): + def endpoint(self) -> str: """CometD service url""" return self._url @property - def client_id(self): + def client_id(self) -> Optional[str]: """clinet id value assigned by the server""" return self._client_id @property - def subscriptions(self): + def subscriptions(self) -> Set[str]: """Set of subscribed channels""" return self._subscriptions @property - def last_connect_result(self): + def last_connect_result(self) -> Optional[JsonObject]: """Result of the last connect request""" if self._connect_task and self._connect_task.done(): return self._connect_task.result() return None @property - def state(self): + def reconnect_advice(self) -> JsonObject: + """Reconnection advice parameters returned by the server""" + return self._reconnect_advice + + @property + def state(self) -> TransportState: """Current state of the transport""" return self._state @property - def _state(self): + def _state(self) -> TransportState: """Current state of the transport""" - return self.__dict__.setdefault("_state") + return self.__dict__.get("_state", TransportState.DISCONNECTED) @_state.setter - def _state(self, value): + def _state(self, value: TransportState) -> None: self._set_state_event(self._state, value) self.__dict__["_state"] = value @property - def request_timeout(self): + def request_timeout(self) -> Optional[Union[int, float]]: """Number of seconds after a network request should time out""" timeout = self.reconnect_advice.get("timeout") - if timeout: + if isinstance(timeout, (int, float)): # convert milliseconds to seconds timeout /= 1000 # increase the timeout specified by the server to avoid timing out # by mistake - timeout *= self.__class__.REQUEST_TIMEOUT_INCREASE_FACTOR - return timeout + timeout *= self.REQUEST_TIMEOUT_INCREASE_FACTOR + return timeout + return None - def _set_state_event(self, old_state, new_state): + def _set_state_event(self, old_state: TransportState, + new_state: TransportState) -> None: """Set event associated with the *new_state* and clear the event for the *old_state* :param old_state: Old state value - :type old_state: TransportState or None :param new_state: New state value - :type new_state: TransportState """ if new_state != old_state: - if old_state is not None: - self._state_events[old_state].clear() + self._state_events[old_state].clear() self._state_events[new_state].set() - async def wait_for_state(self, state): + async def wait_for_state(self, state: TransportState) -> None: """Waits for and returns when the transport enters the given *state* :param TransportState state: A state value """ await self._state_events[state].wait() - async def handshake(self, connection_types): + async def handshake(self, connection_types: List[ConnectionType]) \ + -> JsonObject: """Executes the handshake operation - :param list[ConnectionType] connection_types: list of connection types + :param connection_types: list of connection types :return: Handshake response - :rtype: dict :raises TransportError: When the network request fails. """ # reset message id for a new client session @@ -237,11 +245,11 @@ async def handshake(self, connection_types): self._subscribe_on_connect = True return response_message - def _finalize_message(self, message): + def _finalize_message(self, message: JsonObject) -> None: """Update the ``id``, ``clientId`` and ``connectionType`` message fields as a side effect if they're are present in the *message*. - :param dict message: Outgoing message + :param message: Outgoing message """ if "id" in message: message["id"] = str(self._message_id) @@ -253,14 +261,14 @@ def _finalize_message(self, message): if "connectionType" in message: message["connectionType"] = self.connection_type.value - def _finalize_payload(self, payload): + def _finalize_payload(self, payload: Union[JsonObject, Payload]) \ + -> None: """Update the ``id``, ``clientId`` and ``connectionType`` message fields in the *payload*, as a side effect if they're are present in the *message*. The *payload* can be either a single message or a list of messages. :param payload: A message or a list of messages - :type payload: dict or list[dict] """ if isinstance(payload, list): for item in payload: @@ -268,20 +276,21 @@ def _finalize_payload(self, payload): else: self._finalize_message(payload) - async def _send_message(self, message, **kwargs): + async def _send_message(self, message: JsonObject, **kwargs: Any) \ + -> JsonObject: """Send message to server - :param dict message: A message + :param message: A message :param kwargs: Optional key-value pairs that'll be used to update the \ the values in the *message* :return: Response message - :rtype: dict :raises TransportError: When the network request fails. """ message.update(kwargs) return await self._send_payload_with_auth([message]) - async def _send_payload_with_auth(self, payload): + async def _send_payload_with_auth(self, payload: Payload) \ + -> JsonObject: """Finalize and send *payload* to server and retry on authentication failure @@ -289,9 +298,8 @@ async def _send_payload_with_auth(self, payload): response message can be provided for the first message in the *payload*. - :param list[dict] payload: A list of messages + :param payload: A list of messages :return: The response message for the first message in the *payload* - :rtype: dict :raises TransportError: When the network request fails. """ response = await self._send_payload(payload) @@ -305,30 +313,30 @@ async def _send_payload_with_auth(self, payload): # otherwise return the response return response - async def _send_payload(self, payload): + async def _send_payload(self, payload: Payload) -> JsonObject: """Finalize and send *payload* to server Finalize and send the *payload* to the server and return once a response message can be provided for the first message in the *payload*. - :param list[dict] payload: A list of messages + :param payload: A list of messages :return: The response message for the first message in the *payload* - :rtype: dict :raises TransportError: When the network request fails. """ self._finalize_payload(payload) - headers = {} + headers: Headers = {} # process the outgoing payload with the extensions await self._process_outgoing_payload(payload, headers) # send the payload to the server return await self._send_final_payload(payload, headers=headers) - async def _process_outgoing_payload(self, payload, headers): + async def _process_outgoing_payload(self, payload: Payload, + headers: Headers) -> None: """Process the outgoing *payload* and *headers* with the extensions - :param list[dict] payload: A list of messages - :param dict headers: Headers to send + :param payload: A list of messages + :param headers: Headers to send """ for extension in self._extensions: await extension.outgoing(payload, headers) @@ -336,7 +344,8 @@ async def _process_outgoing_payload(self, payload, headers): await self._auth.outgoing(payload, headers) @abstractmethod - async def _send_final_payload(self, payload, *, headers): + async def _send_final_payload(self, payload: Payload, *, + headers: Headers) -> JsonObject: """Send *payload* to server Send the *payload* to the server and return once a @@ -347,14 +356,13 @@ async def _send_final_payload(self, payload, *, headers): consumers. To enqueue the received messages :meth:`_consume_payload` can be used. - :param list[dict] payload: A list of messages - :param dict headers: Headers to send + :param payload: A list of messages + :param headers: Headers to send :return: The response message for the first message in the *payload* - :rtype: dict :raises TransportError: When the network request fails. """ - async def _consume_message(self, response_message): + async def _consume_message(self, response_message: JsonObject) -> None: """Enqueue the *response_message* for consumers if it's a type of message that consumers should receive @@ -364,7 +372,7 @@ async def _consume_message(self, response_message): is_event_message(response_message)): await self.incoming_queue.put(response_message) - def _update_subscriptions(self, response_message): + def _update_subscriptions(self, response_message: JsonObject) -> None: """Update the set of subscriptions based on the *response_message* :param response_message: A response message @@ -387,33 +395,32 @@ def _update_subscriptions(self, response_message): response_message["subscription"] in self._subscriptions): self._subscriptions.remove(response_message["subscription"]) - async def _process_incoming_payload(self, payload, headers=None): + async def _process_incoming_payload(self, payload: Payload, + headers: Optional[Headers] = None) \ + -> None: """Process incoming *payload* and *headers* with the extensions :param payload: A list of response messages - :type payload: list[dict] :param headers: Received headers - :type headers: dict or None """ if self._auth: await self._auth.incoming(payload, headers) for extension in self._extensions: await extension.incoming(payload, headers) - async def _consume_payload(self, payload, *, headers=None, - find_response_for=None): + async def _consume_payload(self, payload: Payload, *, + headers: Optional[Headers] = None, + find_response_for: JsonObject) \ + -> Optional[JsonObject]: """Enqueue event messages for the consumers and update the internal state of the transport, based on response messages in the *payload*. :param payload: A list of response messages - :type payload: list[dict] :param headers: Received headers - :type headers: dict or None - :param dict find_response_for: Find and return the matching \ + :param find_response_for: Find and return the matching \ response message for the given *find_response_for* message. :return: The response message for the *find_response_for* message, \ otherwise ``None`` - :rtype: dict or None """ # process incoming payload and headers with the extensions await self._process_incoming_payload(payload, headers) @@ -424,7 +431,7 @@ async def _consume_payload(self, payload, *, headers=None, # if there is an advice in the message then update the transport's # reconnect advice if "advice" in message: - self.reconnect_advice = message["advice"] + self._reconnect_advice = message["advice"] # update subscriptions based on responses self._update_subscriptions(message) @@ -439,7 +446,8 @@ async def _consume_payload(self, payload, *, headers=None, await self._consume_message(message) return result - def _start_connect_task(self, coro): + def _start_connect_task(self, coro: Awaitable[JsonObject]) \ + -> Awaitable[JsonObject]: """Wrap the *coro* in a future and schedule it The future is stored internally in :obj:`_connect_task`. The future's @@ -452,7 +460,7 @@ def _start_connect_task(self, coro): self._connect_task.add_done_callback(self._connect_done) return self._connect_task - async def _stop_connect_task(self): + async def _stop_connect_task(self) -> None: """Stop the connection task If no connect task exists or if it's done it does nothing. @@ -461,14 +469,14 @@ async def _stop_connect_task(self): self._connect_task.cancel() await asyncio.wait([self._connect_task]) - async def connect(self): + async def connect(self) -> JsonObject: """Connect to the server The transport will try to start and maintain a continuous connection with the server, but it'll return with the response of the first successful connection as soon as possible. - :return dict: The response of the first successful connection. + :return: The response of the first successful connection. :raise TransportInvalidOperation: If the transport doesn't has a \ client id yet, or if it's not in a :obj:`~TransportState.DISCONNECTED`\ :obj:`state`. @@ -486,11 +494,10 @@ async def connect(self): self._state = TransportState.CONNECTING return await self._start_connect_task(self._connect()) - async def _connect(self): + async def _connect(self) -> JsonObject: """Connect to the server :return: Connect response - :rtype: dict :raises TransportError: When the network request fails. """ message = CONNECT_MESSAGE.copy() @@ -498,21 +505,21 @@ async def _connect(self): if self._subscribe_on_connect and self.subscriptions: for subscription in self.subscriptions: extra_message = SUBSCRIBE_MESSAGE.copy() - extra_message["subscription"] = subscription + extra_message["subscription"] = subscription # type: ignore payload.append(extra_message) result = await self._send_payload_with_auth(payload) self._subscribe_on_connect = not result["successful"] return result - def _connect_done(self, future): + def _connect_done(self, future: "asyncio.Future[JsonObject]") -> None: """Consume the result of the *future* and follow the server's \ connection advice if the transport is still connected - :param asyncio.Future future: A :obj:`_connect` or :obj:`handshake` \ + :param future: A :obj:`_connect` or :obj:`handshake` \ future """ try: - result = future.result() + result: Union[JsonObject, Exception] = future.result() reconnect_timeout = self.reconnect_advice["interval"] self._state = TransportState.CONNECTED except Exception as error: # pylint: disable=broad-except @@ -526,7 +533,8 @@ def _connect_done(self, future): if self.state != TransportState.DISCONNECTING: self._follow_advice(reconnect_timeout) - def _follow_advice(self, reconnect_timeout): + def _follow_advice(self, reconnect_timeout: Union[int, float, None]) \ + -> None: """Follow the server's reconnect advice Either a :obj:`_connect` or :obj:`handshake` operation is started @@ -559,7 +567,7 @@ def _follow_advice(self, reconnect_timeout): "will be scheduled.") self._state = TransportState.SERVER_DISCONNECTED - async def disconnect(self): + async def disconnect(self) -> None: """Disconnect from server The disconnect message is only sent to the server if the transport is @@ -577,16 +585,15 @@ async def disconnect(self): finally: self._state = TransportState.DISCONNECTED - async def close(self): + async def close(self) -> None: """Close transport and release resources""" await self._close_http_session() - async def subscribe(self, channel): + async def subscribe(self, channel: str) -> JsonObject: """Subscribe to *channel* - :param str channel: Name of the channel + :param channel: Name of the channel :return: Subscribe response - :rtype: dict :raise TransportInvalidOperation: If the transport is not in the \ :obj:`~TransportState.CONNECTED` or :obj:`~TransportState.CONNECTING` \ :obj:`state` @@ -599,12 +606,11 @@ async def subscribe(self, channel): return await self._send_message(SUBSCRIBE_MESSAGE.copy(), subscription=channel) - async def unsubscribe(self, channel): + async def unsubscribe(self, channel: str) -> JsonObject: """Unsubscribe from *channel* - :param str channel: Name of the channel + :param channel: Name of the channel :return: Unsubscribe response - :rtype: dict :raise TransportInvalidOperation: If the transport is not in the \ :obj:`~TransportState.CONNECTED` or :obj:`~TransportState.CONNECTING` \ :obj:`state` @@ -617,13 +623,12 @@ async def unsubscribe(self, channel): return await self._send_message(UNSUBSCRIBE_MESSAGE.copy(), subscription=channel) - async def publish(self, channel, data): + async def publish(self, channel: str, data: JsonObject) -> JsonObject: """Publish *data* to the given *channel* - :param str channel: Name of the channel - :param dict data: Data to send to the server + :param channel: Name of the channel + :param data: Data to send to the server :return: Publish response - :rtype: dict :raise TransportInvalidOperation: If the transport is not in the \ :obj:`~TransportState.CONNECTED` or :obj:`~TransportState.CONNECTING` \ :obj:`state` diff --git a/aiocometd/transports/long_polling.py b/aiocometd/transports/long_polling.py index 92eace5..a88eff8 100644 --- a/aiocometd/transports/long_polling.py +++ b/aiocometd/transports/long_polling.py @@ -1,13 +1,15 @@ """Long polling transport class definition""" import asyncio import logging +from typing import Any import aiohttp from ..constants import ConnectionType from .registry import register_transport -from .base import TransportBase +from .base import TransportBase, Payload, Headers from ..exceptions import TransportError +from .._typing import JsonObject LOGGER = logging.getLogger(__name__) @@ -16,13 +18,14 @@ @register_transport(ConnectionType.LONG_POLLING) class LongPollingTransport(TransportBase): """Long-polling type transport""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) #: semaphore to limit the number of concurrent HTTP connections to 2 self._http_semaphore = asyncio.Semaphore(2, loop=self._loop) - async def _send_final_payload(self, payload, *, headers): + async def _send_final_payload(self, payload: Payload, *, + headers: Headers) -> JsonObject: try: session = await self._get_http_session() async with self._http_semaphore: @@ -34,5 +37,15 @@ async def _send_final_payload(self, payload, *, headers): except aiohttp.client_exceptions.ClientError as error: LOGGER.warning("Failed to send payload, %s", error) raise TransportError(str(error)) from error - return await self._consume_payload(response_payload, headers=headers, - find_response_for=payload[0]) + response_message = await self._consume_payload( + response_payload, + headers=headers, + find_response_for=payload[0] + ) + + if response_message is None: + error_message = "No response message received for the " \ + "first message in the payload" + LOGGER.warning(error_message) + raise TransportError(error_message) + return response_message diff --git a/aiocometd/transports/registry.py b/aiocometd/transports/registry.py index 2e03f39..767c237 100644 --- a/aiocometd/transports/registry.py +++ b/aiocometd/transports/registry.py @@ -1,43 +1,47 @@ """Functions for transport class registration and instantiation""" +from typing import Type, Callable, Any + from ..exceptions import TransportInvalidOperation +from .abc import Transport +from ..constants import ConnectionType TRANSPORT_CLASSES = {} -def register_transport(conn_type): +def register_transport(conn_type: ConnectionType) \ + -> Callable[[Type[Transport]], Type[Transport]]: """Class decorator for registering transport classes The class' connection_type property will be also defined to return the given *connection_type* - :param ConnectionType conn_type: A connection type + :param conn_type: A connection type :return: The updated class """ # pylint: disable=unused-argument, missing-docstring - def decorator(cls): + def decorator(cls: Type[Transport]) -> Type[Transport]: TRANSPORT_CLASSES[conn_type] = cls - @property - def connection_type(instance): + @property # type: ignore + def connection_type(self: Transport) -> ConnectionType: return conn_type - cls.connection_type = connection_type + cls.connection_type = connection_type # type: ignore return cls return decorator -def create_transport(connection_type, *args, **kwargs): +def create_transport(connection_type: ConnectionType, *args: Any, + **kwargs: Any) -> Transport: """Create a transport object that can be used for the given *connection_type* - :param ConnectionType connection_type: A connection type - :param args: Positional arguments to pass to the transport + :param connection_type: A connection type :param kwargs: Keyword arguments to pass to the transport :return: A transport object - :rtype: Transport """ if connection_type not in TRANSPORT_CLASSES: raise TransportInvalidOperation("There is no transport for connection " "type {!r}".format(connection_type)) - return TRANSPORT_CLASSES[connection_type](*args, **kwargs) + return TRANSPORT_CLASSES[connection_type](*args, **kwargs) # type: ignore diff --git a/aiocometd/transports/websocket.py b/aiocometd/transports/websocket.py index c942959..3deb9f4 100644 --- a/aiocometd/transports/websocket.py +++ b/aiocometd/transports/websocket.py @@ -2,16 +2,26 @@ import asyncio import logging from contextlib import suppress +from typing import Callable, Optional, AsyncContextManager, Any, Awaitable, \ + cast import aiohttp +import aiohttp.client_ws from ..constants import ConnectionType, MetaChannel from .registry import register_transport -from .base import TransportBase +from .base import TransportBase, Payload, Headers from ..exceptions import TransportError, TransportConnectionClosed +from .._typing import JsonObject LOGGER = logging.getLogger(__name__) +#: Asynchronous factory function of ClientSessions +AsyncSessionFactory = Callable[[], Awaitable[aiohttp.ClientSession]] +#: Web socket type +WebSocket = aiohttp.client_ws.ClientWebSocketResponse +#: Context manager type managing a WebSocket +WebSocketContextManager = AsyncContextManager[WebSocket] class WebSocketFactory: # pylint: disable=too-few-public-methods @@ -20,30 +30,27 @@ class WebSocketFactory: # pylint: disable=too-few-public-methods This class allows the usage of factory objects without context blocks """ - def __init__(self, session_factory): + def __init__(self, session_factory: AsyncSessionFactory): """ - :param asyncio.coroutine session_factory: Coroutine factory function \ + :param session_factory: Coroutine factory function \ which returns an HTTP session """ self._session_factory = session_factory - self._context = None - self._socket = None + self._context: Optional[WebSocketContextManager] = None + self._socket: Optional[WebSocket] = None - async def close(self): + async def close(self) -> None: """Close the factory""" with suppress(Exception): await self._exit() - async def __call__(self, *args, **kwargs): + async def __call__(self, *args: Any, **kwargs: Any) -> WebSocket: """Create a new factory object or returns a previously created one if it's not closed :param args: positional arguments for the ws_connect function :param kwargs: keyword arguments for the ws_connect function :return: Websocket object - :rtype: `aiohttp.ClientWebSocketResponse \ - `_ """ # if a the factory object already exists and if it's in closed state # exit the context manager and clear the references @@ -53,21 +60,22 @@ async def __call__(self, *args, **kwargs): # if there is no factory object, then create it and enter the \ # context manager to initialize it if self._socket is None: - await self._enter(*args, **kwargs) + self._socket = await self._enter(*args, **kwargs) return self._socket - async def _enter(self, *args, **kwargs): + async def _enter(self, *args: Any, **kwargs: Any) -> WebSocket: """Enter factory context :param args: positional arguments for the ws_connect function :param kwargs: keyword arguments for the ws_connect function + :return: Websocket object """ session = await self._session_factory() self._context = session.ws_connect(*args, **kwargs) - self._socket = await self._context.__aenter__() + return await self._context.__aenter__() - async def _exit(self): + async def _exit(self) -> None: """Exit factory context""" if self._context: await self._context.__aexit__(None, None, None) @@ -78,8 +86,8 @@ async def _exit(self): class WebSocketTransport(TransportBase): """WebSocket type transport""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) #: channels used during the connect task, requests on these channels #: are usually long running self._long_duration_channels = (MetaChannel.HANDSHAKE, @@ -95,19 +103,17 @@ def __init__(self, *args, **kwargs): #: exclusive lock for the objects created by _socket_factory_long self._socket_lock_long = asyncio.Lock() - async def _get_socket(self, channel, headers): + async def _get_socket(self, channel: str, headers: Headers) \ + -> WebSocket: """Get a websocket object for the given *channel* Returns different websocket objects for long running and short duration requests, so while a long running request is pending, short duration requests can be transmitted. - :param str channel: CometD channel name - :param dict headers: Headers to send + :param channel: CometD channel name + :param headers: Headers to send :return: Websocket object - :rtype: `aiohttp.ClientWebSocketResponse \ - `_ """ if channel in self._long_duration_channels: factory = self._socket_factory_long @@ -117,25 +123,25 @@ async def _get_socket(self, channel, headers): return await factory(self.endpoint, ssl=self.ssl, headers=headers, receive_timeout=self.request_timeout) - def _get_socket_lock(self, channel): + def _get_socket_lock(self, channel: str) -> asyncio.Lock: """Get an exclusive lock object for the given *channel* - :param str channel: CometD channel name + :param channel: CometD channel name :return: lock object for the *channel* - :rtype: asyncio.Lock """ if channel in self._long_duration_channels: return self._socket_lock_long return self._socket_lock_short - async def _reset_sockets(self): + async def _reset_sockets(self) -> None: """Close all socket factories and recreate them""" await self._socket_factory_short.close() self._socket_factory_short = WebSocketFactory(self._get_http_session) await self._socket_factory_long.close() self._socket_factory_long = WebSocketFactory(self._get_http_session) - async def _send_final_payload(self, payload, *, headers): + async def _send_final_payload(self, payload: Payload, *, + headers: Headers) -> JsonObject: try: # the channel of the first message channel = payload[0]["channel"] @@ -164,16 +170,13 @@ async def _send_final_payload(self, payload, *, headers): LOGGER.warning("Failed to send payload, %s", error) raise TransportError(str(error)) from error - async def _send_socket_payload(self, socket, payload): + async def _send_socket_payload(self, socket: WebSocket, + payload: Payload) -> JsonObject: """Send *payload* to the server on the given *socket* :param socket: WebSocket object - :type socket: `aiohttp.ClientWebSocketResponse \ - `_ - :param list[dict] payload: A message or a list of messages + :param payload: A message or a list of messages :return: Response payload - :rtype: list[dict] :raises TransportError: When the request fails. :raises TransportConnectionClosed: When the *socket* receives a CLOSE \ message instead of the expected response @@ -186,7 +189,8 @@ async def _send_socket_payload(self, socket, payload): if response.type == aiohttp.WSMsgType.CLOSE: raise TransportConnectionClosed("Received CLOSE message on " "the factory.") - response_payload = response.json(loads=self._json_loads) + response_payload = cast(Payload, + response.json(loads=self._json_loads)) matching_response = await self._consume_payload( response_payload, headers=None, @@ -194,7 +198,7 @@ async def _send_socket_payload(self, socket, payload): if matching_response: return matching_response - async def close(self): + async def close(self) -> None: await self._socket_factory_short.close() await self._socket_factory_long.close() await super().close() diff --git a/aiocometd/utils.py b/aiocometd/utils.py index c14fd6a..f3c8cf1 100644 --- a/aiocometd/utils.py +++ b/aiocometd/utils.py @@ -3,15 +3,10 @@ import asyncio from functools import wraps from http import HTTPStatus -from typing import Union, Optional, List, Awaitable, Callable, Dict, Any +from typing import Union, Optional, List, Any from .constants import META_CHANNEL_PREFIX, SERVICE_CHANNEL_PREFIX - - -CoroFunction = Callable[..., Awaitable[Any]] -JsonObject = Dict[str, Any] -JsonDumper = Callable[[JsonObject], str] -JsonLoader = Callable[[str], JsonObject] +from ._typing import CoroFunction, JsonObject def defer(coro_func: CoroFunction, delay: Union[int, float, None] = None, *, @@ -25,9 +20,10 @@ def defer(coro_func: CoroFunction, delay: Union[int, float, None] = None, *, :return: Coroutine function wrapper """ @wraps(coro_func) - async def wrapper(*args, **kwargs): # pylint: disable=missing-docstring + async def wrapper(*args: Any, **kwargs: Any) -> Any: \ + # pylint: disable=missing-docstring if delay: - await asyncio.sleep(delay, loop=loop) + await asyncio.sleep(delay, loop=loop) # type: ignore return await coro_func(*args, **kwargs) return wrapper diff --git a/setup.py b/setup.py index 33c5356..f582be3 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,8 @@ "coverage>=4.5,<5.0", "docker>=3.5.1", "flake8", - "pylint" + "pylint", + "mypy" ] DOCS_REQUIRE = [ "Sphinx>=1.7,<2.0", diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py index 2eabc31..07206b1 100644 --- a/tests/unit/test_exceptions.py +++ b/tests/unit/test_exceptions.py @@ -29,7 +29,16 @@ def test_properties_on_no_error(self): self.assertEqual(error.message, "description message") self.assertEqual(error.response, response) - self.assertEqual(error.error, None) + self.assertIsNone(error.error) + + def test_properties_on_no_response(self): + response = None + + error = ServerError("description message", response) + + self.assertEqual(error.message, "description message") + self.assertIsNone(error.response) + self.assertIsNone(error.error) @mock.patch("aiocometd.exceptions.utils") def test_error_code(self, utils): diff --git a/tests/unit/test_transports/test_base.py b/tests/unit/test_transports/test_base.py index d86add3..1afa9f2 100644 --- a/tests/unit/test_transports/test_base.py +++ b/tests/unit/test_transports/test_base.py @@ -653,6 +653,10 @@ async def test__connect_subscribe_on_connect_error(self): [CONNECT_MESSAGE] + additional_messages) self.assertTrue(self.transport._subscribe_on_connect) + def test__state_initially_disconnected(self): + self.assertIs(self.transport._state, + TransportState.DISCONNECTED) + def test_state(self): self.assertIs(self.transport.state, self.transport._state) @@ -775,7 +779,7 @@ async def test_connect_done_with_result(self): task.set_result("result") self.transport._follow_advice = mock.MagicMock() self.transport._state = TransportState.CONNECTING - self.transport.reconnect_advice = { + self.transport._reconnect_advice = { "interval": 1, "reconnect": "retry" } @@ -797,7 +801,7 @@ async def test_connect_done_with_error(self): task.set_exception(error) self.transport._follow_advice = mock.MagicMock() self.transport._state = TransportState.CONNECTED - self.transport.reconnect_advice = { + self.transport._reconnect_advice = { "interval": 1, "reconnect": "retry" } @@ -819,7 +823,7 @@ async def test_connect_dont_follow_advice_on_disconnecting(self): task.set_exception(error) self.transport._follow_advice = mock.MagicMock() self.transport._state = TransportState.DISCONNECTING - self.transport.reconnect_advice = { + self.transport._reconnect_advice = { "interval": 1, "reconnect": "retry" } @@ -836,7 +840,7 @@ async def test_connect_dont_follow_advice_on_disconnecting(self): @mock.patch("aiocometd.transports.base.defer") def test_follow_advice_handshake(self, defer): - self.transport.reconnect_advice = { + self.transport._reconnect_advice = { "interval": 1, "reconnect": "handshake" } @@ -858,7 +862,7 @@ def test_follow_advice_handshake(self, defer): @mock.patch("aiocometd.transports.base.defer") def test_follow_advice_retry(self, defer): - self.transport.reconnect_advice = { + self.transport._reconnect_advice = { "interval": 1, "reconnect": "retry" } @@ -883,7 +887,7 @@ def test_follow_advice_none(self, defer): advices = ["none", "", None] for advice in advices: self.transport._state = TransportState.CONNECTED - self.transport.reconnect_advice = { + self.transport._reconnect_advice = { "interval": 1, "reconnect": advice } @@ -1194,15 +1198,6 @@ def test_set_state_event(self): self.assertFalse(self.transport._state_events[old_state].is_set()) self.assertTrue(self.transport._state_events[new_state].is_set()) - def test_set_state_event_no_old_state(self): - old_state = None - new_state = TransportState.CONNECTED - self.transport._state_events[new_state].clear() - - self.transport._set_state_event(old_state, new_state) - - self.assertTrue(self.transport._state_events[new_state].is_set()) - def test_set_state_event_unchanged_state(self): state = TransportState.CONNECTED event_mock = mock.MagicMock() @@ -1229,8 +1224,8 @@ def test_last_connect_result_on_no_connect_task(self): self.assertIsNone(result) - def test_request_timeout(self): - self.transport.reconnect_advice = { + def test_request_timeout_int(self): + self.transport._reconnect_advice = { "timeout": 2000 } @@ -1238,7 +1233,23 @@ def test_request_timeout(self): (self.transport.reconnect_advice["timeout"] / 1000) * type(self.transport).REQUEST_TIMEOUT_INCREASE_FACTOR) + def test_request_timeout_float(self): + self.transport._reconnect_advice = { + "timeout": 2000.0 + } + + self.assertEqual(self.transport.request_timeout, + (self.transport.reconnect_advice["timeout"] / 1000) * + type(self.transport).REQUEST_TIMEOUT_INCREASE_FACTOR) + def test_request_timeout_none(self): - self.transport.reconnect_advice = {} + self.transport._reconnect_advice = {} + + self.assertIsNone(self.transport.request_timeout) + + def test_request_timeout_none_on_unsupported_timeout_type(self): + self.transport._reconnect_advice = { + "timeout": "2000" + } self.assertIsNone(self.transport.request_timeout) diff --git a/tests/unit/test_transports/test_long_polling.py b/tests/unit/test_transports/test_long_polling.py index 55c3de3..e9dd842 100644 --- a/tests/unit/test_transports/test_long_polling.py +++ b/tests/unit/test_transports/test_long_polling.py @@ -91,3 +91,47 @@ async def test_send_payload_final_payload_client_error(self): headers=headers, timeout=self.transport.request_timeout) self.transport._consume_payload.assert_not_called() + + async def test_send_payload_final_payload_missing_response(self): + resp_data = [{ + "channel": "test/channel3", + "data": {}, + "id": 4 + }] + response_mock = mock.MagicMock() + response_mock.json = mock.CoroutineMock(return_value=resp_data) + response_mock.headers = object() + session = mock.MagicMock() + session.post = mock.CoroutineMock(return_value=response_mock) + self.transport._get_http_session = \ + mock.CoroutineMock(return_value=session) + self.transport._http_semaphore = mock.MagicMock() + payload = [object(), object()] + self.transport.ssl = object() + self.transport._consume_payload = \ + mock.CoroutineMock(return_value=None) + headers = dict(key="value") + error_message = "No response message received for the " \ + "first message in the payload" + + with self.assertLogs(LongPollingTransport.__module__, + level="DEBUG") as log: + with self.assertRaisesRegex(TransportError, error_message): + await self.transport._send_final_payload(payload, + headers=headers) + + log_message = "WARNING:{}:{}" \ + .format(LongPollingTransport.__module__, error_message) + self.assertEqual(log.output, [log_message]) + self.transport._http_semaphore.__aenter__.assert_called() + self.transport._http_semaphore.__aexit__.assert_called() + session.post.assert_called_with(self.transport._url, + json=payload, + ssl=self.transport.ssl, + headers=headers, + timeout=self.transport.request_timeout) + response_mock.json.assert_called_with(loads=self.transport._json_loads) + self.transport._consume_payload.assert_called_with( + resp_data, + headers=response_mock.headers, + find_response_for=payload[0]) diff --git a/tests/unit/test_transports/test_websocket.py b/tests/unit/test_transports/test_websocket.py index 6e31b7e..49ccaaf 100644 --- a/tests/unit/test_transports/test_websocket.py +++ b/tests/unit/test_transports/test_websocket.py @@ -23,11 +23,11 @@ async def test_enter(self): args = [object()] kwargs = {"key": "value"} - await self.factory._enter(*args, **kwargs) + result = await self.factory._enter(*args, **kwargs) self.session.ws_connect.assert_called_with(*args, **kwargs) self.assertEqual(self.factory._context, context) - self.assertEqual(self.factory._socket, socket) + self.assertEqual(result, socket) async def test_exit(self): socket = object() @@ -69,6 +69,8 @@ async def test_call_socket_creates_socket(self): await self.factory(*args, **kwargs) self.factory._enter.assert_called_with(*args, **kwargs) + self.assertEqual(self.factory._socket, + self.factory._enter.return_value) async def test_call_socket_returns_open_socket(self): self.factory._enter = mock.CoroutineMock() From 62c72af6fabe19dde083c82f103fd9f63d4ee46a Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 4 Jan 2019 16:53:21 +0100 Subject: [PATCH 11/17] Run mypy during CI testing --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 1327bee..761b953 100644 --- a/tox.ini +++ b/tox.ini @@ -12,3 +12,4 @@ commands = coveralls flake8 pylint aiocometd + mypy --strict aiocometd From aa86c9141cb7dc9a69e230622505c8be088f0847 Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 4 Jan 2019 17:16:39 +0100 Subject: [PATCH 12/17] Replace all explicit relative imports with absolute imports --- aiocometd/__init__.py | 10 +++++----- aiocometd/client.py | 18 +++++++++--------- aiocometd/exceptions.py | 2 +- aiocometd/extensions.py | 2 +- aiocometd/transports/__init__.py | 4 ++-- aiocometd/transports/abc.py | 4 ++-- aiocometd/transports/base.py | 16 ++++++++-------- aiocometd/transports/long_polling.py | 10 +++++----- aiocometd/transports/registry.py | 6 +++--- aiocometd/transports/websocket.py | 10 +++++----- aiocometd/{_typing.py => typing.py} | 2 +- aiocometd/utils.py | 4 ++-- 12 files changed, 44 insertions(+), 44 deletions(-) rename aiocometd/{_typing.py => typing.py} (93%) diff --git a/aiocometd/__init__.py b/aiocometd/__init__.py index 20239e5..e2dc25c 100644 --- a/aiocometd/__init__.py +++ b/aiocometd/__init__.py @@ -1,11 +1,11 @@ """CometD client for asyncio""" import logging -from ._metadata import VERSION as __version__ # noqa: F401 -from .client import Client # noqa: F401 -from .constants import ConnectionType # noqa: F401 -from .extensions import Extension, AuthExtension # noqa: F401 -from . import transports # noqa: F401 +from aiocometd._metadata import VERSION as __version__ # noqa: F401 +from aiocometd.client import Client # noqa: F401 +from aiocometd.constants import ConnectionType # noqa: F401 +from aiocometd.extensions import Extension, AuthExtension # noqa: F401 +from aiocometd import transports # noqa: F401 # Create a default handler to avoid warnings in applications without logging # configuration diff --git a/aiocometd/client.py b/aiocometd/client.py index b23f628..6580690 100644 --- a/aiocometd/client.py +++ b/aiocometd/client.py @@ -8,16 +8,16 @@ from typing import Optional, List, Union, Set, AsyncIterator, Type, Any from types import TracebackType -from .transports import create_transport -from .transports.abc import Transport -from .constants import DEFAULT_CONNECTION_TYPE, ConnectionType, MetaChannel, \ - SERVICE_CHANNEL_PREFIX, TransportState -from .exceptions import ServerError, ClientInvalidOperation, \ +from aiocometd.transports import create_transport +from aiocometd.transports.abc import Transport +from aiocometd.constants import DEFAULT_CONNECTION_TYPE, ConnectionType, \ + MetaChannel, SERVICE_CHANNEL_PREFIX, TransportState +from aiocometd.exceptions import ServerError, ClientInvalidOperation, \ TransportTimeoutError, ClientError -from .utils import is_server_error_message -from .extensions import Extension, AuthExtension -from ._typing import ConnectionTypeSpec, SSLValidationMode, JsonObject, \ - JsonDumper, JsonLoader +from aiocometd.utils import is_server_error_message +from aiocometd.extensions import Extension, AuthExtension +from aiocometd.typing import ConnectionTypeSpec, SSLValidationMode, \ + JsonObject, JsonDumper, JsonLoader LOGGER = logging.getLogger(__name__) diff --git a/aiocometd/exceptions.py b/aiocometd/exceptions.py index a437e7e..4a02227 100644 --- a/aiocometd/exceptions.py +++ b/aiocometd/exceptions.py @@ -13,7 +13,7 @@ """ from typing import Optional, List, cast -from . import utils +from aiocometd import utils class AiocometdException(Exception): diff --git a/aiocometd/extensions.py b/aiocometd/extensions.py index a0281cd..b8e87e3 100644 --- a/aiocometd/extensions.py +++ b/aiocometd/extensions.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Optional -from ._typing import Payload, Headers +from aiocometd.typing import Payload, Headers class Extension(ABC): diff --git a/aiocometd/transports/__init__.py b/aiocometd/transports/__init__.py index cb7c562..849b0d2 100644 --- a/aiocometd/transports/__init__.py +++ b/aiocometd/transports/__init__.py @@ -1,3 +1,3 @@ """Transport classes and functions""" -from .registry import create_transport # noqa: F401 -from . import long_polling, websocket # noqa: F401 +from aiocometd.transports.registry import create_transport # noqa: F401 +from aiocometd.transports import long_polling, websocket # noqa: F401 diff --git a/aiocometd/transports/abc.py b/aiocometd/transports/abc.py index fb0586d..19e3311 100644 --- a/aiocometd/transports/abc.py +++ b/aiocometd/transports/abc.py @@ -2,8 +2,8 @@ from abc import ABC, abstractmethod from typing import Set, Optional, List -from ..constants import ConnectionType, TransportState -from .._typing import JsonObject +from aiocometd.constants import ConnectionType, TransportState +from aiocometd.typing import JsonObject class Transport(ABC): diff --git a/aiocometd/transports/base.py b/aiocometd/transports/base.py index e9d4f2c..ee7f09a 100644 --- a/aiocometd/transports/base.py +++ b/aiocometd/transports/base.py @@ -8,16 +8,16 @@ import aiohttp -from .abc import Transport -from ..constants import ConnectionType, MetaChannel, TransportState, \ +from aiocometd.constants import ConnectionType, MetaChannel, TransportState, \ HANDSHAKE_MESSAGE, CONNECT_MESSAGE, DISCONNECT_MESSAGE, \ SUBSCRIBE_MESSAGE, UNSUBSCRIBE_MESSAGE, PUBLISH_MESSAGE -from ..utils import defer, is_matching_response, is_auth_error_message, \ - is_server_error_message, is_event_message -from ..exceptions import TransportInvalidOperation, TransportError -from .._typing import SSLValidationMode, JsonObject, JsonLoader, JsonDumper, \ - Headers, Payload -from ..extensions import Extension, AuthExtension +from aiocometd.utils import defer, is_matching_response, \ + is_auth_error_message, is_server_error_message, is_event_message +from aiocometd.exceptions import TransportInvalidOperation, TransportError +from aiocometd.typing import SSLValidationMode, JsonObject, JsonLoader, \ + JsonDumper, Headers, Payload +from aiocometd.extensions import Extension, AuthExtension +from aiocometd.transports.abc import Transport LOGGER = logging.getLogger(__name__) diff --git a/aiocometd/transports/long_polling.py b/aiocometd/transports/long_polling.py index a88eff8..a05fcbd 100644 --- a/aiocometd/transports/long_polling.py +++ b/aiocometd/transports/long_polling.py @@ -5,11 +5,11 @@ import aiohttp -from ..constants import ConnectionType -from .registry import register_transport -from .base import TransportBase, Payload, Headers -from ..exceptions import TransportError -from .._typing import JsonObject +from aiocometd.constants import ConnectionType +from aiocometd.exceptions import TransportError +from aiocometd.typing import JsonObject +from aiocometd.transports.registry import register_transport +from aiocometd.transports.base import TransportBase, Payload, Headers LOGGER = logging.getLogger(__name__) diff --git a/aiocometd/transports/registry.py b/aiocometd/transports/registry.py index 767c237..e83eb43 100644 --- a/aiocometd/transports/registry.py +++ b/aiocometd/transports/registry.py @@ -1,9 +1,9 @@ """Functions for transport class registration and instantiation""" from typing import Type, Callable, Any -from ..exceptions import TransportInvalidOperation -from .abc import Transport -from ..constants import ConnectionType +from aiocometd.exceptions import TransportInvalidOperation +from aiocometd.constants import ConnectionType +from aiocometd.transports.abc import Transport TRANSPORT_CLASSES = {} diff --git a/aiocometd/transports/websocket.py b/aiocometd/transports/websocket.py index 3deb9f4..b695e94 100644 --- a/aiocometd/transports/websocket.py +++ b/aiocometd/transports/websocket.py @@ -8,11 +8,11 @@ import aiohttp import aiohttp.client_ws -from ..constants import ConnectionType, MetaChannel -from .registry import register_transport -from .base import TransportBase, Payload, Headers -from ..exceptions import TransportError, TransportConnectionClosed -from .._typing import JsonObject +from aiocometd.constants import ConnectionType, MetaChannel +from aiocometd.exceptions import TransportError, TransportConnectionClosed +from aiocometd.typing import JsonObject +from aiocometd.transports.registry import register_transport +from aiocometd.transports.base import TransportBase, Payload, Headers LOGGER = logging.getLogger(__name__) diff --git a/aiocometd/_typing.py b/aiocometd/typing.py similarity index 93% rename from aiocometd/_typing.py rename to aiocometd/typing.py index f4296d9..c4a4042 100644 --- a/aiocometd/_typing.py +++ b/aiocometd/typing.py @@ -4,7 +4,7 @@ import aiohttp -from .constants import ConnectionType +from aiocometd.constants import ConnectionType #: Coroutine function diff --git a/aiocometd/utils.py b/aiocometd/utils.py index f3c8cf1..5e86c75 100644 --- a/aiocometd/utils.py +++ b/aiocometd/utils.py @@ -5,8 +5,8 @@ from http import HTTPStatus from typing import Union, Optional, List, Any -from .constants import META_CHANNEL_PREFIX, SERVICE_CHANNEL_PREFIX -from ._typing import CoroFunction, JsonObject +from aiocometd.constants import META_CHANNEL_PREFIX, SERVICE_CHANNEL_PREFIX +from aiocometd.typing import CoroFunction, JsonObject def defer(coro_func: CoroFunction, delay: Union[int, float, None] = None, *, From 70b374e305554db948f8fe1ee18892a445582ecc Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 4 Jan 2019 18:14:59 +0100 Subject: [PATCH 13/17] Add type hints to chat example --- examples/chat.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/chat.py b/examples/chat.py index 42b9530..cef1afc 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -1,19 +1,21 @@ """Client for the CometD Chat Example""" import asyncio import argparse +from typing import Dict, Any -from aioconsole import ainput +from aioconsole import ainput # type: ignore from aiocometd import Client, ConnectionType from aiocometd.exceptions import AiocometdException -async def chat_client(url, nickname, connection_type): +async def chat_client(url: str, nickname: str, + connection_type: ConnectionType) -> None: """Runs the chat client until it's canceled - :param str url: CometD server URL - :param str nickname: The user's nickname - :param aiocometd.ConnectionType connection_type: Connection type + :param url: CometD server URL + :param nickname: The user's nickname + :param connection_type: Connection type """ try: room_name = "demo" @@ -81,12 +83,13 @@ async def chat_client(url, nickname, connection_type): print("\nExiting...") -async def input_publisher(client, nickname, room_channel): +async def input_publisher(client: Client, nickname: str, + room_channel: str) -> None: """Read text from stdin and publish it on the *room_channel* - :param aiocometd.Client client: A client object - :param str nickname: The user's nickname - :param str room_channel: The chat room's channel + :param client: A client object + :param nickname: The user's nickname + :param room_channel: The chat room's channel """ up_one_line = "\033[F" clear_line = "\033[K" @@ -109,11 +112,8 @@ async def input_publisher(client, nickname, room_channel): }) -def get_arguments(): - """Returns the argument's parsed from the command line - - :rtype: dict - """ +def get_arguments() -> Dict[str, Any]: + """Returns the argument's parsed from the command line""" parser = argparse.ArgumentParser(description="CometD chat example client") parser.add_argument("url", metavar="server_url", type=str, help="CometD server URL") @@ -126,7 +126,7 @@ def get_arguments(): return vars(parser.parse_args()) -def main(): +def main() -> None: """Starts the chat client application""" arguments = get_arguments() From 4be3aa9296f189280350dd2530f39a40a5fc6ac7 Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 4 Jan 2019 18:21:09 +0100 Subject: [PATCH 14/17] Make the package PEP 561 compatible Add the py.typed marker file to the main package, to make the package available for type checking. --- MANIFEST.in | 3 +++ aiocometd/py.typed | 1 + 2 files changed, 4 insertions(+) create mode 100644 aiocometd/py.typed diff --git a/MANIFEST.in b/MANIFEST.in index d66f4cc..b6f20b3 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,8 +2,11 @@ include LICENSE.txt include DESCRIPTION.rst include README.rst include .pylintrc +include aiocometd/py.typed graft tests graft docs graft examples global-exclude *.pyc +global-exclude .python-version prune docs/build +prune venv* diff --git a/aiocometd/py.typed b/aiocometd/py.typed new file mode 100644 index 0000000..6df9ef9 --- /dev/null +++ b/aiocometd/py.typed @@ -0,0 +1 @@ +PEP-561 marker \ No newline at end of file From 59619c4748685a5b006524213f6d4ace5b3133fe Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 4 Jan 2019 18:27:12 +0100 Subject: [PATCH 15/17] Run mypy for examples during CI testing --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 761b953..b9d61e4 100644 --- a/tox.ini +++ b/tox.ini @@ -12,4 +12,4 @@ commands = coveralls flake8 pylint aiocometd - mypy --strict aiocometd + mypy --strict aiocometd examples From 79edc4d42f3791470b9c0f344f5344ef191bdbf0 Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 4 Jan 2019 18:49:43 +0100 Subject: [PATCH 16/17] Update changelog with version 0.4.0 --- docs/source/changes.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/changes.rst b/docs/source/changes.rst index 1879bfe..a797428 100644 --- a/docs/source/changes.rst +++ b/docs/source/changes.rst @@ -1,6 +1,12 @@ Changelog ========= +0.4.0 (2019-01-04) +------------------ + +- Add type hints +- Add integration tests + 0.3.1 (2018-06-15) ------------------ From e170ecc81fc0b5db2748bc0adedc264f43300312 Mon Sep 17 00:00:00 2001 From: Robert Marki Date: Fri, 4 Jan 2019 18:50:34 +0100 Subject: [PATCH 17/17] Bump version number to 0.4.0 --- aiocometd/_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiocometd/_metadata.py b/aiocometd/_metadata.py index d7707d1..5660acb 100644 --- a/aiocometd/_metadata.py +++ b/aiocometd/_metadata.py @@ -3,6 +3,6 @@ DESCRIPTION = "CometD client for asyncio" KEYWORDS = "asyncio aiohttp comet cometd bayeux push streaming" URL = "https://github.com/robertmrk/aiocometd" -VERSION = "0.3.1" +VERSION = "0.4.0" AUTHOR = "Róbert Márki" AUTHOR_EMAIL = "gsmiko@gmail.com"