From 390274802b50f7ef62aa48ae7fae51be88689eed Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Mon, 6 Mar 2023 23:11:30 +0100 Subject: [PATCH 1/5] implement writer codecs --- tests/conftest.py | 4 +- tests/topics/test_topic_writer.py | 28 +++ .../grpcwrapper/ydb_topic_public_types.py | 11 +- ydb/_topic_writer/topic_writer.py | 16 +- ydb/_topic_writer/topic_writer_asyncio.py | 181 +++++++++++++++++- .../topic_writer_asyncio_test.py | 127 +++++++++++- ydb/_topic_writer/topic_writer_sync.py | 3 +- ydb/topic.py | 99 +++++++++- 8 files changed, 437 insertions(+), 32 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 99123660..e7a21970 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -131,7 +131,9 @@ async def topic_path(driver, topic_consumer, database) -> str: @pytest.fixture() @pytest.mark.asyncio() async def topic_with_messages(driver, topic_path): - writer = driver.topic_client.writer(topic_path, producer_id="fixture-producer-id") + writer = driver.topic_client.writer( + topic_path, producer_id="fixture-producer-id", codec=ydb.TopicCodec.RAW + ) await writer.write_with_ack( [ ydb.TopicWriterMessage(data="123".encode()), diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 9e8b0dfe..c53ce0db 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -84,6 +84,20 @@ async def test_write_multi_message_with_ack( assert batch.messages[0].seqno == 2 assert batch.messages[0].data == "456".encode() + @pytest.mark.parametrize( + "codec", + [ + ydb.TopicCodec.RAW, + ydb.TopicCodec.GZIP, + None, + ], + ) + async def test_write_encoded(self, driver: ydb.Driver, topic_path: str, codec): + async with driver.topic_client.writer(topic_path, codec=codec) as writer: + writer.write("a" * 1000) + writer.write("b" * 1000) + writer.write("c" * 1000) + class TestTopicWriterSync: def test_send_message(self, driver_sync: ydb.Driver, topic_path): @@ -163,3 +177,17 @@ def test_write_multi_message_with_ack( assert batch.messages[0].offset == 1 assert batch.messages[0].seqno == 2 assert batch.messages[0].data == "456".encode() + + @pytest.mark.parametrize( + "codec", + [ + ydb.TopicCodec.RAW, + ydb.TopicCodec.GZIP, + None, + ], + ) + def test_write_encoded(self, driver_sync: ydb.Driver, topic_path: str, codec): + with driver_sync.topic_client.writer(topic_path, codec=codec) as writer: + writer.write("a" * 1000) + writer.write("b" * 1000) + writer.write("c" * 1000) diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py index 6d922137..4582f19a 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py @@ -31,11 +31,18 @@ class CreateTopicRequestParams: class PublicCodec(int): + """ + Codec value may contain any int number. + + Values below is only well-known predefined values, + but protocol support custom codecs. + """ + UNSPECIFIED = 0 RAW = 1 GZIP = 2 - LZOP = 3 - ZSTD = 4 + LZOP = 3 # Has not supported codec in standard library + ZSTD = 4 # Has not supported codec in standard library class PublicMeteringMode(IntEnum): diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index a2c3d0d7..2c9807d9 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -1,3 +1,4 @@ +import concurrent.futures import datetime import enum import uuid @@ -10,7 +11,7 @@ import ydb.aio from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage from .._grpc.grpcwrapper.common_utils import IToProto - +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec MessageType = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] @@ -29,6 +30,10 @@ class PublicWriterSettings: partition_id: Optional[int] = None auto_seqno: bool = True auto_created_at: bool = True + codec: Optional[PublicCodec] = None # default mean auto-select + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None # default shared client executor pool # get_last_seqno: bool = False # encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None # serializer: Union[Callable[[Any], bytes], None] = None @@ -85,8 +90,9 @@ class SendMode(Enum): @dataclass class PublicWriterInitInfo: - __slots__ = "last_seqno" + __slots__ = ("last_seqno", "supported_codecs") last_seqno: Optional[int] + supported_codecs: List[PublicCodec] class PublicMessage: @@ -117,15 +123,17 @@ def _create_message( class InternalMessage(StreamWriteMessage.WriteRequest.MessageData, IToProto): + codec: PublicCodec + def __init__(self, mess: PublicMessage): - StreamWriteMessage.WriteRequest.MessageData.__init__( - self, + super().__init__( seq_no=mess.seqno, created_at=mess.created_at, data=mess.data, uncompressed_size=len(mess.data), partitioning=None, ) + self.codec = PublicCodec.RAW def get_bytes(self) -> bytes: if self.data is None: diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 67b1be69..f658efe3 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,7 +1,10 @@ import asyncio +import concurrent.futures import datetime +import gzip +import typing from collections import deque -from typing import Deque, AsyncIterator, Union, List +from typing import Deque, AsyncIterator, Union, List, Optional, Dict, Callable import ydb from .topic_writer import ( @@ -22,6 +25,7 @@ check_retriable_error, RetrySettings, ) +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._topic_common.common import ( TokenGetterFuncType, ) @@ -41,6 +45,7 @@ class WriterAsyncIO: _loop: asyncio.AbstractEventLoop _reconnector: "WriterAsyncIOReconnector" _closed: bool + _compressor_thread_pool: concurrent.futures.Executor @property def last_seqno(self) -> int: @@ -107,11 +112,11 @@ async def write_with_ack_future( For wait with timeout use asyncio.wait_for. """ input_single_message = not isinstance(messages, list) - if isinstance(messages, PublicMessage): - messages = [PublicMessage._create_message(messages)] if isinstance(messages, list): for index, m in enumerate(messages): messages[index] = PublicMessage._create_message(m) + else: + messages = [PublicMessage._create_message(messages)] futures = await self._reconnector.write_with_ack_future(messages) if input_single_message: @@ -152,7 +157,7 @@ async def wait_init(self) -> PublicWriterInitInfo: class WriterAsyncIOReconnector: _closed: bool _loop: asyncio.AbstractEventLoop - _credentials: Union[ydb.Credentials, None] + _credentials: Union[ydb.credentials.Credentials, None] _driver: ydb.aio.Driver _update_token_interval: int _token_get_function: TokenGetterFuncType @@ -160,8 +165,18 @@ class WriterAsyncIOReconnector: _init_info: asyncio.Future _stream_connected: asyncio.Event _settings: WriterSettings + _codec: PublicCodec + _codec_functions: Dict[PublicCodec, Callable[[bytes], bytes]] + _encode_executor: Optional[concurrent.futures.Executor] + _codec_selector_batch_num: int + _codec_selector_last_codec: Optional[PublicCodec] + _codec_selector_check_batches_interval: int _last_known_seq_no: int + if typing.TYPE_CHECKING: + _messages_for_encode: asyncio.Queue[List[InternalMessage]] + else: + _messages_for_encode: asyncio.Queue _messages: Deque[InternalMessage] _messages_future: Deque[asyncio.Future] _new_messages: asyncio.Queue @@ -179,13 +194,34 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._stream_connected = asyncio.Event() self._settings = settings + self._codec_functions = { + PublicCodec.RAW: lambda data: data, + PublicCodec.GZIP: gzip.compress, + } + self._encode_executor = settings.encoder_executor + + self._codec_selector_batch_num = 0 + self._codec_selector_last_codec = None + self._codec_selector_check_batches_interval = 10000 + + self._codec = self._settings.codec + if self._codec and self._codec not in self._codec_functions: + known_codecs = [key for key in self._codec_functions] + known_codecs.sort() + raise ValueError( + "Unknown codec for writer: %s, supported codecs: %s" + % (self._codec, known_codecs) + ) + self._last_known_seq_no = 0 + self._messages_for_encode = asyncio.Queue() self._messages = deque() self._messages_future = deque() self._new_messages = asyncio.Queue() self._stop_reason = self._loop.create_future() self._background_tasks = [ - asyncio.create_task(self._connection_loop(), name="connection_loop") + asyncio.create_task(self._connection_loop(), name="connection_loop"), + asyncio.create_task(self._encode_loop(), name="encode_loop"), ] async def close(self, flush: bool): @@ -238,15 +274,23 @@ async def write_with_ack_future( internal_messages = self._prepare_internal_messages(messages) messages_future = [self._loop.create_future() for _ in internal_messages] - self._messages.extend(internal_messages) self._messages_future.extend(messages_future) - for m in internal_messages: - self._new_messages.put_nowait(m) + if self._codec == PublicCodec.RAW: + self._add_messages_to_send_queue(internal_messages) + else: + self._messages_for_encode.put_nowait(internal_messages) return messages_future - def _prepare_internal_messages(self, messages: List[PublicMessage]): + def _add_messages_to_send_queue(self, internal_messages: List[InternalMessage]): + self._messages.extend(internal_messages) + for m in internal_messages: + self._new_messages.put_nowait(m) + + def _prepare_internal_messages( + self, messages: List[PublicMessage] + ) -> List[InternalMessage]: if self._settings.auto_created_at: now = datetime.datetime.now() else: @@ -307,7 +351,10 @@ async def _connection_loop(self): try: self._last_known_seq_no = stream_writer.last_seqno self._init_info.set_result( - PublicWriterInitInfo(last_seqno=stream_writer.last_seqno) + PublicWriterInitInfo( + last_seqno=stream_writer.last_seqno, + supported_codecs=stream_writer.supported_codecs, + ) ) except asyncio.InvalidStateError: pass @@ -350,6 +397,118 @@ async def _connection_loop(self): task.cancel() await asyncio.wait(pending) + async def _encode_loop(self): + while True: + messages = await self._messages_for_encode.get() + while not self._messages_for_encode.empty(): + messages.extend(self._messages_for_encode.get_nowait()) + + batch_codec = await self._codec_selector(messages) + await self._encode_data_inplace(batch_codec, messages) + self._add_messages_to_send_queue(messages) + + async def _encode_data_inplace( + self, codec: PublicCodec, messages: List[InternalMessage] + ): + if codec == PublicCodec.RAW: + return + + eventloop = asyncio.get_running_loop() + encode_waiters = [] + encoder_function = self._codec_functions[codec] + + for message in messages: + encoded_data_futures = eventloop.run_in_executor( + self._encode_executor, encoder_function, message.get_bytes() + ) + encode_waiters.append(encoded_data_futures) + + encoded_datas = await asyncio.gather(*encode_waiters) + + for index, data in enumerate(encoded_datas): + message = messages[index] + message.codec = codec + message.data = data + + async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec: + if self._codec is not None: + return self._codec + + if self._codec_selector_last_codec is None: + available_codecs = await self._get_available_codecs() + + if self._codec_selector_batch_num < len(available_codecs): + codec_index = self._codec_selector_batch_num % len(available_codecs) + codec = available_codecs[codec_index] + else: + codec = await self._codec_selector_by_check_compress(messages) + self._codec_selector_last_codec = codec + else: + if ( + self._codec_selector_batch_num + % self._codec_selector_check_batches_interval + == 0 + ): + self._codec_selector_last_codec = ( + await self._codec_selector_by_check_compress(messages) + ) + codec = self._codec_selector_last_codec + self._codec_selector_batch_num += 1 + return codec + + async def _get_available_codecs(self) -> List[PublicCodec]: + info = await self.wait_init() + topic_supported_codecs = info.supported_codecs + if not topic_supported_codecs: + topic_supported_codecs = [PublicCodec.RAW, PublicCodec.GZIP] + + res = [] + for codec in topic_supported_codecs: + if codec in self._codec_functions: + res.append(codec) + + if not res: + raise TopicWriterError("Writer does not support topic's codecs") + + res.sort() + + return res + + async def _codec_selector_by_check_compress( + self, messages: List[InternalMessage] + ) -> PublicCodec: + test_messages = messages + if len(test_messages) > 10: + test_messages = test_messages[:10] + + available_codecs = await self._get_available_codecs() + if len(available_codecs) == 1: + return available_codecs[0] + + def get_compressed_size(codec) -> int: + s = 0 + f = self._codec_functions[codec] + + for m in test_messages: + encoded = f(m.get_bytes()) + s += len(encoded) + + return s + + def select_codec() -> PublicCodec: + min_codec = available_codecs[0] + min_size = get_compressed_size(min_codec) + for codec in available_codecs[1:]: + size = get_compressed_size(codec) + if size < min_size: + min_codec = codec + min_size = size + return min_codec + + loop = asyncio.get_running_loop() + codec = await loop.run_in_executor(self._encode_executor, select_codec) + return codec + async def _read_loop(self, writer: "WriterAsyncIOStream"): while True: resp = await writer.receive() @@ -412,6 +571,7 @@ class WriterAsyncIOStream: # todo slots last_seqno: int + supported_codecs: Optional[List[PublicCodec]] _stream: IGrpcWrapperAsyncIO _token_getter: TokenGetterFuncType @@ -466,6 +626,7 @@ async def _start( raise TopicWriterError("Unexpected answer for init request: %s" % resp) self.last_seqno = resp.last_seq_no + self.supported_codecs = [PublicCodec(codec) for codec in resp.supported_codecs] self._stream = stream diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 32bb3de5..abb1f427 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -4,14 +4,15 @@ import copy import dataclasses import datetime +import gzip import typing from queue import Queue, Empty +from typing import List from unittest import mock import freezegun import pytest - from .. import aio from .. import StatusCode, issues from .._grpc.grpcwrapper.ydb_topic import Codec, StreamWriteMessage @@ -25,6 +26,7 @@ PublicWriteResult, TopicWriterError, ) +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._topic_common.test_helpers import StreamMock from .topic_writer_asyncio import ( @@ -151,9 +153,11 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): @pytest.mark.asyncio class TestWriterAsyncIOReconnector: init_last_seqno = 0 + time_for_mocks = 1678046714.639387 class StreamWriterMock: last_seqno: int + supported_codecs: List[PublicCodec] from_client: asyncio.Queue from_server: asyncio.Queue @@ -165,6 +169,7 @@ def __init__(self): self.from_server = asyncio.Queue() self.from_client = asyncio.Queue() self._closed = False + self.supported_codecs = [] def write(self, messages: typing.List[InternalMessage]): if self._closed: @@ -184,10 +189,7 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: def close(self): if self._closed: return - - self.from_server.put_nowait( - Exception("waited message while StreamWriterMock closed") - ) + self._closed = True @pytest.fixture(autouse=True) async def stream_writer_double_queue(self, monkeypatch): @@ -241,6 +243,7 @@ def default_settings(self) -> WriterSettings: producer_id="test-producer", auto_seqno=False, auto_created_at=False, + codec=PublicCodec.RAW, ) ) @@ -347,7 +350,9 @@ async def wait_stop(): async def test_wait_init(self, default_driver, default_settings, get_stream_writer): init_seqno = 100 - expected_init_info = PublicWriterInitInfo(init_seqno) + expected_init_info = PublicWriterInitInfo( + last_seqno=init_seqno, supported_codecs=[] + ) with mock.patch.object( TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno ): @@ -457,6 +462,116 @@ async def test_auto_created_at( ] == sent await reconnector.close(flush=False) + @pytest.mark.parametrize( + "codec,write_datas,expected_codecs,expected_datas", + [ + ( + PublicCodec.RAW, + [b"123"], + [PublicCodec.RAW], + [b"123"], + ), + ( + PublicCodec.GZIP, + [b"123"], + [PublicCodec.GZIP], + [gzip.compress(b"123", mtime=time_for_mocks)], + ), + ( + None, + [b"123", b"456", b"789", b"0" * 1000], + [PublicCodec.RAW, PublicCodec.GZIP, PublicCodec.RAW, PublicCodec.RAW], + [ + b"123", + gzip.compress(b"456", mtime=time_for_mocks), + b"789", + b"0" * 1000, + ], + ), + ( + None, + [b"123", b"456", b"789" * 1000, b"0"], + [PublicCodec.RAW, PublicCodec.GZIP, PublicCodec.GZIP, PublicCodec.GZIP], + [ + b"123", + gzip.compress(b"456", mtime=time_for_mocks), + gzip.compress(b"789" * 1000, mtime=time_for_mocks), + gzip.compress(b"0", mtime=time_for_mocks), + ], + ), + ], + ) + async def test_select_codecs( + self, + default_driver: aio.Driver, + default_settings: WriterSettings, + monkeypatch, + write_datas: List[typing.Optional[bytes]], + codec: typing.Optional[PublicCodec], + expected_codecs: List[PublicCodec], + expected_datas: List[bytes], + ): + assert len(write_datas) == len(expected_datas) + assert len(expected_codecs) == len(expected_datas) + + settings = copy.copy(default_settings) + settings.codec = codec + settings.auto_seqno = True + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + added_messages = asyncio.Queue() # type: asyncio.Queue[List[InternalMessage]] + + def add_messages(_self, messages: typing.List[InternalMessage]): + added_messages.put_nowait(messages) + + monkeypatch.setattr( + WriterAsyncIOReconnector, "_add_messages_to_send_queue", add_messages + ) + monkeypatch.setattr( + "time.time", lambda: TestWriterAsyncIOReconnector.time_for_mocks + ) + + for i in range(len(expected_datas)): + await reconnector.write_with_ack_future( + [PublicMessage(data=write_datas[i])] + ) + mess = await asyncio.wait_for(added_messages.get(), timeout=600) + mess = mess[0] + + assert mess.codec == expected_codecs[i] + assert mess.get_bytes() == expected_datas[i] + + await reconnector.close(flush=False) + + @pytest.mark.parametrize( + "codec,datas", + [ + ( + PublicCodec.RAW, + [b"123", b"456", b"789", b"0"], + ), + ( + PublicCodec.GZIP, + [b"123", b"456", b"789", b"0"], + ), + ], + ) + async def test_encode_data_inplace( + self, + reconnector: WriterAsyncIOReconnector, + codec: PublicCodec, + datas: List[bytes], + ): + f = reconnector._codec_functions[codec] + expected_datas = [f(data) for data in datas] + + messages = [InternalMessage(PublicMessage(data)) for data in datas] + await reconnector._encode_data_inplace(codec, messages) + + for index, mess in enumerate(messages): + assert mess.codec == codec + assert mess.get_bytes() == expected_datas[index] + @pytest.mark.asyncio class TestWriterAsyncIO: diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index dc7b7fbd..48649875 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -9,7 +9,6 @@ PublicWriterSettings, TopicWriterError, PublicWriterInitInfo, - PublicMessage, PublicWriteResult, MessageType, ) @@ -92,7 +91,7 @@ def wait_init(self, timeout: Optional[TimeoutType] = None) -> PublicWriterInitIn def write( self, - messages: Union[PublicMessage, List[PublicMessage]], + messages: Union[MessageType, List[MessageType]], timeout: Union[float, None] = None, ): self._call_sync(self._async_writer.write(messages), timeout=timeout) diff --git a/ydb/topic.py b/ydb/topic.py index 45b8b073..1ed60686 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -1,7 +1,11 @@ +from __future__ import annotations + +import concurrent.futures import datetime +from dataclasses import dataclass from typing import List, Union, Mapping, Optional, Dict -from . import aio, Credentials, _apis +from . import aio, Credentials, _apis, issues from . import driver @@ -43,11 +47,25 @@ class TopicClientAsyncIO: + _closed: bool _driver: aio.Driver _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor - def __init__(self, driver: aio.Driver, settings: "TopicClientSettings" = None): + def __init__(self, driver: aio.Driver, settings: Optional[TopicClientSettings]): + if not settings: + settings = TopicClientSettings() + self._closed = False self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() async def create_topic( self, @@ -148,21 +166,56 @@ def writer( partition_id: Union[int, None] = None, auto_seqno: bool = True, auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool ) -> TopicWriterAsyncIO: args = locals() del args["self"] + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + return TopicWriterAsyncIO(self._driver, settings) + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + + raise RuntimeError("Topic client closed") + class TopicClient: + _closed: bool _driver: driver.Driver _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor - def __init__( - self, driver: driver.Driver, topic_client_settings: "TopicClientSettings" = None - ): + def __init__(self, driver: driver.Driver, settings: Optional[TopicClientSettings]): + if not settings: + settings = TopicClientSettings() + + self._closed = False self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() def create_topic( self, @@ -198,6 +251,8 @@ def create_topic( """ args = locals().copy() del args["self"] + self._check_closed() + req = _ydb_topic_public_types.CreateTopicRequestParams(**args) req = _ydb_topic.CreateTopicRequest.from_public(req) self._driver( @@ -212,6 +267,8 @@ def describe_topic( ) -> TopicDescription: args = locals().copy() del args["self"] + self._check_closed() + req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) res = self._driver( req.to_proto(), @@ -222,6 +279,8 @@ def describe_topic( return res.to_public() def drop_topic(self, path: str): + self._check_closed() + req = _ydb_topic_public_types.DropTopicRequestParams(path=path) self._driver( req.to_proto(), @@ -251,6 +310,8 @@ def reader( ) -> TopicReader: args = locals() del args["self"] + self._check_closed() + settings = TopicReaderSettings(**args) return TopicReader(self._driver, settings) @@ -263,16 +324,40 @@ def writer( partition_id: Union[int, None] = None, auto_seqno: bool = True, auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool ) -> TopicWriter: args = locals() del args["self"] + self._check_closed() + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + return TopicWriter(self._driver, settings) + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + raise RuntimeError("Topic client closed") + + +@dataclass class TopicClientSettings: - pass + encode_decode_threads_count: int = 4 -class StubEvent: +class TopicError(issues.Error): pass From 1104d9d45b329e6bf5f6a45519dea396f55497bc Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 09:59:14 +0100 Subject: [PATCH 2/5] add custom encoders --- ydb/_topic_writer/topic_writer.py | 4 ++- ydb/_topic_writer/topic_writer_asyncio.py | 5 +++ .../topic_writer_asyncio_test.py | 31 ++++++++++++++++++- ydb/topic.py | 8 ++++- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 2c9807d9..7bfb3f30 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -34,8 +34,10 @@ class PublicWriterSettings: encoder_executor: Optional[ concurrent.futures.Executor ] = None # default shared client executor pool + encoders: Optional[ + typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]] + ] = None # get_last_seqno: bool = False - # encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None # serializer: Union[Callable[[Any], bytes], None] = None # send_buffer_count: Optional[int] = 10000 # send_buffer_bytes: Optional[int] = 100 * 1024 * 1024 diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index f658efe3..59daa48d 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -198,6 +198,11 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): PublicCodec.RAW: lambda data: data, PublicCodec.GZIP: gzip.compress, } + + if settings.encoders: + for codec, encoder in settings.encoders.items(): + self._codec_functions[codec] = encoder + self._encode_executor = settings.encoder_executor self._codec_selector_batch_num = 0 diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index abb1f427..921c6aa4 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -27,7 +27,7 @@ TopicWriterError, ) from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec -from .._topic_common.test_helpers import StreamMock +from .._topic_common.test_helpers import StreamMock, wait_for_fast from .topic_writer_asyncio import ( WriterAsyncIOStream, @@ -572,6 +572,35 @@ async def test_encode_data_inplace( assert mess.codec == codec assert mess.get_bytes() == expected_datas[index] + async def test_custom_encoder( + self, default_driver, default_settings, get_stream_writer + ): + codec = 10001 + + settings = copy.copy(default_settings) + settings.encoders = {codec: lambda x: bytes(reversed(x))} + settings.codec = codec + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + now = datetime.datetime.now() + seqno = self.init_last_seqno + 1 + + await reconnector.write_with_ack_future( + [PublicMessage(data=b"123", seqno=seqno, created_at=now)] + ) + + stream_writer = get_stream_writer() + sent_messages = await wait_for_fast(stream_writer.from_client.get()) + + expected_mess = InternalMessage( + PublicMessage(data=b"321", seqno=seqno, created_at=now) + ) + expected_mess.codec = codec + + assert sent_messages == [expected_mess] + + await reconnector.close(flush=False) + @pytest.mark.asyncio class TestWriterAsyncIO: diff --git a/ydb/topic.py b/ydb/topic.py index 1ed60686..a4162208 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -3,7 +3,7 @@ import concurrent.futures import datetime from dataclasses import dataclass -from typing import List, Union, Mapping, Optional, Dict +from typing import List, Union, Mapping, Optional, Dict, Callable from . import aio, Credentials, _apis, issues @@ -167,6 +167,9 @@ def writer( auto_seqno: bool = True, auto_created_at: bool = True, codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, encoder_executor: Optional[ concurrent.futures.Executor ] = None, # default shared client executor pool @@ -325,6 +328,9 @@ def writer( auto_seqno: bool = True, auto_created_at: bool = True, codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, encoder_executor: Optional[ concurrent.futures.Executor ] = None, # default shared client executor pool From a537c65760a45d78f2d4371dbccc54c6cb315458 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 10:21:48 +0100 Subject: [PATCH 3/5] add comment --- ydb/_topic_writer/topic_writer_asyncio.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 59daa48d..5c337448 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -442,6 +442,8 @@ async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec: if self._codec_selector_last_codec is None: available_codecs = await self._get_available_codecs() + # use every of available encoders at start for prevent problems + # with rare used encoders (on writer or reader side) if self._codec_selector_batch_num < len(available_codecs): codec_index = self._codec_selector_batch_num % len(available_codecs) codec = available_codecs[codec_index] @@ -482,6 +484,10 @@ async def _get_available_codecs(self) -> List[PublicCodec]: async def _codec_selector_by_check_compress( self, messages: List[InternalMessage] ) -> PublicCodec: + """ + Try to compress messages and choose codec with the smallest result size. + """ + test_messages = messages if len(test_messages) > 10: test_messages = test_messages[:10] From a66f0d0b9a236d69877692223d80aa2f84485221 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 16:56:44 +0100 Subject: [PATCH 4/5] fix style --- ydb/_topic_writer/topic_writer_asyncio.py | 23 +++++++++-------------- ydb/topic.py | 4 +++- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 5c337448..1b3b3d25 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -45,7 +45,6 @@ class WriterAsyncIO: _loop: asyncio.AbstractEventLoop _reconnector: "WriterAsyncIOReconnector" _closed: bool - _compressor_thread_pool: concurrent.futures.Executor @property def last_seqno(self) -> int: @@ -112,13 +111,14 @@ async def write_with_ack_future( For wait with timeout use asyncio.wait_for. """ input_single_message = not isinstance(messages, list) + converted_messages = [] if isinstance(messages, list): - for index, m in enumerate(messages): - messages[index] = PublicMessage._create_message(m) + for m in messages: + converted_messages.append(PublicMessage._create_message(m)) else: - messages = [PublicMessage._create_message(messages)] + converted_messages = [PublicMessage._create_message(messages)] - futures = await self._reconnector.write_with_ack_future(messages) + futures = await self._reconnector.write_with_ack_future(converted_messages) if input_single_message: return futures[0] else: @@ -200,8 +200,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): } if settings.encoders: - for codec, encoder in settings.encoders.items(): - self._codec_functions[codec] = encoder + self._codec_functions.update(settings.encoders) self._encode_executor = settings.encoder_executor @@ -211,8 +210,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._codec = self._settings.codec if self._codec and self._codec not in self._codec_functions: - known_codecs = [key for key in self._codec_functions] - known_codecs.sort() + known_codecs = sorted(self._codec_functions.keys()) raise ValueError( "Unknown codec for writer: %s, supported codecs: %s" % (self._codec, known_codecs) @@ -445,8 +443,7 @@ async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec: # use every of available encoders at start for prevent problems # with rare used encoders (on writer or reader side) if self._codec_selector_batch_num < len(available_codecs): - codec_index = self._codec_selector_batch_num % len(available_codecs) - codec = available_codecs[codec_index] + codec = available_codecs[self._codec_selector_batch_num] else: codec = await self._codec_selector_by_check_compress(messages) self._codec_selector_last_codec = codec @@ -488,9 +485,7 @@ async def _codec_selector_by_check_compress( Try to compress messages and choose codec with the smallest result size. """ - test_messages = messages - if len(test_messages) > 10: - test_messages = test_messages[:10] + test_messages = messages[:10] available_codecs = await self._get_available_codecs() if len(available_codecs) == 1: diff --git a/ydb/topic.py b/ydb/topic.py index a4162208..3ccdda08 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -53,7 +53,9 @@ class TopicClientAsyncIO: _settings: TopicClientSettings _executor: concurrent.futures.Executor - def __init__(self, driver: aio.Driver, settings: Optional[TopicClientSettings]): + def __init__( + self, driver: aio.Driver, settings: Optional[TopicClientSettings] = None + ): if not settings: settings = TopicClientSettings() self._closed = False From ccc6eecc2b4dc6f733e8e225d6fa541056570919 Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Tue, 7 Mar 2023 17:08:41 +0100 Subject: [PATCH 5/5] rename messagetype --- ydb/_topic_writer/topic_writer.py | 6 ++---- ydb/_topic_writer/topic_writer_asyncio.py | 8 ++++---- ydb/_topic_writer/topic_writer_sync.py | 8 ++++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 7bfb3f30..92212f65 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -13,7 +13,7 @@ from .._grpc.grpcwrapper.common_utils import IToProto from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec -MessageType = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] +Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] @dataclass @@ -116,9 +116,7 @@ def __init__( self.data = data @staticmethod - def _create_message( - data: Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] - ) -> "PublicMessage": + def _create_message(data: Message) -> "PublicMessage": if isinstance(data, PublicMessage): return data return PublicMessage(data=data) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 1b3b3d25..5e3bb455 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -17,7 +17,7 @@ TopicWriterError, messages_to_proto_requests, PublicWriteResultTypes, - MessageType, + Message, ) from .. import ( _apis, @@ -79,7 +79,7 @@ async def close(self, *, flush: bool = True): async def write_with_ack( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ) -> Union[PublicWriteResultTypes, List[PublicWriteResultTypes]]: """ IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. @@ -100,7 +100,7 @@ async def write_with_ack( async def write_with_ack_future( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ) -> Union[asyncio.Future, List[asyncio.Future]]: """ send one or number of messages to server. @@ -126,7 +126,7 @@ async def write_with_ack_future( async def write( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ): """ send one or number of messages to server. diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index 48649875..4713c07d 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -10,7 +10,7 @@ TopicWriterError, PublicWriterInitInfo, PublicWriteResult, - MessageType, + Message, ) from .topic_writer_asyncio import WriterAsyncIO @@ -91,20 +91,20 @@ def wait_init(self, timeout: Optional[TimeoutType] = None) -> PublicWriterInitIn def write( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], timeout: Union[float, None] = None, ): self._call_sync(self._async_writer.write(messages), timeout=timeout) def async_write_with_ack( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: return self._call(self._async_writer.write_with_ack(messages)) def write_with_ack( self, - messages: Union[MessageType, List[MessageType]], + messages: Union[Message, List[Message]], timeout: Union[float, None] = None, ) -> Union[PublicWriteResult, List[PublicWriteResult]]: return self._call_sync(