diff --git a/binance/ws/reconnecting_websocket.py b/binance/ws/reconnecting_websocket.py index 36277956..e0021e33 100644 --- a/binance/ws/reconnecting_websocket.py +++ b/binance/ws/reconnecting_websocket.py @@ -3,7 +3,7 @@ import json import logging from socket import gaierror -from typing import Optional +from typing import Optional, Union from asyncio import sleep from random import random @@ -14,23 +14,7 @@ except ImportError: pass -try: - from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK # type: ignore -except ImportError: - from websockets import ConnectionClosedError, ConnectionClosedOK # type: ignore - - -Proxy = None -proxy_connect = None -try: - from websockets_proxy import Proxy as w_Proxy, proxy_connect as w_proxy_connect - - Proxy = w_Proxy - proxy_connect = w_proxy_connect -except ImportError: - pass - -import websockets as ws +import picows from binance.exceptions import ( BinanceWebsocketClosed, @@ -42,6 +26,57 @@ from binance.ws.constants import WSListenerState +_DISCONNECT_SENTINEL = object() + + +class _PicowsWebSocket(picows.WSListener): + def __init__(self): + self._transport = None + self._queue = asyncio.Queue() + self.closed = False + + def on_ws_connected(self, transport): + self._transport = transport + + def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame) -> None: + if frame.msg_type == picows.WSMsgType.TEXT: + payload: Union[str, bytes] = frame.get_payload_as_utf8_text() + elif frame.msg_type == picows.WSMsgType.BINARY: + payload = frame.get_payload_as_bytes() + else: + return + self._queue.put_nowait(payload) + + def on_ws_disconnected(self, transport: picows.WSTransport) -> None: + self.closed = True + self._queue.put_nowait(_DISCONNECT_SENTINEL) + + async def recv(self): + if self.closed: + raise ConnectionError("WebSocket is closed") + msg = await self._queue.get() + self._queue.task_done() + if msg is _DISCONNECT_SENTINEL: + self.closed = True + raise ConnectionError("WebSocket disconnected") + return msg + + async def send(self, payload: Union[str, bytes]) -> None: + if self.closed: + raise ConnectionError("WebSocket is closed") + if isinstance(payload, bytes): + self._transport.send(picows.WSMsgType.BINARY, payload) + else: + self._transport.send(picows.WSMsgType.TEXT, payload.encode("utf-8")) + + async def close(self) -> None: + if self.closed: + return + self._transport.send_close() + self._transport.disconnect() + await self._transport.wait_disconnected() + + class ReconnectingWebsocket: MAX_RECONNECTS = 5 MAX_RECONNECT_SECONDS = 60 @@ -70,7 +105,7 @@ def __init__( self._is_binary = is_binary self._conn = None self._socket = None - self.ws: Optional[ws.WebSocketClientProtocol] = None # type: ignore + self.ws: Optional[_PicowsWebSocket] = None self.ws_state = WSListenerState.INITIALISING self._queue = asyncio.Queue() self._handle_read_loop = None @@ -103,8 +138,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._exit_coro(self._path) if self.ws: await self.ws.close() - if self._conn and hasattr(self._conn, "protocol"): - await self._conn.__aexit__(exc_type, exc_val, exc_tb) self.ws = None async def connect(self): @@ -116,21 +149,17 @@ async def connect(self): f"{self._url}{getattr(self, '_prefix', '')}{getattr(self, '_path', '')}" ) - # handle https_proxy - if self._https_proxy: - if not Proxy or not proxy_connect: - raise ImportError( - "websockets_proxy is not installed, please install it to use a websockets proxy (pip install websockets_proxy)" - ) - proxy = Proxy.from_url(self._https_proxy) # type: ignore - self._conn = proxy_connect( - ws_url, close_timeout=0.1, proxy=proxy, **self._ws_kwargs - ) # type: ignore - else: - self._conn = ws.connect(ws_url, close_timeout=0.1, **self._ws_kwargs) # type: ignore - try: - self.ws = await self._conn.__aenter__() + if self._https_proxy and self._https_proxy.lower().startswith("https://"): + raise ValueError( + "picows does not support https:// proxy URLs; use http://, socks4://, or socks5://" + ) + _, self.ws = await picows.ws_connect( + _PicowsWebSocket, + ws_url, + proxy=self._https_proxy, + **self._ws_kwargs, + ) except Exception as e: # noqa self._log.error(f"Failed to connect to websocket: {e}") self.ws_state = WSListenerState.RECONNECTING @@ -189,10 +218,7 @@ async def _read_loop(self): f"_read_loop {self._path} break for {self.ws_state}" ) break - elif self.ws.state == ws.protocol.State.CLOSING: # type: ignore - await asyncio.sleep(0.1) - continue - elif self.ws.state == ws.protocol.State.CLOSED: # type: ignore + elif self.ws and self.ws.closed: self._reconnect() raise BinanceWebsocketClosed( "Connection closed. Reconnecting..." @@ -225,8 +251,7 @@ async def _read_loop(self): except ( asyncio.IncompleteReadError, gaierror, - ConnectionClosedError, - ConnectionClosedOK, + ConnectionError, BinanceWebsocketClosed, ) as e: # reports errors and continue loop @@ -299,11 +324,9 @@ def _get_reconnect_wait(self, attempts: int) -> int: async def before_reconnect(self): if self.ws: + await self.ws.close() self.ws = None - if self._conn and hasattr(self._conn, "protocol"): - await self._conn.__aexit__(None, None, None) - self._reconnects += 1 def _reconnect(self): diff --git a/binance/ws/websocket_api.py b/binance/ws/websocket_api.py index 3ef9ed13..7cc8c218 100644 --- a/binance/ws/websocket_api.py +++ b/binance/ws/websocket_api.py @@ -1,8 +1,6 @@ from typing import Dict, Optional import asyncio -from websockets import WebSocketClientProtocol # type: ignore - from .constants import WSListenerState from .reconnecting_websocket import ReconnectingWebsocket from binance.exceptions import BinanceAPIException, BinanceWebsocketUnableToConnect @@ -92,7 +90,7 @@ async def _ensure_ws_connection(self) -> None: try: if ( self.ws is None - or (isinstance(self.ws, WebSocketClientProtocol) and self.ws.closed) + or self.ws.closed or self.ws_state != WSListenerState.STREAMING ): await self.connect() diff --git a/requirements.txt b/requirements.txt index b7e63909..1681b7c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,4 @@ aiohttp dateparser pycryptodome requests -websockets -websockets_proxy; python_version >= '3.8' \ No newline at end of file +picows diff --git a/setup.py b/setup.py index f455906f..82ee965b 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ "six", "dateparser", "aiohttp", - "websockets", + "picows", "pycryptodome", ], keywords="binance exchange rest api bitcoin ethereum btc eth neo", diff --git a/tests/test_reconnecting_websocket.py b/tests/test_reconnecting_websocket.py index 7433a411..91babbed 100644 --- a/tests/test_reconnecting_websocket.py +++ b/tests/test_reconnecting_websocket.py @@ -2,12 +2,11 @@ import pytest import gzip import json -from unittest.mock import patch, create_autospec, Mock +from unittest.mock import patch, Mock from binance.ws.reconnecting_websocket import ReconnectingWebsocket from binance.ws.constants import WSListenerState from binance.exceptions import BinanceWebsocketUnableToConnect, ReadLoopClosed -from websockets import WebSocketClientProtocol # type: ignore -from websockets.protocol import State +import picows import asyncio try: @@ -83,16 +82,52 @@ async def test_recv_message(): assert result == {"test": "data"} +class MockFrame: + def __init__(self, msg_type, payload): + self.msg_type = msg_type + self._payload = payload + + def get_payload_as_utf8_text(self): + return self._payload + + def get_payload_as_bytes(self): + return self._payload + + +class MockTransport: + def __init__(self, listener): + self.listener = listener + self.sent = [] + self._disconnected = asyncio.Event() + + def send(self, msg_type, payload): + self.sent.append((msg_type, payload)) + + def send_close(self): + self.listener.on_ws_disconnected(self) + self._disconnected.set() + + def disconnect(self, graceful=False): + self._disconnected.set() + + async def wait_disconnected(self): + await self._disconnected.wait() + + def emit_text(self, payload: str): + frame = MockFrame(picows.WSMsgType.TEXT, payload) + self.listener.on_ws_frame(self, frame) + + @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+") @pytest.mark.asyncio async def test_before_reconnect(): ws = ReconnectingWebsocket(url="wss://test.url") ws.ws = AsyncMock() - ws._conn = AsyncMock() + ws_connection = ws.ws ws._reconnects = 0 await ws.before_reconnect() + ws_connection.close.assert_awaited_once() assert ws.ws is None - ws._conn.__aexit__.assert_awaited() assert ws._reconnects == 1 @@ -110,14 +145,14 @@ async def test_connect_max_reconnects_exceeded(): ws.MAX_RECONNECTS = 2 # type: ignore # Set max reconnects to a low number for testing ws._before_connect = AsyncMock() ws._after_connect = AsyncMock() - ws._conn = AsyncMock() - exception = Exception("Connection failed") - ws._conn.__aenter__.side_effect = exception - with patch.object(ws._log, "error") as mock_log: - with pytest.raises(BinanceWebsocketUnableToConnect): - for _ in range(3): # Exceed MAX_RECONNECTS - await ws._run_reconnect() + with patch( + "binance.ws.reconnecting_websocket.picows.ws_connect", + side_effect=Exception("Connection failed"), + ): + with pytest.raises(BinanceWebsocketUnableToConnect): + for _ in range(3): # Exceed MAX_RECONNECTS + await ws._run_reconnect() mock_log.assert_called_with(f"Max reconnections {ws.MAX_RECONNECTS} reached:") assert ws._reconnects == ws.MAX_RECONNECTS @@ -126,17 +161,23 @@ async def test_connect_max_reconnects_exceeded(): @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+") @pytest.mark.asyncio async def test_recieve_invalid_json(): - # Create mock WebSocket client - mock_socket = create_autospec(WebSocketClientProtocol) - mock_socket.recv = AsyncMock(return_value="invalid json{") - mock_socket.state = AsyncMock() - - # Mock websockets.connect to return our mock socket - with patch("websockets.connect") as mock_connect: - mock_connect.return_value.__aenter__.return_value = mock_socket - + transport = None + + async def _mock_connect(listener_factory, *_args, **_kwargs): + nonlocal transport + listener = listener_factory() + transport = MockTransport(listener) + listener.on_ws_connected(transport) + return transport, listener + + with patch( + "binance.ws.reconnecting_websocket.picows.ws_connect", + side_effect=_mock_connect, + ): ws = ReconnectingWebsocket(url="wss://test.url") async with ws: + assert transport is not None + transport.emit_text("invalid json{") msg = await ws.recv() assert msg["e"] == "error" assert msg["type"] == "JSONDecodeError" # JSON parsing error @@ -145,18 +186,24 @@ async def test_recieve_invalid_json(): @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+") @pytest.mark.asyncio async def test_receive_valid_json(): - # Create mock WebSocket client msgRecv = '{"e": "value"}' - mock_socket = create_autospec(WebSocketClientProtocol) - mock_socket.recv = AsyncMock(return_value=msgRecv) - mock_socket.state = AsyncMock() - - # Mock websockets.connect to return our mock socket - with patch("websockets.connect") as mock_connect: - mock_connect.return_value.__aenter__.return_value = mock_socket - + transport = None + + async def _mock_connect(listener_factory, *_args, **_kwargs): + nonlocal transport + listener = listener_factory() + transport = MockTransport(listener) + listener.on_ws_connected(transport) + return transport, listener + + with patch( + "binance.ws.reconnecting_websocket.picows.ws_connect", + side_effect=_mock_connect, + ): ws = ReconnectingWebsocket(url="wss://test.url") async with ws: + assert transport is not None + transport.emit_text(msgRecv) msg = await ws.recv() assert msg == json.loads(msgRecv) @@ -166,50 +213,46 @@ async def test_receive_valid_json(): async def test_connect_fails_to_connect_on_enter_context(): """Test ws.connect raises a ConnectionClosedError.""" ws = ReconnectingWebsocket(url="wss://test.url") - ws._conn = AsyncMock() - exception = Exception("Connection closed") - ws._conn.__aenter__.side_effect = exception - with pytest.raises(Exception): - await ws.__aenter__() + with patch( + "binance.ws.reconnecting_websocket.picows.ws_connect", + side_effect=Exception("Connection closed"), + ): + with pytest.raises(Exception): + await ws.__aenter__() @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+") @pytest.mark.asyncio async def test_connect_fails_to_connect_after_disconnect(): - # Create mock WebSocket client - mock_socket = create_autospec(WebSocketClientProtocol) - mock_socket.recv = AsyncMock(side_effect=delayed_return) - mock_socket.state = AsyncMock() - - # Create mock connect that succeeds first, then fails - mock_connect = AsyncMock() - mock_connect.return_value.__aenter__.side_effect = [ - mock_socket, # First call succeeds - Exception("Connection failed"), # Subsequent calls fail - ] - - with patch("websockets.connect", return_value=mock_connect.return_value): + connect_calls = 0 + + async def _mock_connect(listener_factory, *_args, **_kwargs): + nonlocal connect_calls + connect_calls += 1 + if connect_calls > 1: + raise Exception("Connection failed") + listener = listener_factory() + transport = MockTransport(listener) + listener.on_ws_connected(transport) + return transport, listener + + with patch( + "binance.ws.reconnecting_websocket.picows.ws_connect", + side_effect=_mock_connect, + ): ws = ReconnectingWebsocket(url="wss://test.url") async with ws as ws: assert ws.ws is not None - msg = await ws.recv() - ws.ws.state = State.CLOSED + ws.ws._transport.emit_text('{"e":"value"}') + _ = await ws.recv() await ws.ws.close() - while msg["e"] != "error": - msg = await ws.recv() - # Receive the closed message attempting to reconnect - while msg["type"] == "BinanceWebsocketClosed": + msg = await ws.recv() + while msg["type"] in {"ConnectionError", "BinanceWebsocketClosed"}: msg = await ws.recv() - # After retrying to reconnect, receive BinanceWebsocketUnableToConnect assert msg["e"] == "error" assert msg["type"] == "BinanceWebsocketUnableToConnect" -async def delayed_return(): - await asyncio.sleep(0.1) # 100 ms delay - return '{"e": "value"}' - - @pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+") @pytest.mark.asyncio async def test_recv_read_loop_closed(): diff --git a/tests/test_threaded_stream.py b/tests/test_threaded_stream.py index 44a62a52..bcd433de 100644 --- a/tests/test_threaded_stream.py +++ b/tests/test_threaded_stream.py @@ -1,7 +1,6 @@ import pytest import asyncio -import websockets from binance.ws.threaded_stream import ThreadedApiManager from unittest.mock import Mock @@ -71,7 +70,7 @@ async def controlled_recv(): recv_count += 1 # If we've stopped the socket or read enough times, simulate connection closing if not manager._socket_running.get(socket_name) or recv_count > 2: - raise websockets.exceptions.ConnectionClosed(None, None) + raise ConnectionError("connection closed") await asyncio.sleep(0.1) return '{"e": "value"}' @@ -95,7 +94,7 @@ async def controlled_recv(): # Wait for the listener task to complete try: await asyncio.wait_for(listener_task, timeout=1.0) - except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): + except (asyncio.TimeoutError, ConnectionError): pass # These exceptions are expected during shutdown assert socket_name not in manager._socket_running @@ -134,7 +133,7 @@ async def controlled_recv(): # Wait for the listener to finish try: await asyncio.wait_for(listener_task, timeout=1.0) - except (asyncio.TimeoutError, websockets.exceptions.ConnectionClosed): + except (asyncio.TimeoutError, ConnectionError): listener_task.cancel() # Callback should not have been called (no successful messages)