From e3b8118d09fe6fe2b5950d6e74c861a443b9078d Mon Sep 17 00:00:00 2001 From: Timofey Koolin Date: Thu, 23 Mar 2023 03:45:54 +0300 Subject: [PATCH] group messages to batches --- ydb/_topic_writer/topic_writer.py | 92 ++++++++++++++++++++++++-- ydb/_topic_writer/topic_writer_test.py | 52 +++++++++++++++ ydb/connection.py | 7 +- 3 files changed, 143 insertions(+), 8 deletions(-) create mode 100644 ydb/_topic_writer/topic_writer_test.py diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index b94ff46b..c24ee259 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -1,6 +1,7 @@ import concurrent.futures import datetime import enum +import itertools import uuid from dataclasses import dataclass from enum import Enum @@ -12,6 +13,7 @@ from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage from .._grpc.grpcwrapper.common_utils import IToProto from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec +from .. import connection Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] @@ -200,14 +202,94 @@ def default_serializer_message_content(data: Any) -> bytes: def messages_to_proto_requests( messages: List[InternalMessage], ) -> List[StreamWriteMessage.FromClient]: - # todo split by proto message size and codec - res = [] - for msg in messages: + + gropus = _slit_messages_for_send(messages) + + res = [] # type: List[StreamWriteMessage.FromClient] + for group in gropus: req = StreamWriteMessage.FromClient( StreamWriteMessage.WriteRequest( - messages=[msg.to_message_data()], - codec=msg.codec, + messages=list(map(InternalMessage.to_message_data, group)), + codec=group[0].codec, ) ) res.append(req) return res + + +_max_int = 2**63 - 1 + +_message_data_overhead = ( + StreamWriteMessage.FromClient( + StreamWriteMessage.WriteRequest( + messages=[ + StreamWriteMessage.WriteRequest.MessageData( + seq_no=_max_int, + created_at=datetime.datetime(3000, 1, 1, 1, 1, 1, 1), + data=bytes(1), + uncompressed_size=_max_int, + partitioning=StreamWriteMessage.PartitioningMessageGroupID( + message_group_id="a" * 100, + ), + ), + ], + codec=20000, + ) + ) + .to_proto() + .ByteSize() +) + + +def _slit_messages_for_send( + messages: List[InternalMessage], +) -> List[List[InternalMessage]]: + codec_groups = [] # type: List[List[InternalMessage]] + for _, messages in itertools.groupby(messages, lambda x: x.codec): + codec_groups.append(list(messages)) + + res = [] # type: List[List[InternalMessage]] + for codec_group in codec_groups: + group_by_size = _split_messages_by_size_with_default_overhead(codec_group) + res.extend(group_by_size) + return res + + +def _split_messages_by_size_with_default_overhead( + messages: List[InternalMessage], +) -> List[List[InternalMessage]]: + def get_message_size(msg: InternalMessage): + return len(msg.data) + _message_data_overhead + + return _split_messages_by_size( + messages, connection._DEFAULT_MAX_GRPC_MESSAGE_SIZE, get_message_size + ) + + +def _split_messages_by_size( + messages: List[InternalMessage], + split_size: int, + get_msg_size: typing.Callable[[InternalMessage], int], +) -> List[List[InternalMessage]]: + res = [] + group = [] + group_size = 0 + + for msg in messages: + msg_size = get_msg_size(msg) + + if len(group) == 0: + group.append(msg) + group_size += msg_size + elif group_size + msg_size <= split_size: + group.append(msg) + group_size += msg_size + else: + res.append(group) + group = [msg] + group_size = msg_size + + if len(group) > 0: + res.append(group) + + return res diff --git a/ydb/_topic_writer/topic_writer_test.py b/ydb/_topic_writer/topic_writer_test.py new file mode 100644 index 00000000..6d2a96a4 --- /dev/null +++ b/ydb/_topic_writer/topic_writer_test.py @@ -0,0 +1,52 @@ +from typing import List + +import pytest + +from .topic_writer import _split_messages_by_size + + +@pytest.mark.parametrize( + "messages,split_size,expected", + [ + ( + [1, 2, 3], + 0, + [[1], [2], [3]], + ), + ( + [1, 2, 3], + 1, + [[1], [2], [3]], + ), + ( + [1, 2, 3], + 3, + [[1, 2], [3]], + ), + ( + [1, 2, 3], + 100, + [[1, 2, 3]], + ), + ( + [100, 2, 3], + 100, + [[100], [2, 3]], + ), + ( + [], + 100, + [], + ), + ( + [], + 100, + [], + ), + ], +) +def test_split_messages_by_size( + messages: List[int], split_size: int, expected: List[List[int]] +): + res = _split_messages_by_size(messages, split_size, lambda x: x) # noqa + assert res == expected diff --git a/ydb/connection.py b/ydb/connection.py index 95db084a..25a54bb7 100644 --- a/ydb/connection.py +++ b/ydb/connection.py @@ -24,6 +24,8 @@ YDB_TRACE_ID_HEADER = "x-ydb-trace-id" YDB_REQUEST_TYPE_HEADER = "x-ydb-request-type" +_DEFAULT_MAX_GRPC_MESSAGE_SIZE = 64 * 10**6 + def _message_to_string(message): """ @@ -179,10 +181,9 @@ def _construct_channel_options(driver_config, endpoint_options=None): :param endpoint_options: Endpoint options :return: A channel initialization options """ - _max_message_size = 64 * 10**6 _default_connect_options = [ - ("grpc.max_receive_message_length", _max_message_size), - ("grpc.max_send_message_length", _max_message_size), + ("grpc.max_receive_message_length", _DEFAULT_MAX_GRPC_MESSAGE_SIZE), + ("grpc.max_send_message_length", _DEFAULT_MAX_GRPC_MESSAGE_SIZE), ("grpc.primary_user_agent", driver_config.primary_user_agent), ( "grpc.lb_policy_name",