diff --git a/tests/conftest.py b/tests/conftest.py index 99123660..e7a21970 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -131,7 +131,9 @@ async def topic_path(driver, topic_consumer, database) -> str: @pytest.fixture() @pytest.mark.asyncio() async def topic_with_messages(driver, topic_path): - writer = driver.topic_client.writer(topic_path, producer_id="fixture-producer-id") + writer = driver.topic_client.writer( + topic_path, producer_id="fixture-producer-id", codec=ydb.TopicCodec.RAW + ) await writer.write_with_ack( [ ydb.TopicWriterMessage(data="123".encode()), diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 9e8b0dfe..c53ce0db 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -84,6 +84,20 @@ async def test_write_multi_message_with_ack( assert batch.messages[0].seqno == 2 assert batch.messages[0].data == "456".encode() + @pytest.mark.parametrize( + "codec", + [ + ydb.TopicCodec.RAW, + ydb.TopicCodec.GZIP, + None, + ], + ) + async def test_write_encoded(self, driver: ydb.Driver, topic_path: str, codec): + async with driver.topic_client.writer(topic_path, codec=codec) as writer: + writer.write("a" * 1000) + writer.write("b" * 1000) + writer.write("c" * 1000) + class TestTopicWriterSync: def test_send_message(self, driver_sync: ydb.Driver, topic_path): @@ -163,3 +177,17 @@ def test_write_multi_message_with_ack( assert batch.messages[0].offset == 1 assert batch.messages[0].seqno == 2 assert batch.messages[0].data == "456".encode() + + @pytest.mark.parametrize( + "codec", + [ + ydb.TopicCodec.RAW, + ydb.TopicCodec.GZIP, + None, + ], + ) + def test_write_encoded(self, driver_sync: ydb.Driver, topic_path: str, codec): + with driver_sync.topic_client.writer(topic_path, codec=codec) as writer: + writer.write("a" * 1000) + writer.write("b" * 1000) + writer.write("c" * 1000) diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py index 6d922137..4582f19a 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py @@ -31,11 +31,18 @@ class CreateTopicRequestParams: class PublicCodec(int): + """ + Codec value may contain any int number. + + Values below is only well-known predefined values, + but protocol support custom codecs. + """ + UNSPECIFIED = 0 RAW = 1 GZIP = 2 - LZOP = 3 - ZSTD = 4 + LZOP = 3 # Has not supported codec in standard library + ZSTD = 4 # Has not supported codec in standard library class PublicMeteringMode(IntEnum): diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index a2c3d0d7..92212f65 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -1,3 +1,4 @@ +import concurrent.futures import datetime import enum import uuid @@ -10,9 +11,9 @@ import ydb.aio from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage from .._grpc.grpcwrapper.common_utils import IToProto +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec - -MessageType = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] +Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] @dataclass @@ -29,8 +30,14 @@ class PublicWriterSettings: partition_id: Optional[int] = None auto_seqno: bool = True auto_created_at: bool = True + codec: Optional[PublicCodec] = None # default mean auto-select + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None # default shared client executor pool + encoders: Optional[ + typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]] + ] = None # get_last_seqno: bool = False - # encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None # serializer: Union[Callable[[Any], bytes], None] = None # send_buffer_count: Optional[int] = 10000 # send_buffer_bytes: Optional[int] = 100 * 1024 * 1024 @@ -85,8 +92,9 @@ class SendMode(Enum): @dataclass class PublicWriterInitInfo: - __slots__ = "last_seqno" + __slots__ = ("last_seqno", "supported_codecs") last_seqno: Optional[int] + supported_codecs: List[PublicCodec] class PublicMessage: @@ -108,24 +116,24 @@ def __init__( self.data = data @staticmethod - def _create_message( - data: Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] - ) -> "PublicMessage": + def _create_message(data: Message) -> "PublicMessage": if isinstance(data, PublicMessage): return data return PublicMessage(data=data) class InternalMessage(StreamWriteMessage.WriteRequest.MessageData, IToProto): + codec: PublicCodec + def __init__(self, mess: PublicMessage): - StreamWriteMessage.WriteRequest.MessageData.__init__( - self, + super().__init__( seq_no=mess.seqno, created_at=mess.created_at, data=mess.data, uncompressed_size=len(mess.data), partitioning=None, ) + self.codec = PublicCodec.RAW def get_bytes(self) -> bytes: if self.data is None: diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 67b1be69..5e3bb455 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,7 +1,10 @@ import asyncio +import concurrent.futures import datetime +import gzip +import typing from collections import deque -from typing import Deque, AsyncIterator, Union, List +from typing import Deque, AsyncIterator, Union, List, Optional, Dict, Callable import ydb from .topic_writer import ( @@ -14,7 +17,7 @@ TopicWriterError, messages_to_proto_requests, PublicWriteResultTypes, - MessageType, + Message, ) from .. import ( _apis, @@ -22,6 +25,7 @@ check_retriable_error, RetrySettings, ) +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._topic_common.common import ( TokenGetterFuncType, ) @@ -75,7 +79,7 @@ async def close(self, *, flush: bool = True): async def write_with_ack( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ) -> Union[PublicWriteResultTypes, List[PublicWriteResultTypes]]: """ IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. @@ -96,7 +100,7 @@ async def write_with_ack( async def write_with_ack_future( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ) -> Union[asyncio.Future, List[asyncio.Future]]: """ send one or number of messages to server. @@ -107,13 +111,14 @@ async def write_with_ack_future( For wait with timeout use asyncio.wait_for. """ input_single_message = not isinstance(messages, list) - if isinstance(messages, PublicMessage): - messages = [PublicMessage._create_message(messages)] + converted_messages = [] if isinstance(messages, list): - for index, m in enumerate(messages): - messages[index] = PublicMessage._create_message(m) + for m in messages: + converted_messages.append(PublicMessage._create_message(m)) + else: + converted_messages = [PublicMessage._create_message(messages)] - futures = await self._reconnector.write_with_ack_future(messages) + futures = await self._reconnector.write_with_ack_future(converted_messages) if input_single_message: return futures[0] else: @@ -121,7 +126,7 @@ async def write_with_ack_future( async def write( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ): """ send one or number of messages to server. @@ -152,7 +157,7 @@ async def wait_init(self) -> PublicWriterInitInfo: class WriterAsyncIOReconnector: _closed: bool _loop: asyncio.AbstractEventLoop - _credentials: Union[ydb.Credentials, None] + _credentials: Union[ydb.credentials.Credentials, None] _driver: ydb.aio.Driver _update_token_interval: int _token_get_function: TokenGetterFuncType @@ -160,8 +165,18 @@ class WriterAsyncIOReconnector: _init_info: asyncio.Future _stream_connected: asyncio.Event _settings: WriterSettings + _codec: PublicCodec + _codec_functions: Dict[PublicCodec, Callable[[bytes], bytes]] + _encode_executor: Optional[concurrent.futures.Executor] + _codec_selector_batch_num: int + _codec_selector_last_codec: Optional[PublicCodec] + _codec_selector_check_batches_interval: int _last_known_seq_no: int + if typing.TYPE_CHECKING: + _messages_for_encode: asyncio.Queue[List[InternalMessage]] + else: + _messages_for_encode: asyncio.Queue _messages: Deque[InternalMessage] _messages_future: Deque[asyncio.Future] _new_messages: asyncio.Queue @@ -179,13 +194,37 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._stream_connected = asyncio.Event() self._settings = settings + self._codec_functions = { + PublicCodec.RAW: lambda data: data, + PublicCodec.GZIP: gzip.compress, + } + + if settings.encoders: + self._codec_functions.update(settings.encoders) + + self._encode_executor = settings.encoder_executor + + self._codec_selector_batch_num = 0 + self._codec_selector_last_codec = None + self._codec_selector_check_batches_interval = 10000 + + self._codec = self._settings.codec + if self._codec and self._codec not in self._codec_functions: + known_codecs = sorted(self._codec_functions.keys()) + raise ValueError( + "Unknown codec for writer: %s, supported codecs: %s" + % (self._codec, known_codecs) + ) + self._last_known_seq_no = 0 + self._messages_for_encode = asyncio.Queue() self._messages = deque() self._messages_future = deque() self._new_messages = asyncio.Queue() self._stop_reason = self._loop.create_future() self._background_tasks = [ - asyncio.create_task(self._connection_loop(), name="connection_loop") + asyncio.create_task(self._connection_loop(), name="connection_loop"), + asyncio.create_task(self._encode_loop(), name="encode_loop"), ] async def close(self, flush: bool): @@ -238,15 +277,23 @@ async def write_with_ack_future( internal_messages = self._prepare_internal_messages(messages) messages_future = [self._loop.create_future() for _ in internal_messages] - self._messages.extend(internal_messages) self._messages_future.extend(messages_future) - for m in internal_messages: - self._new_messages.put_nowait(m) + if self._codec == PublicCodec.RAW: + self._add_messages_to_send_queue(internal_messages) + else: + self._messages_for_encode.put_nowait(internal_messages) return messages_future - def _prepare_internal_messages(self, messages: List[PublicMessage]): + def _add_messages_to_send_queue(self, internal_messages: List[InternalMessage]): + self._messages.extend(internal_messages) + for m in internal_messages: + self._new_messages.put_nowait(m) + + def _prepare_internal_messages( + self, messages: List[PublicMessage] + ) -> List[InternalMessage]: if self._settings.auto_created_at: now = datetime.datetime.now() else: @@ -307,7 +354,10 @@ async def _connection_loop(self): try: self._last_known_seq_no = stream_writer.last_seqno self._init_info.set_result( - PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) + PublicWriterInitInfo( + last_seqno=stream_writer.last_seqno, + supported_codecs=stream_writer.supported_codecs, + ) ) except asyncio.InvalidStateError: pass @@ -350,6 +400,121 @@ async def _connection_loop(self): task.cancel() await asyncio.wait(pending) + async def _encode_loop(self): + while True: + messages = await self._messages_for_encode.get() + while not self._messages_for_encode.empty(): + messages.extend(self._messages_for_encode.get_nowait()) + + batch_codec = await self._codec_selector(messages) + await self._encode_data_inplace(batch_codec, messages) + self._add_messages_to_send_queue(messages) + + async def _encode_data_inplace( + self, codec: PublicCodec, messages: List[InternalMessage] + ): + if codec == PublicCodec.RAW: + return + + eventloop = asyncio.get_running_loop() + encode_waiters = [] + encoder_function = self._codec_functions[codec] + + for message in messages: + encoded_data_futures = eventloop.run_in_executor( + self._encode_executor, encoder_function, message.get_bytes() + ) + encode_waiters.append(encoded_data_futures) + + encoded_datas = await asyncio.gather(*encode_waiters) + + for index, data in enumerate(encoded_datas): + message = messages[index] + message.codec = codec + message.data = data + + async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec: + if self._codec is not None: + return self._codec + + if self._codec_selector_last_codec is None: + available_codecs = await self._get_available_codecs() + + # use every of available encoders at start for prevent problems + # with rare used encoders (on writer or reader side) + if self._codec_selector_batch_num < len(available_codecs): + codec = available_codecs[self._codec_selector_batch_num] + else: + codec = await self._codec_selector_by_check_compress(messages) + self._codec_selector_last_codec = codec + else: + if ( + self._codec_selector_batch_num + % self._codec_selector_check_batches_interval + == 0 + ): + self._codec_selector_last_codec = ( + await self._codec_selector_by_check_compress(messages) + ) + codec = self._codec_selector_last_codec + self._codec_selector_batch_num += 1 + return codec + + async def _get_available_codecs(self) -> List[PublicCodec]: + info = await self.wait_init() + topic_supported_codecs = info.supported_codecs + if not topic_supported_codecs: + topic_supported_codecs = [PublicCodec.RAW, PublicCodec.GZIP] + + res = [] + for codec in topic_supported_codecs: + if codec in self._codec_functions: + res.append(codec) + + if not res: + raise TopicWriterError("Writer does not support topic's codecs") + + res.sort() + + return res + + async def _codec_selector_by_check_compress( + self, messages: List[InternalMessage] + ) -> PublicCodec: + """ + Try to compress messages and choose codec with the smallest result size. + """ + + test_messages = messages[:10] + + available_codecs = await self._get_available_codecs() + if len(available_codecs) == 1: + return available_codecs[0] + + def get_compressed_size(codec) -> int: + s = 0 + f = self._codec_functions[codec] + + for m in test_messages: + encoded = f(m.get_bytes()) + s += len(encoded) + + return s + + def select_codec() -> PublicCodec: + min_codec = available_codecs[0] + min_size = get_compressed_size(min_codec) + for codec in available_codecs[1:]: + size = get_compressed_size(codec) + if size < min_size: + min_codec = codec + min_size = size + return min_codec + + loop = asyncio.get_running_loop() + codec = await loop.run_in_executor(self._encode_executor, select_codec) + return codec + async def _read_loop(self, writer: "WriterAsyncIOStream"): while True: resp = await writer.receive() @@ -412,6 +577,7 @@ class WriterAsyncIOStream: # todo slots last_seqno: int + supported_codecs: Optional[List[PublicCodec]] _stream: IGrpcWrapperAsyncIO _token_getter: TokenGetterFuncType @@ -466,6 +632,7 @@ async def _start( raise TopicWriterError("Unexpected answer for init request: %s" % resp) self.last_seqno = resp.last_seq_no + self.supported_codecs = [PublicCodec(codec) for codec in resp.supported_codecs] self._stream = stream diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 32bb3de5..921c6aa4 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -4,14 +4,15 @@ import copy import dataclasses import datetime +import gzip import typing from queue import Queue, Empty +from typing import List from unittest import mock import freezegun import pytest - from .. import aio from .. import StatusCode, issues from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage @@ -25,7 +26,8 @@ PublicWriteResult, TopicWriterError, ) -from .._topic_common.test_helpers import StreamMock +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec +from .._topic_common.test_helpers import StreamMock, wait_for_fast from .topic_writer_asyncio import ( WriterAsyncIOStream, @@ -151,9 +153,11 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): @pytest.mark.asyncio class TestWriterAsyncIOReconnector: init_last_seqno = 0 + time_for_mocks = 1678046714.639387 class StreamWriterMock: last_seqno: int + supported_codecs: List[PublicCodec] from_client: asyncio.Queue from_server: asyncio.Queue @@ -165,6 +169,7 @@ def __init__(self): self.from_server = asyncio.Queue() self.from_client = asyncio.Queue() self._closed = False + self.supported_codecs = [] def write(self, messages: typing.List[InternalMessage]): if self._closed: @@ -184,10 +189,7 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: def close(self): if self._closed: return - - self.from_server.put_nowait( - Exception("waited message while StreamWriterMock closed") - ) + self._closed = True @pytest.fixture(autouse=True) async def stream_writer_double_queue(self, monkeypatch): @@ -241,6 +243,7 @@ def default_settings(self) -> WriterSettings: producer_id="test-producer", auto_seqno=False, auto_created_at=False, + codec=PublicCodec.RAW, ) ) @@ -347,7 +350,9 @@ async def wait_stop(): async def test_wait_init(self, default_driver, default_settings, get_stream_writer): init_seqno = 100 - expected_init_info = PublicWriterInitInfo(init_seqno) + expected_init_info = PublicWriterInitInfo( + last_seqno=init_seqno, supported_codecs=[] + ) with mock.patch.object( TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno ): @@ -457,6 +462,145 @@ async def test_auto_created_at( ] == sent await reconnector.close(flush=False) + @pytest.mark.parametrize( + "codec,write_datas,expected_codecs,expected_datas", + [ + ( + PublicCodec.RAW, + [b"123"], + [PublicCodec.RAW], + [b"123"], + ), + ( + PublicCodec.GZIP, + [b"123"], + [PublicCodec.GZIP], + [gzip.compress(b"123", mtime=time_for_mocks)], + ), + ( + None, + [b"123", b"456", b"789", b"0" * 1000], + [PublicCodec.RAW, PublicCodec.GZIP, PublicCodec.RAW, PublicCodec.RAW], + [ + b"123", + gzip.compress(b"456", mtime=time_for_mocks), + b"789", + b"0" * 1000, + ], + ), + ( + None, + [b"123", b"456", b"789" * 1000, b"0"], + [PublicCodec.RAW, PublicCodec.GZIP, PublicCodec.GZIP, PublicCodec.GZIP], + [ + b"123", + gzip.compress(b"456", mtime=time_for_mocks), + gzip.compress(b"789" * 1000, mtime=time_for_mocks), + gzip.compress(b"0", mtime=time_for_mocks), + ], + ), + ], + ) + async def test_select_codecs( + self, + default_driver: aio.Driver, + default_settings: WriterSettings, + monkeypatch, + write_datas: List[typing.Optional[bytes]], + codec: typing.Optional[PublicCodec], + expected_codecs: List[PublicCodec], + expected_datas: List[bytes], + ): + assert len(write_datas) == len(expected_datas) + assert len(expected_codecs) == len(expected_datas) + + settings = copy.copy(default_settings) + settings.codec = codec + settings.auto_seqno = True + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + added_messages = asyncio.Queue() # type: asyncio.Queue[List[InternalMessage]] + + def add_messages(_self, messages: typing.List[InternalMessage]): + added_messages.put_nowait(messages) + + monkeypatch.setattr( + WriterAsyncIOReconnector, "_add_messages_to_send_queue", add_messages + ) + monkeypatch.setattr( + "time.time", lambda: TestWriterAsyncIOReconnector.time_for_mocks + ) + + for i in range(len(expected_datas)): + await reconnector.write_with_ack_future( + [PublicMessage(data=write_datas[i])] + ) + mess = await asyncio.wait_for(added_messages.get(), timeout=600) + mess = mess[0] + + assert mess.codec == expected_codecs[i] + assert mess.get_bytes() == expected_datas[i] + + await reconnector.close(flush=False) + + @pytest.mark.parametrize( + "codec,datas", + [ + ( + PublicCodec.RAW, + [b"123", b"456", b"789", b"0"], + ), + ( + PublicCodec.GZIP, + [b"123", b"456", b"789", b"0"], + ), + ], + ) + async def test_encode_data_inplace( + self, + reconnector: WriterAsyncIOReconnector, + codec: PublicCodec, + datas: List[bytes], + ): + f = reconnector._codec_functions[codec] + expected_datas = [f(data) for data in datas] + + messages = [InternalMessage(PublicMessage(data)) for data in datas] + await reconnector._encode_data_inplace(codec, messages) + + for index, mess in enumerate(messages): + assert mess.codec == codec + assert mess.get_bytes() == expected_datas[index] + + async def test_custom_encoder( + self, default_driver, default_settings, get_stream_writer + ): + codec = 10001 + + settings = copy.copy(default_settings) + settings.encoders = {codec: lambda x: bytes(reversed(x))} + settings.codec = codec + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + now = datetime.datetime.now() + seqno = self.init_last_seqno + 1 + + await reconnector.write_with_ack_future( + [PublicMessage(data=b"123", seqno=seqno, created_at=now)] + ) + + stream_writer = get_stream_writer() + sent_messages = await wait_for_fast(stream_writer.from_client.get()) + + expected_mess = InternalMessage( + PublicMessage(data=b"321", seqno=seqno, created_at=now) + ) + expected_mess.codec = codec + + assert sent_messages == [expected_mess] + + await reconnector.close(flush=False) + @pytest.mark.asyncio class TestWriterAsyncIO: diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index dc7b7fbd..4713c07d 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -9,9 +9,8 @@ PublicWriterSettings, TopicWriterError, PublicWriterInitInfo, - PublicMessage, PublicWriteResult, - MessageType, + Message, ) from .topic_writer_asyncio import WriterAsyncIO @@ -92,20 +91,20 @@ def wait_init(self, timeout: Optional[TimeoutType] = None) -> PublicWriterInitIn def write( self, - messages: Union[PublicMessage, List[PublicMessage]], + messages: Union[Message, List[Message]], timeout: Union[float, None] = None, ): self._call_sync(self._async_writer.write(messages), timeout=timeout) def async_write_with_ack( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: return self._call(self._async_writer.write_with_ack(messages)) def write_with_ack( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], timeout: Union[float, None] = None, ) -> Union[PublicWriteResult, List[PublicWriteResult]]: return self._call_sync( diff --git a/ydb/topic.py b/ydb/topic.py index 45b8b073..3ccdda08 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,7 +1,11 @@ +from __future__ import annotations + +import concurrent.futures import datetime -from typing import List, Union, Mapping, Optional, Dict +from dataclasses import dataclass +from typing import List, Union, Mapping, Optional, Dict, Callable -from . import aio, Credentials, _apis +from . import aio, Credentials, _apis, issues from . import driver @@ -43,11 +47,27 @@ class TopicClientAsyncIO: + _closed: bool _driver: aio.Driver _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor - def __init__(self, driver: aio.Driver, settings: "TopicClientSettings" = None): + def __init__( + self, driver: aio.Driver, settings: Optional[TopicClientSettings] = None + ): + if not settings: + settings = TopicClientSettings() + self._closed = False self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() async def create_topic( self, @@ -148,21 +168,59 @@ def writer( partition_id: Union[int, None] = None, auto_seqno: bool = True, auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool ) -> TopicWriterAsyncIO: args = locals() del args["self"] + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + return TopicWriterAsyncIO(self._driver, settings) + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + + raise RuntimeError("Topic client closed") + class TopicClient: + _closed: bool _driver: driver.Driver _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor - def __init__( - self, driver: driver.Driver, topic_client_settings: "TopicClientSettings" = None - ): + def __init__(self, driver: driver.Driver, settings: Optional[TopicClientSettings]): + if not settings: + settings = TopicClientSettings() + + self._closed = False self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() def create_topic( self, @@ -198,6 +256,8 @@ def create_topic( """ args = locals().copy() del args["self"] + self._check_closed() + req = _ydb_topic_public_types.CreateTopicRequestParams(**args) req = _ydb_topic.CreateTopicRequest.from_public(req) self._driver( @@ -212,6 +272,8 @@ def describe_topic( ) -> TopicDescription: args = locals().copy() del args["self"] + self._check_closed() + req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) res = self._driver( req.to_proto(), @@ -222,6 +284,8 @@ def describe_topic( return res.to_public() def drop_topic(self, path: str): + self._check_closed() + req = _ydb_topic_public_types.DropTopicRequestParams(path=path) self._driver( req.to_proto(), @@ -251,6 +315,8 @@ def reader( ) -> TopicReader: args = locals() del args["self"] + self._check_closed() + settings = TopicReaderSettings(**args) return TopicReader(self._driver, settings) @@ -263,16 +329,43 @@ def writer( partition_id: Union[int, None] = None, auto_seqno: bool = True, auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool ) -> TopicWriter: args = locals() del args["self"] + self._check_closed() + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + return TopicWriter(self._driver, settings) + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + raise RuntimeError("Topic client closed") + + +@dataclass class TopicClientSettings: - pass + encode_decode_threads_count: int = 4 -class StubEvent: +class TopicError(issues.Error): pass