diff --git a/.gitignore b/.gitignore index f48b391..e210641 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,7 @@ dmypy.json # Visual Studio Code .vscode/ + +# sphinx.ext.autosummary + +generated diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst new file mode 100644 index 0000000..151a51a --- /dev/null +++ b/docs/_templates/autosummary/class.rst @@ -0,0 +1,29 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + + {% block methods %} + + {% if methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + {% for item in methods %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Attributes') }} + + .. autosummary:: + {% for item in attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/_templates/autosummary/module.rst b/docs/_templates/autosummary/module.rst new file mode 100644 index 0000000..4cb0abc --- /dev/null +++ b/docs/_templates/autosummary/module.rst @@ -0,0 +1,61 @@ +{{ fullname | escape | underline}} + +.. automodule:: {{ fullname }} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Module Attributes') }} + + .. autosummary:: + {% for item in attributes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: {{ _('Functions') }} + + .. autosummary:: + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: {{ _('Classes') }} + + .. autosummary:: + :toctree: + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: {{ _('Exceptions') }} + + .. autosummary:: + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + +{% block modules %} +{% if modules %} +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: +{% for item in modules %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} diff --git a/docs/conf.py b/docs/conf.py index e8e1fcc..e16d0e4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,7 +13,10 @@ # import os # import sys # sys.path.insert(0, os.path.abspath('.')) +from pathlib import Path +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.resolve())) # -- Project information ----------------------------------------------------- @@ -27,7 +30,7 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [] +extensions = ["sphinx.ext.autodoc", 'sphinx.ext.autosummary'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -58,3 +61,7 @@ html_static_path = ['_static'] source_suffix = [".rst"] + +autoclass_content = 'both' + +templates_path = ['_templates'] diff --git a/docs/index.rst b/docs/index.rst index 151d9d9..fd64962 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,6 +16,15 @@ Welcome to pybotters's documentation! Exchanges Contributing +API Reference +------------- + +.. autosummary:: + :toctree: generated + :recursive: + + pybotters + Indices and tables ================== diff --git a/docs/requirements.txt b/docs/requirements.txt index 2543925..199eb38 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,4 @@ Sphinx -sphinx-rtd-theme \ No newline at end of file +sphinx-rtd-theme +aiohttp +rich \ No newline at end of file diff --git a/pybotters/__init__.py b/pybotters/__init__.py index 6f8fc9e..d51549a 100644 --- a/pybotters/__init__.py +++ b/pybotters/__init__.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import asyncio -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Mapping, Optional, Tuple, Union import aiohttp from aiohttp import hdrs @@ -9,8 +11,10 @@ from .models import experimental from .models.binance import BinanceDataStore from .models.bitbank import bitbankDataStore +from .models.bitflyer import bitFlyerDataStore from .models.bitmex import BitMEXDataStore from .models.bybit import BybitDataStore +from .models.experimental.bybit import BybitInverseDataStore, BybitUSDTDataStore from .models.ftx import FTXDataStore from .models.gmocoin import GMOCoinDataStore from .typedefs import WsJsonHandler, WsStrHandler @@ -23,9 +27,12 @@ 'put', 'delete', 'BybitDataStore', + 'BybitInverseDataStore', + 'BybitUSDTDataStore', 'FTXDataStore', 'BinanceDataStore', 'bitbankDataStore', + 'bitFlyerDataStore', 'BitMEXDataStore', 'GMOCoinDataStore', 'experimental', @@ -52,9 +59,11 @@ async def _request( *, params: Optional[Mapping[str, str]] = None, data: Any = None, - apis: Union[Dict[str, List[str]], str] = {}, + apis: Optional[Union[dict[str, list[str]], str]] = None, **kwargs: Any, ) -> SyncClientResponse: + if apis is None: + apis = {} async with Client(apis=apis, response_class=SyncClientResponse) as client: async with client.request( method, url, params=params, data=data, **kwargs @@ -69,7 +78,7 @@ def request( *, params: Optional[Mapping[str, str]] = None, data: Any = None, - apis: Union[Dict[str, List[str]], str] = {}, + apis: Optional[Union[dict[str, list[str]], str]] = None, **kwargs: Any, ) -> SyncClientResponse: loop = asyncio.get_event_loop() @@ -82,7 +91,7 @@ def get( url: str, *, params: Optional[Mapping[str, str]] = None, - apis: Union[Dict[str, List[str]], str] = {}, + apis: Optional[Union[dict[str, list[str]], str]] = None, **kwargs: Any, ) -> SyncClientResponse: loop = asyncio.get_event_loop() @@ -95,7 +104,7 @@ def post( url: str, *, data: Any = None, - apis: Union[Dict[str, List[str]], str] = {}, + apis: Optional[Union[dict[str, list[str]], str]] = None, **kwargs: Any, ) -> SyncClientResponse: loop = asyncio.get_event_loop() @@ -108,7 +117,7 @@ def put( url: str, *, data: Any = None, - apis: Union[Dict[str, List[str]], str] = {}, + apis: Optional[Union[dict[str, list[str]], str]] = None, **kwargs: Any, ) -> SyncClientResponse: loop = asyncio.get_event_loop() @@ -121,7 +130,7 @@ def delete( url: str, *, data: Any = None, - apis: Union[Dict[str, List[str]], str] = {}, + apis: Optional[Union[dict[str, list[str]], str]] = None, **kwargs: Any, ) -> SyncClientResponse: loop = asyncio.get_event_loop() @@ -133,13 +142,15 @@ def delete( async def _ws_connect( url: str, *, - send_str: Optional[Union[str, List[str]]] = None, + send_str: Optional[Union[str, list[str]]] = None, send_json: Any = None, hdlr_str: Optional[WsStrHandler] = None, hdlr_json: Optional[WsJsonHandler] = None, - apis: Union[Dict[str, List[str]], str] = {}, + apis: Optional[Union[dict[str, list[str]], str]] = None, **kwargs: Any, ) -> None: + if apis is None: + apis = {} async with Client(apis=apis) as client: wstask = await client.ws_connect( url, @@ -155,11 +166,11 @@ async def _ws_connect( def ws_connect( url: str, *, - send_str: Optional[Union[str, List[str]]] = None, + send_str: Optional[Union[str, list[str]]] = None, send_json: Any = None, hdlr_str: Optional[WsStrHandler] = None, hdlr_json: Optional[WsJsonHandler] = None, - apis: Union[Dict[str, List[str]], str] = {}, + apis: Optional[Union[dict[str, list[str]], str]] = None, **kwargs: Any, ) -> None: loop = asyncio.get_event_loop() diff --git a/pybotters/auth.py b/pybotters/auth.py index f288ec1..4246ec5 100644 --- a/pybotters/auth.py +++ b/pybotters/auth.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import base64 import hashlib import hmac import json import time from dataclasses import dataclass -from typing import Any, Dict, Tuple +from typing import Any import aiohttp from aiohttp.formdata import FormData @@ -16,10 +18,10 @@ class Auth: @staticmethod - def bybit(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def bybit(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: method: str = args[0] url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} session: aiohttp.ClientSession = kwargs['session'] key: str = session.__dict__['_apis'][Hosts.items[url.host].name][0] @@ -58,10 +60,10 @@ def bybit(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: return args @staticmethod - def binance(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def binance(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: method: str = args[0] url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} headers: CIMultiDict = kwargs['headers'] session: aiohttp.ClientSession = kwargs['session'] @@ -95,10 +97,10 @@ def binance(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: return args @staticmethod - def bitflyer(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def bitflyer(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: method: str = args[0] url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} headers: CIMultiDict = kwargs['headers'] session: aiohttp.ClientSession = kwargs['session'] @@ -118,10 +120,10 @@ def bitflyer(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: return args @staticmethod - def gmocoin(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def gmocoin(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: method: str = args[0] url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} headers: CIMultiDict = kwargs['headers'] session: aiohttp.ClientSession = kwargs['session'] @@ -145,9 +147,9 @@ def gmocoin(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: return args @staticmethod - def liquid(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def liquid(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} headers: CIMultiDict = kwargs['headers'] session: aiohttp.ClientSession = kwargs['session'] @@ -181,10 +183,10 @@ def liquid(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: return args @staticmethod - def bitbank(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def bitbank(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: method: str = args[0] url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} headers: CIMultiDict = kwargs['headers'] session: aiohttp.ClientSession = kwargs['session'] @@ -207,10 +209,10 @@ def bitbank(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: return args @staticmethod - def ftx(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def ftx(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: method: str = args[0] url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} headers: CIMultiDict = kwargs['headers'] session: aiohttp.ClientSession = kwargs['session'] @@ -228,10 +230,10 @@ def ftx(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: return args @staticmethod - def bitmex(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def bitmex(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: method: str = args[0] url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} headers: CIMultiDict = kwargs['headers'] session: aiohttp.ClientSession = kwargs['session'] @@ -251,9 +253,9 @@ def bitmex(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: return args @staticmethod - def phemex(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def phemex(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} headers: CIMultiDict = kwargs['headers'] session: aiohttp.ClientSession = kwargs['session'] @@ -278,9 +280,9 @@ def phemex(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: return args @staticmethod - def coincheck(args: Tuple[str, URL], kwargs: Dict[str, Any]) -> Tuple[str, URL]: + def coincheck(args: tuple[str, URL], kwargs: dict[str, Any]) -> tuple[str, URL]: url: URL = args[1] - data: Dict[str, Any] = kwargs['data'] or {} + data: dict[str, Any] = kwargs['data'] or {} headers: CIMultiDict = kwargs['headers'] session: aiohttp.ClientSession = kwargs['session'] diff --git a/pybotters/client.py b/pybotters/client.py index f9da3df..d64ee5e 100644 --- a/pybotters/client.py +++ b/pybotters/client.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import json import logging import os -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Mapping, Optional, Union import aiohttp from aiohttp import hdrs @@ -10,22 +12,81 @@ from .auth import Auth from .request import ClientRequest -from .typedefs import WsJsonHandler, WsStrHandler +from .typedefs import WsBytesHandler, WsJsonHandler, WsStrHandler from .ws import ClientWebSocketResponse, ws_run_forever logger = logging.getLogger(__name__) class Client: + """ + HTTPリクエストクライアントクラス + + .. note:: + 引数 apis は省略できます。 + + :Example: + + .. code-block:: python + + async def main(): + async with pybotters.Client(apis={'example': ['KEY', 'SECRET']}) as client: + r = await client.get('https://...', params={'foo': 'bar'}) + print(await r.json()) + + .. code-block:: python + + async def main(): + async with pybotters.Client(apis={'example': ['KEY', 'SECRET']}) as client: + wstask = await client.ws_connect( + 'wss://...', + send_json={'foo': 'bar'}, + hdlr_json=pybotters.print_handler + ) + await wstask + # Ctrl+C to break + + Basic API + + パッケージトップレベルで利用できるHTTPリクエスト関数です。 これらは同期関数です。 内部的にpybotters.Clientをラップしています。 + + :Example: + + .. code-block:: python + + r = pybotters.get( + 'https://...', + params={'foo': 'bar'}, + apis={'example': ['KEY', 'SECRET']} + ) + print(r.text()) + print(r.json()) + + .. code-block:: python + + pybotters.ws_connect( + 'wss://...', + send_json={'foo': 'bar'}, + hdlr_json=pybotters.print_handler, + apis={'example': ['KEY', 'SECRET']} + ) + # Ctrl+C to break + """ + _session: aiohttp.ClientSession _base_url: str def __init__( self, - apis: Union[Dict[str, List[str]], str] = {}, + apis: Optional[Union[dict[str, list[str]], str]] = None, base_url: str = '', **kwargs: Any, ) -> None: + """ + :param apis: APIキー・シークレットのデータ(optional) ex: {'exchange': ['key', 'secret']} + :param base_url: リクエストメソッドの url の前方に自動付加するURL(optional) + :param ``**kwargs``: aiohttp.Client.requestに渡されるキーワード引数(optional) + """ self._session = aiohttp.ClientSession( request_class=ClientRequest, ws_response_class=ClientWebSocketResponse, @@ -49,8 +110,8 @@ def _request( method: str, url: str, *, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, + params: Optional[Mapping[str, Any]] = None, + data: Optional[dict[str, Any]] = None, auth: Optional[Auth] = Auth, **kwargs: Any, ) -> _RequestContextManager: @@ -76,6 +137,15 @@ def request( data: Any = None, **kwargs: Any, ) -> _RequestContextManager: + """ + :param method: GET, POST, PUT, DELETE などのHTTPメソッド + :param url: リクエストURL + :param params: URLのクエリ文字列(optional) + :param data: リクエストボディ(optional) + :param headers: リクエストヘッダー(optional) + :param auth: API自動認証の機能の有効/無効。デフォルトで有効。auth=Noneを指定することで無効になります(optional) + :param ``kwargs``: aiohttp.Client.requestに渡されるキーワード引数(optional) + """ return self._request(method, url, params=params, data=data, **kwargs) def get( @@ -118,12 +188,27 @@ async def ws_connect( self, url: str, *, - send_str: Optional[Union[str, List[str]]] = None, + send_str: Optional[Union[str, list[str]]] = None, + send_bytes: Optional[Union[bytes, list[bytes]]] = None, send_json: Any = None, hdlr_str: Optional[WsStrHandler] = None, + hdlr_bytes: Optional[WsBytesHandler] = None, hdlr_json: Optional[WsJsonHandler] = None, **kwargs: Any, ) -> asyncio.Task: + """ + :param url: WebSocket URL + :param send_str: WebSocketで送信する文字列。文字列、または文字列のリスト形式(optional) + :param send_json: WebSocketで送信する辞書オブジェクト。辞書、または辞書のリスト形式(optional) + :param hdlr_str: WebSocketの受信データをハンドリングする関数。 + 第1引数 msg に _str_型, 第2引数 ws にWebSocketClientResponse 型の変数が渡されます(optional) + :param hdlr_json: WebSocketの受信データをハンドリングする関数。 + 第1引数 msg に Any 型(JSON-like), 第2引数 ws に WebSocketClientResponse 型の変数が渡されます + (optional) + :param headers: リクエストヘッダー(optional) + :param auth: API自動認証の機能の有効/無効。デフォルトで有効。auth=Noneを指定することで無効になります(optional) + :param ``**kwargs``: aiohttp.ClientSession.ws_connectに渡されるキーワード引数(optional) + """ event = asyncio.Event() task = asyncio.create_task( ws_run_forever( @@ -131,8 +216,10 @@ async def ws_connect( self._session, event, send_str=send_str, + send_bytes=send_bytes, send_json=send_json, hdlr_str=hdlr_str, + hdlr_bytes=hdlr_bytes, hdlr_json=hdlr_json, **kwargs, ) @@ -141,7 +228,11 @@ async def ws_connect( return task @staticmethod - def _load_apis(apis: Union[Dict[str, List[str]], str]) -> Dict[str, List[str]]: + def _load_apis( + apis: Optional[Union[dict[str, list[str]], str]] + ) -> dict[str, list[str]]: + if apis is None: + apis = {} if isinstance(apis, dict): if apis: return apis @@ -165,7 +256,11 @@ def _load_apis(apis: Union[Dict[str, List[str]], str]) -> Dict[str, List[str]]: return {} @staticmethod - def _encode_apis(apis: Dict[str, List[str]]) -> Dict[str, Tuple[str, bytes]]: + def _encode_apis( + apis: Optional[dict[str, list[str]]] + ) -> dict[str, tuple[str, bytes]]: + if apis is None: + apis = {} encoded = {} for name in apis: if len(apis[name]) == 2: diff --git a/pybotters/models/binance.py b/pybotters/models/binance.py index 17887b4..960f8ab 100644 --- a/pybotters/models/binance.py +++ b/pybotters/models/binance.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import asyncio from collections import deque -from typing import Any, Awaitable, Dict, List, Optional, Union +from typing import Any, Awaitable, Optional, Union import aiohttp @@ -11,6 +13,10 @@ class BinanceDataStore(DataStoreManager): + """ + Binanceのデータストアマネージャー(※v0.4.0: Binance Futures USDⓈ-Mのみ) + """ + def _init(self) -> None: self.create('trade', datastore_class=Trade) self.create('markprice', datastore_class=MarkPrice) @@ -26,6 +32,22 @@ def _init(self) -> None: self.listenkey: Optional[str] = None async def initialize(self, *aws: Awaitable[aiohttp.ClientResponse]) -> None: + """ + 対応エンドポイント + + - GET /fapi/v1/depth (DataStore: orderbook) + + - Binance APIドキュメントに従ってWebSocket接続後にinitializeすること。 + - orderbook データストアの initialized がTrueになる。 + + - GET /fapi/v2/balance (DataStore: balance) + - GET /fapi/v2/positionRisk (DataStore: position) + - GET /fapi/v1/openOrders (DataStore: order) + - POST /fapi/v1/listenKey (Property: listenkey) + + - プロパティ listenkey にlistenKeyが格納され30分ごとに PUT /fapi/v1/listenKey + のリクエストがスケジュールされる。 + """ for f in asyncio.as_completed(aws): resp = await f data = await resp.json() @@ -119,6 +141,9 @@ def position(self) -> 'Position': @property def order(self) -> 'Order': + """ + アクティブオーダーのみ(約定・キャンセル済みは削除される) + """ return self.get('order', Order) @@ -132,7 +157,7 @@ def _onmessage(self, item: Item) -> None: class MarkPrice(DataStore): _KEYS = ['s'] - def _onmessage(self, data: Union[Item, List[Item]]) -> None: + def _onmessage(self, data: Union[Item, list[Item]]) -> None: if isinstance(data, list): self._update(data) else: @@ -156,7 +181,7 @@ def _onmessage(self, item: Item) -> None: class Ticker(DataStore): _KEYS = ['s'] - def _onmessage(self, data: Union[Item, List[Item]]) -> None: + def _onmessage(self, data: Union[Item, list[Item]]) -> None: if isinstance(data, list): self._update(data) else: @@ -183,7 +208,9 @@ def _init(self) -> None: self.initialized = False self._buff = deque(maxlen=200) - def sorted(self, query: Item = {}) -> Dict[str, List[float]]: + def sorted(self, query: Optional[Item] = None) -> dict[str, list[float]]: + if query is None: + query = {} result = {self._MAPSIDE[k]: [] for k in self._MAPSIDE} for item in self: if all(k in item and query[k] == item[k] for k in query): @@ -220,7 +247,7 @@ class Balance(DataStore): def _onmessage(self, item: Item) -> None: self._update(item['a']['B']) - def _onresponse(self, data: List[Item]) -> None: + def _onresponse(self, data: list[Item]) -> None: for item in data: self._update( [ @@ -239,7 +266,7 @@ class Position(DataStore): def _onmessage(self, item: Item) -> None: self._update(item['a']['P']) - def _onresponse(self, data: List[Item]) -> None: + def _onresponse(self, data: list[Item]) -> None: for item in data: self._update( [ @@ -264,7 +291,7 @@ def _onmessage(self, item: Item) -> None: else: self._delete([item['o']]) - def _onresponse(self, symbol: Optional[str], data: List[Item]) -> None: + def _onresponse(self, symbol: Optional[str], data: list[Item]) -> None: if symbol is not None: self._delete(self.find({'symbol': symbol})) else: diff --git a/pybotters/models/bitbank.py b/pybotters/models/bitbank.py index e899bea..7926ab9 100644 --- a/pybotters/models/bitbank.py +++ b/pybotters/models/bitbank.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import json -from typing import Dict, List +from typing import Optional from ..store import DataStore, DataStoreManager from ..typedefs import Item @@ -40,7 +42,7 @@ def ticker(self) -> 'Ticker': class Transactions(DataStore): _MAXLEN = 99999 - def _onmessage(self, room_name: str, data: List[Item]) -> None: + def _onmessage(self, room_name: str, data: list[Item]) -> None: data = data['transactions'] for item in data: pair = room_name.replace('transactions_', '') @@ -51,7 +53,9 @@ class Depth(DataStore): _KEYS = ['pair', 'side', 'price'] _BDSIDE = {'sell': 'asks', 'buy': 'bids'} - def sorted(self, query: Item = {}) -> Dict[str, List[float]]: + def sorted(self, query: Optional[Item] = None) -> dict[str, list[float]]: + if query is None: + query = {} result = {'asks': [], 'bids': []} for item in self: if all(k in item and query[k] == item[k] for k in query): @@ -60,7 +64,7 @@ def sorted(self, query: Item = {}) -> Dict[str, List[float]]: result['bids'].sort(key=lambda x: x[0], reverse=True) return result - def _onmessage(self, room_name: str, data: List[Item]) -> None: + def _onmessage(self, room_name: str, data: list[Item]) -> None: if 'whole' in room_name: pair = room_name.replace('depth_whole_', '') result = self.find({'pair': pair}) diff --git a/pybotters/models/bitflyer.py b/pybotters/models/bitflyer.py new file mode 100644 index 0000000..a480bfe --- /dev/null +++ b/pybotters/models/bitflyer.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +import asyncio +import logging +import operator +from decimal import Decimal +from typing import Awaitable + +import aiohttp + +from ..store import DataStore, DataStoreManager +from ..typedefs import Item +from ..ws import ClientWebSocketResponse + +logger = logging.getLogger(__name__) + + +class bitFlyerDataStore(DataStoreManager): + def _init(self) -> None: + self.create('board', datastore_class=Board) + self.create('ticker', datastore_class=Ticker) + self.create('executions', datastore_class=Executions) + self.create('childorderevents', datastore_class=ChildOrderEvents) + self.create('childorders', datastore_class=ChildOrders) + self.create('parentorderevents', datastore_class=ParentOrderEvents) + self.create('parentorders', datastore_class=ParentOrders) + self.create('positions', datastore_class=Positions) + self._snapshots = set() + + async def initialize(self, *aws: Awaitable[aiohttp.ClientResponse]) -> None: + for f in asyncio.as_completed(aws): + resp = await f + data = await resp.json() + if resp.url.path == '/v1/me/getchildorders': + self.childorders._onresponse(data) + elif resp.url.path == '/v1/me/getparentorders': + self.parentorders._onresponse(data) + elif resp.url.path == '/v1/me/getpositions': + self.positions._onresponse(data) + + def _onmessage(self, msg: Item, ws: ClientWebSocketResponse) -> None: + if 'error' in msg: + logger.warning(msg) + if 'params' in msg: + channel: str = msg['params']['channel'] + message = msg['params']['message'] + if channel.startswith('lightning_board_'): + if channel.startswith('lightning_board_snapshot_'): + asyncio.create_task( + ws.send_json( + { + 'method': 'unsubscribe', + 'params': {'channel': channel}, + } + ) + ) + product_code = channel.replace('lightning_board_snapshot_', '') + self.board._delete(self.board.find({'product_code': product_code})) + self._snapshots.add(product_code) + else: + product_code = channel.replace('lightning_board_', '') + if product_code in self._snapshots: + self.board._onmessage(product_code, message) + elif channel.startswith('lightning_ticker_'): + self.ticker._onmessage(message) + elif channel.startswith('lightning_executions_'): + product_code = channel.replace('lightning_executions_', '') + self.executions._onmessage(product_code, message) + elif channel == 'child_order_events': + self.childorderevents._onmessage(message) + self.childorders._onmessage(message) + self.positions._onmessage(message) + elif channel == 'parent_order_events': + self.parentorderevents._onmessage(message) + self.parentorders._onmessage(message) + + @property + def board(self) -> 'Board': + return self.get('board', Board) + + @property + def ticker(self) -> 'Ticker': + return self.get('ticker', Ticker) + + @property + def executions(self) -> 'Executions': + return self.get('executions', Executions) + + @property + def childorderevents(self) -> 'ChildOrderEvents': + return self.get('childorderevents', ChildOrderEvents) + + @property + def childorders(self) -> 'ChildOrders': + return self.get('childorders', ChildOrders) + + @property + def parentorderevents(self) -> 'ParentOrderEvents': + return self.get('parentorderevents', ParentOrderEvents) + + @property + def parentorders(self) -> 'ParentOrders': + return self.get('parentorders', ParentOrders) + + @property + def positions(self) -> 'Positions': + return self.get('positions', Positions) + + +class Board(DataStore): + _KEYS = ['product_code', 'side', 'price'] + + def _init(self) -> None: + self.mid_price: dict[str, float] = {} + + def sorted(self, query: Item = None) -> dict[str, list[Item]]: + if query is None: + query = {} + result = {'SELL': [], 'BUY': []} + for item in self: + if all(k in item and query[k] == item[k] for k in query): + result[item['side']].append(item) + result['SELL'].sort(key=lambda x: x['price']) + result['BUY'].sort(key=lambda x: x['price'], reverse=True) + return result + + def _onmessage(self, product_code: str, message: Item) -> None: + self.mid_price[product_code] = message['mid_price'] + for key, side in (('bids', 'BUY'), ('asks', 'SELL')): + for item in message[key]: + if item['size']: + self._insert([{'product_code': product_code, 'side': side, **item}]) + else: + self._delete([{'product_code': product_code, 'side': side, **item}]) + board = self.sorted({'product_code': product_code}) + targets = [] + for side, ope in (('BUY', operator.le), ('SELL', operator.gt)): + for item in board[side]: + if ope(item['price'], message['mid_price']): + break + else: + targets.append(item) + self._delete(targets) + + +class Ticker(DataStore): + _KEYS = ['product_code'] + + def _onmessage(self, message: Item) -> None: + self._update([message]) + + +class Executions(DataStore): + _MAXLEN = 99999 + + def _onmessage(self, product_code: str, message: list[Item]) -> None: + for item in message: + self._insert([{'product_code': product_code, **item}]) + + +class ChildOrderEvents(DataStore): + def _onmessage(self, message: list[Item]) -> None: + self._insert(message) + + +class ParentOrderEvents(DataStore): + def _onmessage(self, message: list[Item]) -> None: + self._insert(message) + + +class ChildOrders(DataStore): + _KEYS = ['child_order_acceptance_id'] + + def _onresponse(self, data: list[Item]) -> None: + if data: + self._delete(self.find({'product_code': data[0]['product_code']})) + for item in data: + if item['child_order_state'] == 'ACTIVE': + self._insert([item]) + + def _onmessage(self, message: list[Item]) -> None: + for item in message: + if item['event_type'] == 'ORDER': + self._insert([item]) + elif item['event_type'] in ('CANCEL', 'EXPIRE'): + self._delete([item]) + elif item['event_type'] == 'EXECUTION': + if item['outstanding_size']: + childorder = self.get(item) + if childorder: + if isinstance(childorder['size'], int) and isinstance( + item['size'], int + ): + childorder['size'] -= item['size'] + else: + childorder['size'] = float( + Decimal(str(childorder['size'])) + - Decimal(str(item['size'])) + ) + else: + self._delete([item]) + + +class ParentOrders(DataStore): + _KEYS = ['parent_order_acceptance_id'] + + def _onresponse(self, data: list[Item]) -> None: + if data: + self._delete(self.find({'product_code': data[0]['product_code']})) + for item in data: + if item['parent_order_state'] == 'ACTIVE': + self._insert([item]) + + def _onmessage(self, message: list[Item]) -> None: + for item in message: + if item['event_type'] == 'ORDER': + self._insert([item]) + elif item['event_type'] in ('CANCEL', 'EXPIRE'): + self._delete([item]) + elif item['event_type'] == 'COMPLETE': + parentorder = self.get(item) + if parentorder: + if parentorder['parent_order_type'] in ('IFD', 'IFDOCO'): + if item['parameter_index'] >= 2: + self._delete([item]) + else: + self._delete([item]) + + +class Positions(DataStore): + _COMMON_KEYS = [ + 'product_code', + 'side', + 'price', + 'size', + 'commission', + 'sfd', + ] + + def _common_keys(self, item: Item) -> Item: + return {key: item[key] for key in self._COMMON_KEYS} + + def _onresponse(self, data: list[Item]) -> None: + if data: + self._delete(self.find({'product_code': data[0]['product_code']})) + for item in data: + self._insert([self._common_keys(item)]) + + def _onmessage(self, message: list[Item]) -> None: + for item in message: + if item['event_type'] == 'EXECUTION': + positions = self._find_with_uuid({'product_code': item['product_code']}) + if positions: + if positions[next(iter(positions))]['side'] == item['side']: + self._insert([self._common_keys(item)]) + else: + for uid, pos in positions.items(): + if pos['size'] > item['size']: + if isinstance(pos['size'], int) and isinstance( + item['size'], int + ): + pos['size'] -= item['size'] + else: + pos['size'] = float( + Decimal(str(pos['size'])) + - Decimal(str(item['size'])) + ) + break + else: + if isinstance(pos['size'], int) and isinstance( + item['size'], int + ): + item['size'] -= pos['size'] + else: + item['size'] = float( + Decimal(str(item['size'])) + - Decimal(str(pos['size'])) + ) + self._remove([uid]) + if not pos['size']: + break + else: + try: + self._insert([self._common_keys(item)]) + except KeyError: + pass diff --git a/pybotters/models/bitmex.py b/pybotters/models/bitmex.py index 89aced7..03c8d98 100644 --- a/pybotters/models/bitmex.py +++ b/pybotters/models/bitmex.py @@ -7,6 +7,10 @@ class BitMEXDataStore(DataStoreManager): + """ + BitMEXのデータストアマネージャー + """ + def _onmessage(self, msg: Item, ws: ClientWebSocketResponse) -> None: if 'error' in msg: logger.warning(msg) @@ -73,6 +77,9 @@ def execution(self) -> DataStore: @property def order(self) -> DataStore: + """ + アクティブオーダーのみ(約定・キャンセル済みは削除される) + """ return self.get('order', DataStore) @property diff --git a/pybotters/models/bybit.py b/pybotters/models/bybit.py index 0a0db2d..e9aa08d 100644 --- a/pybotters/models/bybit.py +++ b/pybotters/models/bybit.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import asyncio import logging -from typing import Any, Awaitable, Dict, List, Optional, Union +from typing import Any, Awaitable, Optional, Union import aiohttp @@ -12,6 +14,17 @@ class BybitDataStore(DataStoreManager): + """ + Bybitのデータストアマネージャー + """ + + def __new__(cls) -> BybitDataStore: + logger.warning( + 'DEPRECATION WARNING: BybitDataStore will be changed to ' + 'BybitInverseDataStore and BybitUSDTDataStore' + ) + return super().__new__(cls) + def _init(self) -> None: self.create('orderbook', datastore_class=OrderBook) self.create('trade', datastore_class=Trade) @@ -28,6 +41,20 @@ def _init(self) -> None: self.timestamp_e6: Optional[int] = None async def initialize(self, *aws: Awaitable[aiohttp.ClientResponse]) -> None: + """ + 対応エンドポイント + + - GET /v2/private/order (DataStore: order) + - GET /private/linear/order/search (DataStore: order) + - GET /futures/private/order (DataStore: order) + - GET /v2/private/stop-order (DataStore: stoporder) + - GET /private/linear/stop-order/search (DataStore: stoporder) + - GET /futures/private/stop-order (DataStore: stoporder) + - GET /v2/private/position/list (DataStore: position_inverse) + - GET /futures/private/position/list (DataStore: position_inverse) + - GET /private/linear/position/list (DataStore: position_usdt) + - GET /v2/private/wallet/balance (DataStore: wallet) + """ for f in asyncio.as_completed(aws): resp = await f data = await resp.json() @@ -127,10 +154,16 @@ def liquidation(self) -> 'Liquidation': @property def position_inverse(self) -> 'PositionInverse': + """ + インバース契約(無期限/先物)用のポジション + """ return self.get('position_inverse', PositionInverse) @property def position_usdt(self) -> 'PositionUSDT': + """ + USDT契約用のポジション + """ return self.get('position_usdt', PositionUSDT) @property @@ -139,10 +172,16 @@ def execution(self) -> 'Execution': @property def order(self) -> 'Order': + """ + アクティブオーダーのみ(約定・キャンセル済みは削除される) + """ return self.get('order', Order) @property def stoporder(self) -> 'StopOrder': + """ + アクティブオーダーのみ(トリガー済みは削除される) + """ return self.get('stoporder', StopOrder) @property @@ -153,7 +192,9 @@ def wallet(self) -> 'Wallet': class OrderBook(DataStore): _KEYS = ['symbol', 'id', 'side'] - def sorted(self, query: Item = {}) -> Dict[str, List[Item]]: + def sorted(self, query: Optional[Item] = None) -> dict[str, list[Item]]: + if query is None: + query = {} result = {'Sell': [], 'Buy': []} for item in self: if all(k in item and query[k] == item[k] for k in query): @@ -162,7 +203,7 @@ def sorted(self, query: Item = {}) -> Dict[str, List[Item]]: result['Buy'].sort(key=lambda x: x['id'], reverse=True) return result - def _onmessage(self, topic: str, type_: str, data: Union[List[Item], Item]) -> None: + def _onmessage(self, topic: str, type_: str, data: Union[list[Item], Item]) -> None: if type_ == 'snapshot': symbol = topic.split('.')[-1] # ex: 'orderBookL2_25.BTCUSD' result = self.find({'symbol': symbol}) @@ -180,14 +221,14 @@ class Trade(DataStore): _KEYS = ['trade_id'] _MAXLEN = 99999 - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: self._insert(data) class Insurance(DataStore): _KEYS = ['currency'] - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: self._update(data) @@ -207,14 +248,14 @@ def _onmessage(self, topic: str, type_: str, data: Item) -> None: class Kline(DataStore): _KEYS = ['symbol', 'period', 'start'] - def _onmessage(self, topic: str, data: List[Item]) -> None: + def _onmessage(self, topic: str, data: list[Item]) -> None: topic_split = topic.split('.') # ex:'klineV2.1.BTCUSD' for item in data: item['symbol'] = topic_split[-1] item['period'] = topic_split[-2] self._update(data) - def _onresponse(self, data: List[Item]) -> None: + def _onresponse(self, data: list[Item]) -> None: for item in data: item["start"] = item.pop("open_time") item["period"] = item.pop("interval") @@ -242,13 +283,13 @@ class PositionInverse(DataStore): def getone(self, symbol: str) -> Optional[Item]: return self.get({'symbol': symbol, 'position_idx': 0}) - def getboth(self, symbol: str) -> Dict[str, Optional[Item]]: + def getboth(self, symbol: str) -> dict[str, Optional[Item]]: return { 'Sell': self.get({'symbol': symbol, 'position_idx': 2}), 'Buy': self.get({'symbol': symbol, 'position_idx': 1}), } - def _onresponse(self, data: Union[Item, List[Item]]) -> None: + def _onresponse(self, data: Union[Item, list[Item]]) -> None: if isinstance(data, dict): self._update([data]) elif isinstance(data, list): @@ -258,47 +299,47 @@ def _onresponse(self, data: Union[Item, List[Item]]) -> None: else: self._update(data) - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: self._update(data) class PositionUSDT(DataStore): _KEYS = ['symbol', 'side'] - def getboth(self, symbol: str) -> Dict[str, Optional[Item]]: + def getboth(self, symbol: str) -> dict[str, Optional[Item]]: return { 'Sell': self.get({'symbol': symbol, 'side': 'Sell'}), 'Buy': self.get({'symbol': symbol, 'side': 'Buy'}), } - def _onresponse(self, data: List[Item]) -> None: + def _onresponse(self, data: list[Item]) -> None: if len(data): if 'data' in data[0]: self._update([item['data'] for item in data]) else: self._update(data) - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: self._update(data) class Execution(DataStore): _KEYS = ['exec_id'] - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: self._update(data) class Order(DataStore): _KEYS = ['order_id'] - def _onresponse(self, data: List[Item]) -> None: + def _onresponse(self, data: list[Item]) -> None: if isinstance(data, list): self._update(data) elif isinstance(data, dict): self._update([data]) - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: for item in data: if item['order_status'] in ('Created', 'New', 'PartiallyFilled'): self._update([item]) @@ -309,13 +350,13 @@ def _onmessage(self, data: List[Item]) -> None: class StopOrder(DataStore): _KEYS = ['stop_order_id'] - def _onresponse(self, data: List[Item]) -> None: + def _onresponse(self, data: list[Item]) -> None: if isinstance(data, list): self._update(data) elif isinstance(data, dict): self._update([data]) - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: for item in data: if 'order_id' in item: item['stop_order_id'] = item.pop('order_id') @@ -330,7 +371,7 @@ def _onmessage(self, data: List[Item]) -> None: class Wallet(DataStore): _KEYS = ['coin'] - def _onresponse(self, data: Dict[str, Item]) -> None: + def _onresponse(self, data: dict[str, Item]) -> None: for coin, item in data.items(): self._update( [ @@ -342,7 +383,7 @@ def _onresponse(self, data: Dict[str, Item]) -> None: ] ) - def _onposition(self, data: List[Item]) -> None: + def _onposition(self, data: list[Item]) -> None: for item in data: symbol: str = item['symbol'] if symbol.endswith('USD'): @@ -359,7 +400,7 @@ def _onposition(self, data: List[Item]) -> None: ] ) - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: for item in data: self._update( [ diff --git a/pybotters/models/experimental/__init__.py b/pybotters/models/experimental/__init__.py new file mode 100644 index 0000000..659875c --- /dev/null +++ b/pybotters/models/experimental/__init__.py @@ -0,0 +1,7 @@ +from typing import Tuple +from .bybit import BybitInverseDataStore, BybitUSDTDataStore + +__all__: Tuple[str, ...] = ( + 'BybitInverseDataStore', + 'BybitUSDTDataStore', +) diff --git a/pybotters/models/experimental.py b/pybotters/models/experimental/bybit.py similarity index 70% rename from pybotters/models/experimental.py rename to pybotters/models/experimental/bybit.py index 1261edb..f281135 100644 --- a/pybotters/models/experimental.py +++ b/pybotters/models/experimental/bybit.py @@ -1,17 +1,23 @@ +from __future__ import annotations + import asyncio import logging -from typing import Awaitable, Dict, List, Optional, Union +from typing import Awaitable, Optional, Union import aiohttp -from ..store import DataStore, DataStoreManager -from ..typedefs import Item -from ..ws import ClientWebSocketResponse +from ...store import DataStore, DataStoreManager +from ...typedefs import Item +from ...ws import ClientWebSocketResponse logger = logging.getLogger(__name__) class BybitInverseDataStore(DataStoreManager): + """ + Bybit Inverse契約のデータストアマネージャー + """ + def _init(self) -> None: self.create("orderbook", datastore_class=OrderBookInverse) self.create("trade", datastore_class=TradeInverse) @@ -26,6 +32,16 @@ def _init(self) -> None: self.timestamp_e6: Optional[int] = None async def initialize(self, *aws: Awaitable[aiohttp.ClientResponse]) -> None: + """ + 対応エンドポイント + + - GET /v2/private/order (DataStore: order) + - GET /futures/private/order (DataStore: order) + - GET /v2/private/stop-order (DataStore: stoporder) + - GET /futures/private/stop-order (DataStore: stoporder) + - GET /v2/private/position/list (DataStore: position) + - GET /futures/private/position/list (DataStore: position) + """ for f in asyncio.as_completed(aws): resp = await f data = await resp.json() @@ -114,6 +130,9 @@ def liquidation(self) -> "LiquidationInverse": @property def position(self) -> "PositionInverse": + """ + インバース契約(無期限/先物)用のポジション + """ return self.get("position", PositionInverse) @property @@ -122,14 +141,24 @@ def execution(self) -> "ExecutionInverse": @property def order(self) -> "OrderInverse": + """ + アクティブオーダーのみ(約定・キャンセル済みは削除される) + """ return self.get("order", OrderInverse) @property def stoporder(self) -> "StopOrderInverse": + """ + アクティブオーダーのみ(トリガー済みは削除される) + """ return self.get("stoporder", StopOrderInverse) class BybitUSDTDataStore(DataStoreManager): + """ + Bybit USDT契約のデータストアマネージャー + """ + def _init(self) -> None: self.create("orderbook", datastore_class=OrderBookUSDT) self.create("trade", datastore_class=TradeUSDT) @@ -145,6 +174,13 @@ def _init(self) -> None: self.timestamp_e6: Optional[int] = None async def initialize(self, *aws: Awaitable[aiohttp.ClientResponse]) -> None: + """ + 対応エンドポイント + + - GET /private/linear/order/search (DataStore: order) + - GET /private/linear/stop-order/search (DataStore: stoporder) + - GET /private/linear/position/list (DataStore: position) + """ for f in asyncio.as_completed(aws): resp = await f data = await resp.json() @@ -220,6 +256,9 @@ def liquidation(self) -> "LiquidationUSDT": @property def position(self) -> "PositionUSDT": + """ + USDT契約用のポジション + """ return self.get("position", PositionUSDT) @property @@ -228,10 +267,16 @@ def execution(self) -> "ExecutionUSDT": @property def order(self) -> "OrderUSDT": + """ + アクティブオーダーのみ(約定・キャンセル済みは削除される) + """ return self.get("order", OrderUSDT) @property def stoporder(self) -> "StopOrderUSDT": + """ + アクティブオーダーのみ(トリガー済みは削除される) + """ return self.get("stoporder", StopOrderUSDT) @property @@ -239,42 +284,12 @@ def wallet(self) -> "Wallet": return self.get("wallet", Wallet) -class CastDataStore(DataStore): - _CAST_TYPES = {} - - def _cast(self, data: List[Item]) -> None: - for item in data: - for x in self._CAST_TYPES: - for k in self._CAST_TYPES[x]: - try: - item[k] = x(item[k]) - except KeyError: - pass - except TypeError: - pass - - def _insert(self, data: List[Item]) -> None: - self._cast(data) - super()._insert(data) - - def _update(self, data: List[Item]) -> None: - self._cast(data) - super()._update(data) - - def _delete(self, data: List[Item]) -> None: - self._cast(data) - super()._delete(data) - - -class OrderBookInverse(CastDataStore): +class OrderBookInverse(DataStore): _KEYS = ["symbol", "id", "side"] - _CAST_TYPES = { - float: [ - "price", - ], - } - def sorted(self, query: Item = {}) -> Dict[str, List[Item]]: + def sorted(self, query: Optional[Item] = None) -> dict[str, list[Item]]: + if query is None: + query = {} result = {"Sell": [], "Buy": []} for item in self: if all(k in item and query[k] == item[k] for k in query): @@ -283,7 +298,7 @@ def sorted(self, query: Item = {}) -> Dict[str, List[Item]]: result["Buy"].sort(key=lambda x: x["id"], reverse=True) return result - def _onmessage(self, topic: str, type_: str, data: Union[List[Item], Item]) -> None: + def _onmessage(self, topic: str, type_: str, data: Union[list[Item], Item]) -> None: if type_ == "snapshot": symbol = topic.split(".")[-1] # ex: "orderBookL2_25.BTCUSD", "orderBook_200.100ms.BTCUSD" @@ -297,16 +312,7 @@ def _onmessage(self, topic: str, type_: str, data: Union[List[Item], Item]) -> N class OrderBookUSDT(OrderBookInverse): - _CAST_TYPES = { - float: [ - "price", - ], - int: [ - "id", - ], - } - - def _onmessage(self, topic: str, type_: str, data: Union[List[Item], Item]) -> None: + def _onmessage(self, topic: str, type_: str, data: Union[list[Item], Item]) -> None: if type_ == "snapshot": symbol = topic.split(".")[-1] # ex: "orderBookL2_25.BTCUSDT", "orderBook_200.100ms.BTCUSDT" @@ -319,47 +325,27 @@ def _onmessage(self, topic: str, type_: str, data: Union[List[Item], Item]) -> N self._insert(data["insert"]) -class TradeInverse(CastDataStore): +class TradeInverse(DataStore): _KEYS = ['trade_id'] _MAXLEN = 99999 - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: self._insert(data) class TradeUSDT(TradeInverse): - _CAST_TYPES = { - float: [ - "price", - ], - int: [ - "trade_time_ms", - ], - } + ... -class Insurance(CastDataStore): +class Insurance(DataStore): _KEYS = ["currency"] - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: self._update(data) -class InstrumentInverse(CastDataStore): +class InstrumentInverse(DataStore): _KEYS = ["symbol"] - _CAST_TYPES = { - float: [ - "last_price", - "bid1_price", - "ask1_price", - "prev_price_24h", - "high_price_24h", - "low_price_24h", - "prev_price_1h", - "mark_price", - "index_price", - ], - } def _onmessage(self, topic: str, type_: str, data: Item) -> None: if type_ == "snapshot": @@ -372,77 +358,30 @@ def _onmessage(self, topic: str, type_: str, data: Item) -> None: class InstrumentUSDT(InstrumentInverse): - _CAST_TYPES = { - float: [ - "last_price", - "prev_price_24h", - "high_price_24h", - "low_price_24h", - "prev_price_1h", - "mark_price", - "index_price", - "bid1_price", - "ask1_price", - ], - int: [ - "last_price_e4", - "prev_price_24h_e4", - "price_24h_pcnt_e6", - "high_price_24h_e4", - "low_price_24h_e4", - "prev_price_1h_e4", - "price_1h_pcnt_e6", - "mark_price_e4", - "index_price_e4", - "open_interest_e8", - "total_turnover_e8", - "turnover_24h_e8", - "total_volume_e8", - "volume_24h_e8", - "funding_rate_e6", - "predicted_funding_rate_e6", - "cross_seq", - "count_down_hour", - "bid1_price_e4", - "ask1_price_e4", - ], - } - - -class KlineInverse(CastDataStore): + ... + + +class KlineInverse(DataStore): _KEYS = ["start", "symbol", "interval"] - def _onmessage(self, topic: str, data: List[Item]) -> None: + def _onmessage(self, topic: str, data: list[Item]) -> None: topic_split = topic.split(".") # ex:"klineV2.1.BTCUSD" for item in data: item["symbol"] = topic_split[-1] item["interval"] = topic_split[-2] self._update(data) - def _onresponse(self, data: List[Item]) -> None: + def _onresponse(self, data: list[Item]) -> None: for item in data: item["start"] = item.pop("open_time") self._update(data) class KlineUSDT(KlineInverse): - _CAST_TYPES = { - float: [ - "volume", - "turnover", - ], - } - - -class LiquidationInverse(CastDataStore): - _CAST_TYPES = { - float: [ - "price", - ], - int: [ - "qty", - ], - } + ... + + +class LiquidationInverse(DataStore): _MAXLEN = 99999 def _onmessage(self, item: Item) -> None: @@ -450,51 +389,22 @@ def _onmessage(self, item: Item) -> None: class LiquidationUSDT(LiquidationInverse): - _CAST_TYPES = { - float: [ - "price", - "qty", - ], - int: [ - "qty", - ], - } - - -class PositionInverse(CastDataStore): + ... + + +class PositionInverse(DataStore): _KEYS = ["symbol", "position_idx"] - _CAST_TYPES = { - float: [ - "position_value", - "entry_price", - "liq_price", - "bust_price", - "leverage", - "order_margin", - "position_margin", - "available_balance", - "take_profit", - "stop_loss", - "realised_pnl", - "trailing_stop", - "trailing_active", - "wallet_balance", - "occ_closing_fee", - "occ_funding_fee", - "cum_realised_pnl", - ], - } def one(self, symbol: str) -> Optional[Item]: return self.get({"symbol": symbol, "position_idx": 0}) - def both(self, symbol: str) -> Dict[str, Optional[Item]]: + def both(self, symbol: str) -> dict[str, Optional[Item]]: return { "Sell": self.get({"symbol": symbol, "position_idx": 2}), "Buy": self.get({"symbol": symbol, "position_idx": 1}), } - def _onresponse(self, data: Union[Item, List[Item]]) -> None: + def _onresponse(self, data: Union[Item, list[Item]]) -> None: if isinstance(data, dict): self._update([data]) # ex: {"symbol": "BTCUSD", ...} elif isinstance(data, list): @@ -514,68 +424,47 @@ def _onresponse(self, data: Union[Item, List[Item]]) -> None: self._update([item]) # ex: [{"symbol": "BTCUSDT", ...}, ...] - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: self._update(data) class PositionUSDT(PositionInverse): _KEYS = ["symbol", "side"] - _CAST_TYPES = { - int: [ - "user_id", - "auto_add_margin", - "position_id", - "position_seq", - "adl_rank_indicator", - "risk_id", - ], - } - - def both(self, symbol: str) -> Dict[str, Optional[Item]]: + + def one(self, symbol: str) -> dict[str, Optional[Item]]: + return { + "Sell": self.get({"symbol": symbol, "side": "Sell"}), + "Buy": self.get({"symbol": symbol, "side": "Buy"}), + } + + def both(self, symbol: str) -> dict[str, Optional[Item]]: return { "Sell": self.get({"symbol": symbol, "side": "Sell"}), "Buy": self.get({"symbol": symbol, "side": "Buy"}), } -class ExecutionInverse(CastDataStore): +class ExecutionInverse(DataStore): _KEYS = ["exec_id"] - _CAST_TYPES = { - float: [ - "price", - "exec_fee", - ], - } - - def _onmessage(self, data: List[Item]) -> None: + + def _onmessage(self, data: list[Item]) -> None: self._update(data) class ExecutionUSDT(ExecutionInverse): - _CAST_TYPES = {} + ... -class OrderInverse(CastDataStore): +class OrderInverse(DataStore): _KEYS = ["order_id"] - _CAST_TYPES = { - float: [ - "price", - "cum_exec_value", - "cum_exec_fee", - "take_profit", - "stop_loss", - "trailing_stop", - "last_exec_price", - ], - } - - def _onresponse(self, data: List[Item]) -> None: + + def _onresponse(self, data: list[Item]) -> None: if isinstance(data, list): self._update(data) elif isinstance(data, dict): self._update([data]) - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: for item in data: if item["order_status"] in ("Created", "New", "PartiallyFilled"): self._update([item]) @@ -584,25 +473,19 @@ def _onmessage(self, data: List[Item]) -> None: class OrderUSDT(OrderInverse): - _CAST_TYPES = {} + ... -class StopOrderInverse(CastDataStore): +class StopOrderInverse(DataStore): _KEYS = ["order_id"] - _CAST_TYPES = { - float: [ - "price", - "trigger_price", - ], - } - - def _onresponse(self, data: List[Item]) -> None: + + def _onresponse(self, data: list[Item]) -> None: if isinstance(data, list): self._update(data) elif isinstance(data, dict): self._update([data]) - def _onmessage(self, data: List[Item]) -> None: + def _onmessage(self, data: list[Item]) -> None: for item in data: if item["order_status"] in ("Active", "Untriggered"): self._update([item]) @@ -612,14 +495,9 @@ def _onmessage(self, data: List[Item]) -> None: class StopOrderUSDT(StopOrderInverse): _KEYS = ["stop_order_id"] - _CAST_TYPES = { - int: [ - "user_id", - ], - } -class Wallet(CastDataStore): - def _onmessage(self, data: List[Item]) -> None: +class Wallet(DataStore): + def _onmessage(self, data: list[Item]) -> None: self._clear() self._update(data) diff --git a/pybotters/models/ftx.py b/pybotters/models/ftx.py index 7bf5903..53eb38c 100644 --- a/pybotters/models/ftx.py +++ b/pybotters/models/ftx.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import asyncio import logging -from typing import Any, Awaitable, Dict, List +from typing import Any, Awaitable, Optional import aiohttp @@ -13,6 +15,10 @@ class FTXDataStore(DataStoreManager): + """ + FTXのデータストアマネージャー + """ + def _init(self) -> None: self.create('ticker', datastore_class=Ticker) self.create('markets', datastore_class=Markets) @@ -23,6 +29,15 @@ def _init(self) -> None: self.create('positions', datastore_class=Positions) async def initialize(self, *aws: Awaitable[aiohttp.ClientResponse]) -> None: + """ + 対応エンドポイント + + - GET /orders (DataStore: orders) + - GET /conditional_orders (DataStore: orders) + - GET /positions (DataStore: positions) + + - fills 受信時に GET /positions の自動フェッチする機能が有効化される。 + """ for f in asyncio.as_completed(aws): resp = await f data = await resp.json() @@ -83,6 +98,9 @@ def fills(self) -> 'Fills': @property def orders(self) -> 'Orders': + """ + アクティブオーダーのみ(約定・キャンセル済みは削除される) + """ return self.get('orders', Orders) @property @@ -109,7 +127,7 @@ def _onmessage(self, item: Item) -> None: class Trades(DataStore): _MAXLEN = 99999 - def _onmessage(self, market: str, data: List[Item]) -> None: + def _onmessage(self, market: str, data: list[Item]) -> None: for item in data: self._insert([{'market': market, **item}]) @@ -118,7 +136,9 @@ class OrderBook(DataStore): _KEYS = ['market', 'side', 'price'] _BDSIDE = {'sell': 'asks', 'buy': 'bids'} - def sorted(self, query: Item = {}) -> Dict[str, List[float]]: + def sorted(self, query: Optional[Item] = None) -> dict[str, list[float]]: + if query is None: + query = {} result = {'asks': [], 'bids': []} for item in self: if all(k in item and query[k] == item[k] for k in query): @@ -127,7 +147,7 @@ def sorted(self, query: Item = {}) -> Dict[str, List[float]]: result['bids'].sort(key=lambda x: x[0], reverse=True) return result - def _onmessage(self, market: str, data: List[Item]) -> None: + def _onmessage(self, market: str, data: list[Item]) -> None: if data['action'] == 'partial': result = self.find({'market': market}) self._delete(result) @@ -156,7 +176,7 @@ def _onmessage(self, item: Item) -> None: class Orders(DataStore): _KEYS = ['id'] - def _onresponse(self, data: List[Item]) -> None: + def _onresponse(self, data: list[Item]) -> None: if data: results = self.find({'market': data[0]['market']}) self._delete(results) @@ -175,7 +195,7 @@ class Positions(DataStore): def _init(self) -> None: self._fetch = False - def _onresponse(self, data: List[Item]) -> None: + def _onresponse(self, data: list[Item]) -> None: self._update(data) async def _onfills(self, session: aiohttp.ClientSession) -> None: diff --git a/pybotters/models/gmocoin.py b/pybotters/models/gmocoin.py index b196409..e068f21 100644 --- a/pybotters/models/gmocoin.py +++ b/pybotters/models/gmocoin.py @@ -1,16 +1,11 @@ +from __future__ import annotations + import asyncio import logging from datetime import datetime, timezone from decimal import Decimal from enum import Enum, auto -from typing import ( - Any, - Awaitable, - Dict, - List, - Optional, - cast, -) +from typing import Any, Awaitable, Optional, cast import aiohttp from pybotters.store import DataStore, DataStoreManager @@ -211,8 +206,8 @@ class OrderLevel(TypedDict): class OrderBook(TypedDict): - asks: List[OrderLevel] - bids: List[OrderLevel] + asks: list[OrderLevel] + bids: list[OrderLevel] symbol: Symbol timestamp: datetime @@ -294,8 +289,13 @@ def _onmessage(self, mes: Ticker) -> None: class OrderBookStore(DataStore): _KEYS = ["symbol", "side", "price"] - def sorted(self, query: Item = {}) -> Dict[OrderSide, List[OrderLevel]]: - result: Dict[OrderSide, List[OrderLevel]] = { + def _init(self) -> None: + self.timestamp: Optional[datetime] = None + + def sorted(self, query: Optional[Item] = None) -> dict[OrderSide, list[OrderLevel]]: + if query is None: + query = {} + result: dict[OrderSide, list[OrderLevel]] = { OrderSide.BUY: [], OrderSide.SELL: [], } @@ -310,7 +310,8 @@ def _onmessage(self, mes: OrderBook) -> None: data = mes["asks"] + mes["bids"] result = self.find({"symbol": mes["symbol"]}) self._delete(result) - self._insert(cast(List[Item], data)) + self._insert(cast(list[Item], data)) + self.timestamp = mes["timestamp"] class TradeStore(DataStore): @@ -321,8 +322,8 @@ def _onmessage(self, mes: Trade) -> None: class OrderStore(DataStore): _KEYS = ["order_id"] - def _onresponse(self, data: List[Order]) -> None: - self._insert(cast(List[Item], data)) + def _onresponse(self, data: list[Order]) -> None: + self._insert(cast(list[Item], data)) def _onmessage(self, mes: Order) -> None: if mes["order_status"] in (OrderStatus.WAITING, OrderStatus.ORDERED): @@ -348,7 +349,9 @@ def _onexecution(self, mes: Execution) -> None: class ExecutionStore(DataStore): _KEYS = ["execution_id"] - def sorted(self, query: Item = {}) -> List[Execution]: + def sorted(self, query: Optional[Item] = None) -> list[Execution]: + if query is None: + query = {} result = [] for item in self: if all(k in item and query[k] == item[k] for k in query): @@ -356,8 +359,8 @@ def sorted(self, query: Item = {}) -> List[Execution]: result.sort(key=lambda x: x["execution_id"], reverse=True) return result - def _onresponse(self, data: List[Execution]) -> None: - self._insert(cast(List[Item], data)) + def _onresponse(self, data: list[Execution]) -> None: + self._insert(cast(list[Item], data)) def _onmessage(self, mes: Execution) -> None: self._insert([cast(Item, mes)]) @@ -366,8 +369,8 @@ def _onmessage(self, mes: Execution) -> None: class PositionStore(DataStore): _KEYS = ["position_id"] - def _onresponse(self, data: List[Position]) -> None: - self._update(cast(List[Item], data)) + def _onresponse(self, data: list[Position]) -> None: + self._update(cast(list[Item], data)) def _onmessage(self, mes: Position, type: MessageType) -> None: if type == MessageType.OPR: @@ -381,8 +384,8 @@ def _onmessage(self, mes: Position, type: MessageType) -> None: class PositionSummaryStore(DataStore): _KEYS = ["symbol", "side"] - def _onresponse(self, data: List[PositionSummary]) -> None: - self._update(cast(List[Item], data)) + def _onresponse(self, data: list[PositionSummary]) -> None: + self._update(cast(list[Item], data)) def _onmessage(self, mes: PositionSummary) -> None: self._update([cast(Item, mes)]) @@ -390,7 +393,7 @@ def _onmessage(self, mes: PositionSummary) -> None: class MessageHelper: @staticmethod - def to_tickers(data: List[Item]) -> List["Ticker"]: + def to_tickers(data: list[Item]) -> list["Ticker"]: return [MessageHelper.to_ticker(x) for x in data] @staticmethod @@ -432,7 +435,7 @@ def to_orderbook(data: Item) -> "OrderBook": ) @staticmethod - def to_trades(data: List[Item]) -> List["Trade"]: + def to_trades(data: list[Item]) -> list["Trade"]: return [MessageHelper.to_trade(x) for x in data] @staticmethod @@ -446,7 +449,7 @@ def to_trade(data: Item) -> "Trade": ) @staticmethod - def to_executions(data: List[Item]) -> List["Execution"]: + def to_executions(data: list[Item]) -> list["Execution"]: return [MessageHelper.to_execution(x) for x in data] @staticmethod @@ -481,7 +484,7 @@ def to_execution(data: Item) -> "Execution": ) @staticmethod - def to_orders(data: List[Item]) -> List["Order"]: + def to_orders(data: list[Item]) -> list["Order"]: return [MessageHelper.to_order(x) for x in data] @staticmethod @@ -507,7 +510,7 @@ def to_order(data: Item) -> "Order": ) @staticmethod - def to_positions(data: List[Item]) -> List["Position"]: + def to_positions(data: list[Item]) -> list["Position"]: return [MessageHelper.to_position(x) for x in data] @staticmethod @@ -526,7 +529,7 @@ def to_position(data: Item) -> "Position": ) @staticmethod - def to_position_summaries(data: List[Item]) -> List["PositionSummary"]: + def to_position_summaries(data: list[Item]) -> list["PositionSummary"]: return [MessageHelper.to_position_summary(x) for x in data] @staticmethod @@ -545,6 +548,10 @@ def to_position_summary(data: Item) -> "PositionSummary": class GMOCoinDataStore(DataStoreManager): + """ + GMOコインのデータストアマネージャー + """ + def _init(self) -> None: self.create("ticker", datastore_class=TickerStore) self.create("orderbooks", datastore_class=OrderBookStore) @@ -555,6 +562,14 @@ def _init(self) -> None: self.create("position_summary", datastore_class=PositionSummaryStore) async def initialize(self, *aws: Awaitable[aiohttp.ClientResponse]) -> None: + """ + 対応エンドポイント + + - GET /private/v1/latestExecutions (DataStore: executions) + - GET /private/v1/activeOrders (DataStore: orders) + - GET /private/v1/openPositions (DataStore: positions) + - GET /private/v1/positionSummary (DataStore: position_summary) + """ for f in asyncio.as_completed(aws): resp = await f data = await resp.json() @@ -615,6 +630,9 @@ def trades(self) -> TradeStore: @property def orders(self) -> OrderStore: + """ + アクティブオーダーのみ(約定・キャンセル済みは削除される) + """ return self.get("orders", OrderStore) @property diff --git a/pybotters/store.py b/pybotters/store.py index 96164c0..d98bb01 100644 --- a/pybotters/store.py +++ b/pybotters/store.py @@ -1,17 +1,8 @@ +from __future__ import annotations + import asyncio import uuid -from typing import ( - Any, - cast, - Dict, - Hashable, - Iterator, - List, - Optional, - Tuple, - Type, - TypeVar, -) +from typing import Any, Hashable, Iterator, Optional, Type, TypeVar, cast from .typedefs import Item from .ws import ClientWebSocketResponse @@ -21,11 +12,20 @@ class DataStore: _KEYS = [] _MAXLEN = 9999 - def __init__(self, keys: List[str] = [], data: List[Item] = []) -> None: - self._data: Dict[uuid.UUID, Item] = {} - self._index: Dict[int, uuid.UUID] = {} - self._keys: Tuple[str, ...] = tuple(keys if keys else self._KEYS) - self._events: Dict[asyncio.Event, List[Item]] = {} + def __init__( + self, + keys: Optional[list[str]] = None, + data: Optional[list[Item]] = None, + *, + auto_cast: bool = False, + ) -> None: + self._data: dict[uuid.UUID, Item] = {} + self._index: dict[int, uuid.UUID] = {} + self._keys: tuple[str, ...] = tuple(keys if keys else self._KEYS) + self._events: dict[asyncio.Event, list[Item]] = {} + self._auto_cast = auto_cast + if data is None: + data = [] self._insert(data) if hasattr(self, '_init'): getattr(self, '_init')() @@ -37,12 +37,26 @@ def __iter__(self) -> Iterator[Item]: return iter(self._data.values()) @staticmethod - def _hash(item: Dict[str, Hashable]) -> int: + def _hash(item: dict[str, Hashable]) -> int: return hash(tuple(item.items())) - def _insert(self, data: List[Item]) -> None: + @staticmethod + def _cast_item(item: dict[str, Hashable]) -> None: + for k in item: + if isinstance(item[k], str): + try: + item[k] = int(item[k]) + except ValueError: + try: + item[k] = float(item[k]) + except ValueError: + pass + + def _insert(self, data: list[Item]) -> None: if self._keys: for item in data: + if self._auto_cast: + self._cast_item(item) try: keyitem = {k: item[k] for k in self._keys} except KeyError: @@ -58,15 +72,19 @@ def _insert(self, data: List[Item]) -> None: self._sweep_with_key() else: for item in data: + if self._auto_cast: + self._cast_item(item) _id = uuid.uuid4() self._data[_id] = item self._sweep_without_key() # !TODO! This behaviour might be undesirable. self._set(data) - def _update(self, data: List[Item]) -> None: + def _update(self, data: list[Item]) -> None: if self._keys: for item in data: + if self._auto_cast: + self._cast_item(item) try: keyitem = {k: item[k] for k in self._keys} except KeyError: @@ -82,15 +100,19 @@ def _update(self, data: List[Item]) -> None: self._sweep_with_key() else: for item in data: + if self._auto_cast: + self._cast_item(item) _id = uuid.uuid4() self._data[_id] = item self._sweep_without_key() # !TODO! This behaviour might be undesirable. self._set(data) - def _delete(self, data: List[Item]) -> None: + def _delete(self, data: list[Item]) -> None: if self._keys: for item in data: + if self._auto_cast: + self._cast_item(item) try: keyitem = {k: item[k] for k in self._keys} except KeyError: @@ -103,6 +125,21 @@ def _delete(self, data: List[Item]) -> None: # !TODO! This behaviour might be undesirable. self._set(data) + def _remove(self, uuids: list[uuid.UUID]) -> None: + if self._keys: + for _id in uuids: + if _id in self._data: + item = self._data[_id] + keyhash = self._hash({k: item[k] for k in self._keys}) + del self._data[_id] + del self._index[keyhash] + else: + for _id in uuids: + if _id in self._data: + del self._data[_id] + # !TODO! This behaviour might be undesirable. + self._set([]) + def _clear(self) -> None: self._data.clear() self._index.clear() @@ -150,7 +187,7 @@ def _pop(self, item: Item) -> Optional[Item]: del self._index[keyhash] return ret - def find(self, query: Item = {}) -> List[Item]: + def find(self, query: Optional[Item] = None) -> list[Item]: if query: return [ item @@ -160,7 +197,21 @@ def find(self, query: Item = {}) -> List[Item]: else: return list(self) - def _find_and_delete(self, query: Item = {}) -> List[Item]: + def _find_with_uuid(self, query: Optional[Item] = None) -> dict[uuid.UUID, Item]: + if query is None: + query = {} + if query: + return { + _id: item + for _id, item in self._data.items() + if all(k in item and query[k] == item[k] for k in query) + } + else: + return self._data + + def _find_and_delete(self, query: Optional[Item] = None) -> list[Item]: + if query is None: + query = {} if query: ret = [ item @@ -174,12 +225,14 @@ def _find_and_delete(self, query: Item = {}) -> List[Item]: self._clear() return ret - def _set(self, data: List[Item] = None) -> None: + def _set(self, data: Optional[list[Item]] = None) -> None: + if data is None: + data = [] for event in self._events: event.set() self._events[event].extend(data) - async def wait(self) -> List[Item]: + async def wait(self) -> list[Item]: event = asyncio.Event() ret = [] self._events[event] = ret @@ -192,10 +245,15 @@ async def wait(self) -> List[Item]: class DataStoreManager: - def __init__(self) -> None: - self._stores: Dict[str, DataStore] = {} - self._events: List[asyncio.Event] = [] + """ + データストアマネージャーの抽象クラスです。 データストアの作成・参照・ハンドリングなどの役割を持ちます。 それぞれの取引所のクラスが継承します。 + """ + + def __init__(self, auto_cast: bool = False) -> None: + self._stores: dict[str, DataStore] = {} + self._events: list[asyncio.Event] = [] self._iscorofunc = asyncio.iscoroutinefunction(self._onmessage) + self._auto_cast = auto_cast if hasattr(self, '_init'): getattr(self, '_init')() @@ -209,11 +267,15 @@ def create( self, name: str, *, - keys: List[str] = [], - data: List[Item] = [], + keys: Optional[list[str]] = None, + data: Optional[list[Item]] = None, datastore_class: Type[DataStore] = DataStore, ) -> None: - self._stores[name] = datastore_class(keys, data) + if keys is None: + keys = [] + if data is None: + data = [] + self._stores[name] = datastore_class(keys, data, auto_cast=self._auto_cast) def get(self, name: str, type: Type[TDataStore]) -> TDataStore: return cast(type, self._stores.get(name)) @@ -222,6 +284,9 @@ def _onmessage(self, msg: Any, ws: ClientWebSocketResponse) -> None: print(msg) def onmessage(self, msg: Any, ws: ClientWebSocketResponse) -> None: + """ + Clientクラスws_connectメソッドの引数send_jsonに渡すハンドラです。 + """ self._onmessage(msg, ws) self._set() @@ -231,6 +296,9 @@ def _set(self) -> None: self._events.clear() async def wait(self) -> None: + """ + 非同期メソッド。onmessageのイベントがあるまで待機します。 + """ event = asyncio.Event() self._events.append(event) await event.wait() diff --git a/pybotters/typedefs.py b/pybotters/typedefs.py index 3063b92..0ea6d5d 100644 --- a/pybotters/typedefs.py +++ b/pybotters/typedefs.py @@ -5,6 +5,9 @@ WsStrHandler = Callable[ [str, ClientWebSocketResponse], Optional[Coroutine[Any, Any, None]] ] +WsBytesHandler = Callable[ + [bytes, ClientWebSocketResponse], Optional[Coroutine[Any, Any, None]] +] WsJsonHandler = Callable[ [Any, ClientWebSocketResponse], Optional[Coroutine[Any, Any, None]] ] diff --git a/pybotters/ws.py b/pybotters/ws.py index fab6eca..2006e24 100644 --- a/pybotters/ws.py +++ b/pybotters/ws.py @@ -1,13 +1,16 @@ +from __future__ import annotations + import asyncio import base64 import datetime import hashlib import hmac +import inspect import logging import time from dataclasses import dataclass from secrets import token_hex -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import aiohttp from aiohttp.http_websocket import json @@ -20,14 +23,24 @@ logger = logging.getLogger(__name__) +def pretty_modulename(e: Exception) -> str: + modulename = e.__class__.__name__ + module = inspect.getmodule(e) + if module: + modulename = f'{module.__name__}.{modulename}' + return modulename + + async def ws_run_forever( url: StrOrURL, session: aiohttp.ClientSession, event: asyncio.Event, *, - send_str: Optional[Union[str, List[str]]] = None, + send_str: Optional[Union[str, list[str]]] = None, + send_bytes: Optional[Union[bytes, list[bytes]]] = None, send_json: Any = None, hdlr_str=None, + hdlr_bytes=None, hdlr_json=None, auth=_Auth, **kwargs: Any, @@ -35,6 +48,7 @@ async def ws_run_forever( if all([hdlr_str is None, hdlr_json is None]): hdlr_json = pybotters.print_handler iscorofunc_str = asyncio.iscoroutinefunction(hdlr_str) + iscorofunc_bytes = asyncio.iscoroutinefunction(hdlr_bytes) iscorofunc_json = asyncio.iscoroutinefunction(hdlr_json) while not session.closed: cooldown = asyncio.create_task(asyncio.sleep(60.0)) @@ -48,6 +62,13 @@ async def ws_run_forever( await asyncio.gather(*[ws.send_str(item) for item in send_str]) else: await ws.send_str(send_str) + if send_bytes is not None: + if isinstance(send_bytes, list): + await asyncio.gather( + *[ws.send_bytes(item) for item in send_bytes] + ) + else: + await ws.send_bytes(send_bytes) if send_json is not None: if isinstance(send_json, list): await asyncio.gather( @@ -64,7 +85,7 @@ async def ws_run_forever( else: hdlr_str(msg.data, ws) except Exception as e: - logger.error(repr(e)) + logger.exception(f'{pretty_modulename(e)}: {e}') if hdlr_json is not None: try: data = msg.json() @@ -77,11 +98,24 @@ async def ws_run_forever( else: hdlr_json(data, ws) except Exception as e: - logger.error(repr(e)) + logger.exception(f'{pretty_modulename(e)}: {e}') + elif msg.type == aiohttp.WSMsgType.BINARY: + if hdlr_bytes is not None: + try: + if iscorofunc_bytes: + await hdlr_bytes(msg.data, ws) + else: + hdlr_bytes(msg.data, ws) + except Exception as e: + logger.exception(f'{pretty_modulename(e)}: {e}') elif msg.type == aiohttp.WSMsgType.ERROR: break - except (aiohttp.WSServerHandshakeError, aiohttp.ClientOSError) as e: - logger.warning(repr(e)) + except ( + aiohttp.WSServerHandshakeError, + aiohttp.ClientOSError, + ConnectionResetError, + ) as e: + logger.warning(f'{pretty_modulename(e)}: {e}') await cooldown diff --git a/tests/test_client.py b/tests/test_client.py index f758f6f..2def1cd 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,11 +3,10 @@ from unittest.mock import mock_open import aiohttp +import pybotters import pytest import pytest_mock -import pybotters - async def test_client(): apis = { @@ -50,7 +49,7 @@ async def test_client_open(mocker: pytest_mock.MockerFixture): async def test_client_warn(mocker: pytest_mock.MockerFixture): apis = {'name1', 'key1', 'secret1'} base_url = 'http://example.com' - async with pybotters.Client(apis=apis, base_url=base_url) as client: + async with pybotters.Client(apis=apis, base_url=base_url) as client: # type: ignore assert isinstance(client._session, aiohttp.ClientSession) assert not client._session.closed assert client._base_url == base_url @@ -149,7 +148,25 @@ async def test_client_ws_connect_json(mocker: pytest_mock.MockerFixture): ret = await client.ws_connect( 'ws://test.org', send_json={'foo': 'bar'}, - hdlr_json=lambda msg, ws: ..., + hdlr_json=lambda msg, ws: None, + ) + assert coro.called + assert task.called + assert ret == task.return_value + + +@pytest.mark.asyncio +async def test_client_ws_connect_bytes(mocker: pytest_mock.MockerFixture): + event = asyncio.Event() + event.set() + mocker.patch('asyncio.Event', return_value=event) + task = mocker.patch('asyncio.create_task') + coro = mocker.patch('pybotters.client.ws_run_forever') + async with pybotters.Client() as client: + ret = await client.ws_connect( + 'ws://test.org', + send_bytes=b'{"foo":"bar"}', + hdlr_bytes=lambda msg, ws: ..., ) assert coro.called assert task.called diff --git a/tests/test_store.py b/tests/test_store.py index a9cbdc0..4863dcf 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -16,6 +16,11 @@ def test_interface(): assert 'example' in store assert isinstance(store['example'], pybotters.store.DataStore) + store = pybotters.store.DataStoreManager(auto_cast=True) + assert store._auto_cast is True + store.create('example') + store['example']._auto_cast is True + @pytest.mark.asyncio async def test_interface_onmessage(mocker: pytest_mock.MockerFixture): @@ -66,6 +71,29 @@ def test_hash(): assert isinstance(hashed, int) +def test_cast_item(): + actual = { + 'num_int': 123, + 'num_float': 1.23, + 'str_int': "123", + 'str_float': "1.23", + 'str_orig': "foo", + 'bool': True, + 'null': None, + } + expected = { + 'num_int': 123, + 'num_float': 1.23, + 'str_int': 123, + 'str_float': 1.23, + 'str_orig': "foo", + 'bool': True, + 'null': None, + } + pybotters.store.DataStore._cast_item(actual) + assert expected == actual + + def test_sweep_with_key(): data = [{'foo': f'bar{i}'} for i in range(1000)] ds = pybotters.store.DataStore(keys=['foo'], data=data)