Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 87 additions & 5 deletions ydb/_topic_writer/topic_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import concurrent.futures
import datetime
import enum
import itertools
import uuid
from dataclasses import dataclass
from enum import Enum
Expand All @@ -12,6 +13,7 @@
from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage
from .._grpc.grpcwrapper.common_utils import IToProto
from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
from .. import connection

Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"]

Expand Down Expand Up @@ -200,14 +202,94 @@ def default_serializer_message_content(data: Any) -> bytes:
def messages_to_proto_requests(
messages: List[InternalMessage],
) -> List[StreamWriteMessage.FromClient]:
# todo split by proto message size and codec
res = []
for msg in messages:

gropus = _slit_messages_for_send(messages)

res = [] # type: List[StreamWriteMessage.FromClient]
for group in gropus:
req = StreamWriteMessage.FromClient(
StreamWriteMessage.WriteRequest(
messages=[msg.to_message_data()],
codec=msg.codec,
messages=list(map(InternalMessage.to_message_data, group)),
codec=group[0].codec,
)
)
res.append(req)
return res


_max_int = 2**63 - 1

_message_data_overhead = (
StreamWriteMessage.FromClient(
StreamWriteMessage.WriteRequest(
messages=[
StreamWriteMessage.WriteRequest.MessageData(
seq_no=_max_int,
created_at=datetime.datetime(3000, 1, 1, 1, 1, 1, 1),
data=bytes(1),
uncompressed_size=_max_int,
partitioning=StreamWriteMessage.PartitioningMessageGroupID(
message_group_id="a" * 100,
),
),
],
codec=20000,
)
)
.to_proto()
.ByteSize()
)


def _slit_messages_for_send(
messages: List[InternalMessage],
) -> List[List[InternalMessage]]:
codec_groups = [] # type: List[List[InternalMessage]]
for _, messages in itertools.groupby(messages, lambda x: x.codec):
codec_groups.append(list(messages))

res = [] # type: List[List[InternalMessage]]
for codec_group in codec_groups:
group_by_size = _split_messages_by_size_with_default_overhead(codec_group)
res.extend(group_by_size)
return res


def _split_messages_by_size_with_default_overhead(
messages: List[InternalMessage],
) -> List[List[InternalMessage]]:
def get_message_size(msg: InternalMessage):
return len(msg.data) + _message_data_overhead

return _split_messages_by_size(
messages, connection._DEFAULT_MAX_GRPC_MESSAGE_SIZE, get_message_size
)


def _split_messages_by_size(
messages: List[InternalMessage],
split_size: int,
get_msg_size: typing.Callable[[InternalMessage], int],
) -> List[List[InternalMessage]]:
res = []
group = []
group_size = 0

for msg in messages:
msg_size = get_msg_size(msg)

if len(group) == 0:
group.append(msg)
group_size += msg_size
elif group_size + msg_size <= split_size:
group.append(msg)
group_size += msg_size
else:
res.append(group)
group = [msg]
group_size = msg_size

if len(group) > 0:
res.append(group)

return res
52 changes: 52 additions & 0 deletions ydb/_topic_writer/topic_writer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List

import pytest

from .topic_writer import _split_messages_by_size


@pytest.mark.parametrize(
"messages,split_size,expected",
[
(
[1, 2, 3],
0,
[[1], [2], [3]],
),
(
[1, 2, 3],
1,
[[1], [2], [3]],
),
(
[1, 2, 3],
3,
[[1, 2], [3]],
),
(
[1, 2, 3],
100,
[[1, 2, 3]],
),
(
[100, 2, 3],
100,
[[100], [2, 3]],
),
(
[],
100,
[],
),
(
[],
100,
[],
),
],
)
def test_split_messages_by_size(
messages: List[int], split_size: int, expected: List[List[int]]
):
res = _split_messages_by_size(messages, split_size, lambda x: x) # noqa
assert res == expected
7 changes: 4 additions & 3 deletions ydb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
YDB_TRACE_ID_HEADER = "x-ydb-trace-id"
YDB_REQUEST_TYPE_HEADER = "x-ydb-request-type"

_DEFAULT_MAX_GRPC_MESSAGE_SIZE = 64 * 10**6


def _message_to_string(message):
"""
Expand Down Expand Up @@ -179,10 +181,9 @@ def _construct_channel_options(driver_config, endpoint_options=None):
:param endpoint_options: Endpoint options
:return: A channel initialization options
"""
_max_message_size = 64 * 10**6
_default_connect_options = [
("grpc.max_receive_message_length", _max_message_size),
("grpc.max_send_message_length", _max_message_size),
("grpc.max_receive_message_length", _DEFAULT_MAX_GRPC_MESSAGE_SIZE),
("grpc.max_send_message_length", _DEFAULT_MAX_GRPC_MESSAGE_SIZE),
("grpc.primary_user_agent", driver_config.primary_user_agent),
(
"grpc.lb_policy_name",
Expand Down