diff --git a/integration_tests/samples/socket_mode/aiohttp_example.py b/integration_tests/samples/socket_mode/aiohttp_example.py index cf82aa6f2..eaa7d152b 100644 --- a/integration_tests/samples/socket_mode/aiohttp_example.py +++ b/integration_tests/samples/socket_mode/aiohttp_example.py @@ -23,12 +23,12 @@ async def process(client: SocketModeClient, req: SocketModeRequest): if req.type == "events_api": response = SocketModeResponse(envelope_id=req.envelope_id) await client.send_socket_mode_response(response) - - await client.web_client.reactions_add( - name="eyes", - channel=req.payload["event"]["channel"], - timestamp=req.payload["event"]["ts"], - ) + if req.payload["event"]["type"] == "message": + await client.web_client.reactions_add( + name="eyes", + channel=req.payload["event"]["channel"], + timestamp=req.payload["event"]["ts"], + ) client.socket_mode_request_listeners.append(process) await client.connect() diff --git a/slack_sdk/socket_mode/aiohttp/__init__.py b/slack_sdk/socket_mode/aiohttp/__init__.py index 5e6d41d2b..36e538f34 100644 --- a/slack_sdk/socket_mode/aiohttp/__init__.py +++ b/slack_sdk/socket_mode/aiohttp/__init__.py @@ -7,7 +7,7 @@ """ import asyncio import logging -from asyncio import Future +from asyncio import Future, Lock from asyncio import Queue from logging import Logger from typing import Union, Optional, List, Callable, Awaitable @@ -58,6 +58,7 @@ class SocketModeClient(AsyncBaseSocketModeClient): auto_reconnect_enabled: bool default_auto_reconnect_enabled: bool closed: bool + connect_operation_lock: Lock on_message_listeners: List[Callable[[WSMessage], Awaitable[None]]] on_error_listeners: List[Callable[[WSMessage], Awaitable[None]]] @@ -92,6 +93,7 @@ def __init__( self.logger = logger or logging.getLogger(__name__) self.web_client = web_client or AsyncWebClient() self.closed = False + self.connect_operation_lock = Lock() self.proxy = proxy if self.proxy is None or len(self.proxy.strip()) == 0: env_variable = load_http_proxy_from_env(self.logger) @@ -185,6 +187,13 @@ async def receive_messages(self) -> None: else: await asyncio.sleep(consecutive_error_count) + async def is_connected(self) -> bool: + return ( + not self.closed + and self.current_session is not None + and not self.current_session.closed + ) + async def connect(self): old_session = None if self.current_session is None else self.current_session if self.wss_uri is None: diff --git a/slack_sdk/socket_mode/async_client.py b/slack_sdk/socket_mode/async_client.py index 60ced0933..ac58dc220 100644 --- a/slack_sdk/socket_mode/async_client.py +++ b/slack_sdk/socket_mode/async_client.py @@ -1,7 +1,7 @@ import asyncio import json import logging -from asyncio import Queue +from asyncio import Queue, Lock from asyncio.futures import Future from logging import Logger from typing import Dict, Union, Any, Optional, List, Callable, Awaitable @@ -23,6 +23,8 @@ class AsyncBaseSocketModeClient: wss_uri: str auto_reconnect_enabled: bool closed: bool + connect_operation_lock: Lock + message_queue: Queue message_listeners: List[ Union[ @@ -58,15 +60,24 @@ async def issue_new_wss_url(self) -> str: self.logger.error(f"Failed to retrieve WSS URL: {e}") raise e + async def is_connected(self) -> bool: + return False + async def connect(self): raise NotImplementedError() async def disconnect(self): raise NotImplementedError() - async def connect_to_new_endpoint(self): - self.wss_uri = await self.issue_new_wss_url() - await self.connect() + async def connect_to_new_endpoint(self, force: bool = False): + try: + await self.connect_operation_lock.acquire() + if force or not await self.is_connected(): + self.wss_uri = await self.issue_new_wss_url() + await self.connect() + finally: + if self.connect_operation_lock.locked() is True: + self.connect_operation_lock.release() async def close(self): self.closed = True @@ -116,7 +127,7 @@ async def run_message_listeners(self, message: dict, raw_message: str) -> None: ) try: if message.get("type") == "disconnect": - await self.connect_to_new_endpoint() + await self.connect_to_new_endpoint(force=True) return for listener in self.message_listeners: diff --git a/slack_sdk/socket_mode/websockets/__init__.py b/slack_sdk/socket_mode/websockets/__init__.py index 435216c6b..205c952d7 100644 --- a/slack_sdk/socket_mode/websockets/__init__.py +++ b/slack_sdk/socket_mode/websockets/__init__.py @@ -7,7 +7,7 @@ """ import asyncio import logging -from asyncio import Future +from asyncio import Future, Lock from logging import Logger from asyncio import Queue from typing import Union, Optional, List, Callable, Awaitable @@ -56,6 +56,7 @@ class SocketModeClient(AsyncBaseSocketModeClient): auto_reconnect_enabled: bool default_auto_reconnect_enabled: bool closed: bool + connect_operation_lock: Lock def __init__( self, @@ -78,6 +79,7 @@ def __init__( self.logger = logger or logging.getLogger(__name__) self.web_client = web_client or AsyncWebClient() self.closed = False + self.connect_operation_lock = Lock() self.default_auto_reconnect_enabled = auto_reconnect_enabled self.auto_reconnect_enabled = self.default_auto_reconnect_enabled self.ping_interval = ping_interval @@ -130,6 +132,13 @@ async def receive_messages(self) -> None: else: await asyncio.sleep(consecutive_error_count) + async def is_connected(self) -> bool: + return ( + not self.closed + and self.current_session is not None + and not self.current_session.closed + ) + async def connect(self): if self.wss_uri is None: self.wss_uri = await self.issue_new_wss_url()