Skip to content

Commit

Permalink
Revise WebClient/RTMClient internals for run_async=False
Browse files Browse the repository at this point in the history
* slackapi#530 Fixed by changing _execute_in_thread to be a coroutine
* slackapi#569 Resolved by removing a blocking loop (while future.running())
* slackapi#645 WebClient(run_async=False) no longer depends on asyncio by default
* slackapi#633 WebClient(run_async=False) doesn't internally depend on aiohttp
* slackapi#631 When run_async=True, RTM listner can be a normal function and WebClient is free from the event loop
* slackapi#630 WebClient no longer depends on aiohttp when run_async=False
* slackapi#497 Fixed when run_async=False / can be closed as we don't support run_async=True for this use case (in Flask)
  • Loading branch information
seratch committed Apr 30, 2020
1 parent e92442d commit fb21a25
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 67 deletions.
75 changes: 48 additions & 27 deletions slack/rtm/client.py
Expand Up @@ -5,9 +5,9 @@
import logging
import random
import collections
import concurrent
import inspect
import signal
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Optional, Callable, DefaultDict
from ssl import SSLContext
from threading import current_thread, main_thread
Expand Down Expand Up @@ -107,6 +107,8 @@ def __init__(
*,
token: str,
run_async: Optional[bool] = False,
# will be used only when run_async=False
run_sync_thread_pool_size: int = 3,
auto_reconnect: Optional[bool] = True,
ssl: Optional[SSLContext] = None,
proxy: Optional[str] = None,
Expand All @@ -119,6 +121,9 @@ def __init__(
):
self.token = token.strip()
self.run_async = run_async
self.thread_pool_executor = ThreadPoolExecutor(
max_workers=run_sync_thread_pool_size
)
self.auto_reconnect = auto_reconnect
self.ssl = ssl
self.proxy = proxy
Expand All @@ -135,6 +140,16 @@ def __init__(
self._last_message_id = 0
self._connection_attempts = 0
self._stopped = False
self._web_client = WebClient(
token=self.token,
base_url=self.base_url,
ssl=self.ssl,
proxy=self.proxy,
run_async=self.run_async,
loop=self._event_loop,
session=self._session,
headers=self.headers,
)

@staticmethod
def run_on(*, event: str):
Expand Down Expand Up @@ -195,8 +210,8 @@ def start(self) -> asyncio.Future:

if self.run_async:
return future

return self._event_loop.run_until_complete(future)
else:
return self._event_loop.run_until_complete(future)

def stop(self):
"""Closes the websocket connection and ensures it won't reconnect."""
Expand Down Expand Up @@ -351,7 +366,6 @@ async def _connect_and_read(self):
client_err.SlackApiError,
# TODO: Catch websocket exceptions thrown by aiohttp.
) as exception:
self._logger.debug(str(exception))
await self._dispatch_event(event="error", data=exception)
if self.auto_reconnect and not self._stopped:
await self._wait_exponentially(exception)
Expand Down Expand Up @@ -433,37 +447,27 @@ async def _dispatch_event(self, event, data=None):
# close/error callbacks.
break

if inspect.iscoroutinefunction(callback):
if self.run_async or inspect.iscoroutinefunction(callback):
await callback(
rtm_client=self, web_client=self._web_client, data=data
)
else:
self._execute_in_thread(callback, data)
await self._execute_in_thread(
callback=callback, web_client=self._web_client, data=data
)
except Exception as err:
name = callback.__name__
module = callback.__module__
msg = f"When calling '#{name}()' in the '{module}' module the following error was raised: {err}"
self._logger.error(msg)
raise

def _execute_in_thread(self, callback, data):
async def _execute_in_thread(self, callback, web_client, data):
"""Execute the callback in another thread. Wait for and return the results."""
web_client = WebClient(
token=self.token,
base_url=self.base_url,
ssl=self.ssl,
proxy=self.proxy,
headers=self.headers,
future = self.thread_pool_executor.submit(
callback, rtm_client=self, web_client=web_client, data=data
)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(
callback, rtm_client=self, web_client=web_client, data=data
)

while future.running():
pass

future.result()
return future.result()

async def _retrieve_websocket_info(self):
"""Retrieves the WebSocket info from Slack.
Expand Down Expand Up @@ -498,10 +502,18 @@ async def _retrieve_websocket_info(self):
headers=self.headers,
)
self._logger.debug("Retrieving websocket info.")
if self.connect_method in ["rtm.start", "rtm_start"]:
resp = await self._web_client.rtm_start()
use_rtm_start = self.connect_method in ["rtm.start", "rtm_start"]
if self.run_async:
if use_rtm_start:
resp = await self._web_client.rtm_start()
else:
resp = await self._web_client.rtm_connect()
else:
resp = await self._web_client.rtm_connect()
if use_rtm_start:
resp = self._web_client.rtm_start()
else:
resp = self._web_client.rtm_connect()

url = resp.get("url")
if url is None:
msg = "Unable to retrieve RTM URL from Slack."
Expand All @@ -513,15 +525,24 @@ async def _wait_exponentially(self, exception, max_wait_time=300):
Calculate the number of seconds to wait and then add
a random number of milliseconds to avoid coincidental
synchronized client retries. Wait up to the maximium amount
synchronized client retries. Wait up to the maximum amount
of wait time specified via 'max_wait_time'. However,
if Slack returned how long to wait use that.
"""
wait_time = min(
(2 ** self._connection_attempts) + random.random(), max_wait_time
)
try:
wait_time = exception.response["headers"]["Retry-After"]
headers = (
exception.response["headers"]
if "headers" in exception.response
else None
)
if headers and "Retry-After" in headers:
wait_time = headers["Retry-After"]
else:
# an error returned due to other unrecoverable reasons
raise exception
except (KeyError, AttributeError):
pass
self._logger.debug("Waiting %s seconds before reconnecting.", wait_time)
Expand Down
20 changes: 20 additions & 0 deletions slack/web/__init__.py
@@ -0,0 +1,20 @@
import platform
import sys

import slack.version as ver


def get_user_agent():
"""Construct the user-agent header with the package info,
Python version and OS version.
Returns:
The user agent string.
e.g. 'Python/3.6.7 slackclient/2.0.0 Darwin/17.7.0'
"""
# __name__ returns all classes, we only want the client
client = "{0}/{1}".format("slackclient", ver.__version__)
python_version = "Python/{v.major}.{v.minor}.{v.micro}".format(v=sys.version_info)
system_info = "{0}/{1}".format(platform.system(), platform.release())
user_agent_string = " ".join([python_version, client, system_info])
return user_agent_string
89 changes: 57 additions & 32 deletions slack/web/base_client.py
@@ -1,9 +1,8 @@
"""A Python module for interacting with Slack's Web API."""

# Standard Imports
import json
from urllib.parse import urljoin
import platform
import sys
import logging
import asyncio
from typing import Optional, Union
Expand All @@ -15,9 +14,10 @@
from aiohttp import FormData, BasicAuth

# Internal Imports
from slack.web import get_user_agent
from slack.web.slack_response import SlackResponse
import slack.version as ver
import slack.errors as err
from slack.web.urllib_client import UrllibWebClient


class BaseClient:
Expand All @@ -32,6 +32,7 @@ def __init__(
ssl=None,
proxy=None,
run_async=False,
use_sync_aiohttp=False,
session=None,
headers: Optional[dict] = None,
):
Expand All @@ -41,11 +42,16 @@ def __init__(
self.ssl = ssl
self.proxy = proxy
self.run_async = run_async
self.use_sync_aiohttp = use_sync_aiohttp
self.session = session
self.headers = headers or {}
self._logger = logging.getLogger(__name__)
self._event_loop = loop

self.urllib_client = UrllibWebClient(
token=self.token, default_headers=self.headers, web_client=self,
)

def _get_event_loop(self):
"""Retrieves the event loop or creates a new one."""
try:
Expand All @@ -58,7 +64,7 @@ def _get_event_loop(self):
def _get_headers(
self, has_json: bool, has_files: bool, request_specific_headers: Optional[dict]
):
"""Contructs the headers need for a request.
"""Constructs the headers need for a request.
Args:
has_json (bool): Whether or not the request has json.
has_files (bool): Whether or not the request has files.
Expand All @@ -73,7 +79,7 @@ def _get_headers(
}
"""
final_headers = {
"User-Agent": self._get_user_agent(),
"User-Agent": get_user_agent(),
"Content-Type": "application/x-www-form-urlencoded;charset=utf-8",
}

Expand Down Expand Up @@ -115,7 +121,7 @@ def api_call(
e.g. 'chat.postMessage'
http_verb (str): HTTP Verb. e.g. 'POST'
files (dict): Files to multipart upload.
e.g. {imageORfile: file_objectORfile_path}
e.g. {image OR file: file_object OR file_path}
data: The body to attach to the request. If a dictionary is
provided, form-encoding will take place.
e.g. {'key1': 'value1', 'key2': 'value2'}
Expand Down Expand Up @@ -160,18 +166,21 @@ def api_call(
"auth": auth,
}

if self._event_loop is None:
self._event_loop = self._get_event_loop()

future = asyncio.ensure_future(
self._send(http_verb=http_verb, api_url=api_url, req_args=req_args),
loop=self._event_loop,
)

if self.run_async:
return future
if self.run_async or self.use_sync_aiohttp:
if self._event_loop is None:
self._event_loop = self._get_event_loop()

return self._event_loop.run_until_complete(future)
future = asyncio.ensure_future(
self._send(http_verb=http_verb, api_url=api_url, req_args=req_args),
loop=self._event_loop,
)
if self.run_async:
return future
elif self.use_sync_aiohttp:
# Using this is no longer recommended - just keep this for backward-compatibility
return self._event_loop.run_until_complete(future)
else:
return self._sync_send(api_url=api_url, req_args=req_args)

def _get_url(self, api_method):
"""Joins the base Slack URL and an API method to form an absolute URL.
Expand Down Expand Up @@ -225,6 +234,7 @@ async def _send(self, http_verb, api_url, req_args):
"http_verb": http_verb,
"api_url": api_url,
"req_args": req_args,
"use_sync_aiohttp": self.use_sync_aiohttp,
}
return SlackResponse(**{**data, **res}).validate()

Expand Down Expand Up @@ -258,23 +268,38 @@ async def _request(self, *, http_verb, api_url, req_args):
await session.close()
return response

@staticmethod
def _get_user_agent():
"""Construct the user-agent header with the package info,
Python version and OS version.
def _sync_send(self, api_url, req_args):
params = req_args["params"] if "params" in req_args else None
data = req_args["data"] if "data" in req_args else None
files = req_args["files"] if "files" in req_args else None
json = req_args["json"] if "files" in req_args else None
headers = req_args["headers"] if "headers" in req_args else None
token = params.get("token") if params and "token" in params else None
body_params = {}
if params:
body_params.update(params)
if data:
body_params.update(data)

return self.urllib_client.api_call(
token=token,
url=api_url,
query_params={},
body_params=body_params,
files=files,
json_body=json,
additional_headers=headers,
)

Returns:
The user agent string.
e.g. 'Python/3.6.7 slackclient/2.0.0 Darwin/17.7.0'
"""
# __name__ returns all classes, we only want the client
client = "{0}/{1}".format("slackclient", ver.__version__)
python_version = "Python/{v.major}.{v.minor}.{v.micro}".format(
v=sys.version_info
def _sync_request(self, api_url, req_args):
response, response_body = self.urllib_client._perform_http_request(
url=api_url, args=req_args,
)
system_info = "{0}/{1}".format(platform.system(), platform.release())
user_agent_string = " ".join([python_version, client, system_info])
return user_agent_string
return {
"status_code": int(response.status),
"headers": dict(response.headers),
"data": json.loads(response_body),
}

@staticmethod
def validate_slack_signature(
Expand Down

0 comments on commit fb21a25

Please sign in to comment.