diff --git a/.github/scripts/increment_version_test.py b/.github/scripts/increment_version_test.py index a1c8bbf4..ee921a62 100644 --- a/.github/scripts/increment_version_test.py +++ b/.github/scripts/increment_version_test.py @@ -7,16 +7,17 @@ [ ("0.0.0", 'patch', False, "0.0.1"), ("0.0.1", 'patch', False, "0.0.2"), - ("0.0.1a1", 'patch', False, "0.0.1"), - ("0.0.0", 'patch', True, "0.0.1a1"), - ("0.0.1", 'patch', True, "0.0.2a1"), - ("0.0.2a1", 'patch', True, "0.0.2a2"), + ("0.0.1b1", 'patch', False, "0.0.1"), + ("0.0.0", 'patch', True, "0.0.1b1"), + ("0.0.1", 'patch', True, "0.0.2b1"), + ("0.0.2b1", 'patch', True, "0.0.2b2"), ("0.0.1", 'minor', False, "0.1.0"), - ("0.0.1a1", 'minor', False, "0.1.0"), - ("0.1.0a1", 'minor', False, "0.1.0"), - ("0.1.0", 'minor', True, "0.2.0a1"), - ("0.1.0a1", 'minor', True, "0.1.0a2"), - ("0.1.1a1", 'minor', True, "0.2.0a1"), + ("0.0.1b1", 'minor', False, "0.1.0"), + ("0.1.0b1", 'minor', False, "0.1.0"), + ("0.1.0", 'minor', True, "0.2.0b1"), + ("0.1.0b1", 'minor', True, "0.1.0b2"), + ("0.1.1b1", 'minor', True, "0.2.0b1"), + ("3.0.0b1", 'patch', True, "3.0.0b2"), ] ) def test_increment_version(source, inc_type, with_beta, result): diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 9c0a2392..9682a48a 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,14 +11,21 @@ jobs: runs-on: ubuntu-latest concurrency: - group: unit-${{ github.ref }}-${{ matrix.environment }}-${{ matrix.python-version }} + group: unit-${{ github.ref }}-${{ matrix.environment }}-${{ matrix.python-version }}-${{ matrix.folder }} cancel-in-progress: true strategy: + fail-fast: false max-parallel: 4 matrix: python-version: [3.8] environment: [py, py-tls, py-proto3, py-tls-proto3] + folder: [ydb, tests] + exclude: + - environment: py-tls + folder: ydb + - environment: py-tls-proto3 + folder: ydb steps: - uses: actions/checkout@v1 @@ -26,9 +33,11 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + + - name: Install tox run: | python -m pip install --upgrade pip pip install tox==4.2.6 - - name: Test with tox - run: tox -e ${{ matrix.environment }} + + - name: Run unit tests + run: tox -e ${{ matrix.environment }} -- ${{ matrix.folder }} diff --git a/AUTHORS b/AUTHORS index 200343e3..69fee17e 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,4 +1,5 @@ -The following authors have created the source code of "YDB Python SDK" +The following authors have created the source code of "Yandex Database Python SDK" published and distributed by YANDEX LLC as the owner: Vitalii Gridnev +Timofey Koolin diff --git a/CHANGELOG.md b/CHANGELOG.md index 28224067..c24d452d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,54 @@ +## 3.0.1b11 ## +* Fixed global_allow_split_transactions +* Added reader.receive_message() method +* Swap topic_path and consumer arguments in topic_client.reader method + +## 3.0.1b10 ## +* fixed sqlalchemy get_columns method with not null columns +* fixed requirements.txt + +## 3.0.1b9 ## +* Add function for global revert broken behaviour: ydb.global_allow_truncated_result, ydb.global_allow_split_transactions +* Change argument names from deny_split_transactions to allow_split_transactions (with reverse value +* Fixed check retriable for idempotent error +* Reader codecs +* Read one message +* fixed sqlalchemy get_columns method with not null columns + +## 3.0.1b8 ## +* Fixed exception while create ResultSet with None table_settings + +## 3.0.1b7 ## +* BROKEN CHANGE: deny any action in transaction after commit/rollback +* BROKEN CHANGE: raise exception for truncated response by default +* Compatible protobaf detection for arcadia +* Add codecs support for topic writer + +## 3.0.1b6 ## +* BROKEN CHANGES: remove writer.write(mess1, mess2) variant, use list instead: writer.write([mess1, mess2]) +* BROKEN CHANGES: change names of public method in topic client +* BROKEN CHANGES: rename parameter producer_and_message_group_id to producer_id +* producer_id is optional now + +## 3.0.1b5 ## +* Remove six package from code and dependencies (remove support python2) +* Use anonymous credentials by default instead of iam metadata (use ydb.driver.credentials_from_env_variables for creds by env var) +* Close grpc streams while closing readers/writers +* Add control plane operations for topic api: create, drop +* Add six package to requirements + +## 3.0.1b4 ## +* Initial implementation of topic reader + +## 3.0.1b3 ## +* Fix error of check retriable error for idempotent operations (error exist since 2.12.1) + +## 3.0.1b2 ## +* Add initial topic writer + +## 3.0.1b1 ## +* start 3.0 beta branch + ## 2.13.4 ## * fixed sqlalchemy get_columns method with not null columns diff --git a/docker-compose-tls.yml b/docker-compose-tls.yml index c9d6fac9..83403ea1 100644 --- a/docker-compose-tls.yml +++ b/docker-compose-tls.yml @@ -1,7 +1,7 @@ version: "3.9" services: ydb: - image: cr.yandex/yc/yandex-docker-local-ydb:latest + image: cr.yandex/yc/yandex-docker-local-ydb@sha256:b569c23d6854564ec4d970bda86cddcf5b11c7c6362df62beb8ba8eafb8d54fd restart: always ports: - 2136:2136 diff --git a/docker-compose.yml b/docker-compose.yml index d8b898ae..3223033b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.9" services: ydb: - image: cr.yandex/yc/yandex-docker-local-ydb:latest + image: cr.yandex/yc/yandex-docker-local-ydb@sha256:b569c23d6854564ec4d970bda86cddcf5b11c7c6362df62beb8ba8eafb8d54fd restart: always ports: - 2136:2136 diff --git a/examples/_sqlalchemy_example/example.py b/examples/_sqlalchemy_example/example.py new file mode 100644 index 00000000..70f4c465 --- /dev/null +++ b/examples/_sqlalchemy_example/example.py @@ -0,0 +1,229 @@ +import datetime +import logging +import argparse +import sqlalchemy as sa +from sqlalchemy import orm, exc, sql +from sqlalchemy import Table, Column, Integer, String, Float, TIMESTAMP +from ydb._sqlalchemy import register_dialect + +from fill_tables import fill_all_tables, to_days +from models import Base, Series, Episodes + + +def describe_table(engine, name): + inspect = sa.inspect(engine) + print(f"describe table {name}:") + for col in inspect.get_columns(name): + print(f"\t{col['name']}: {col['type']}") + + +def simple_select(conn): + stm = sa.select(Series).where(Series.series_id == 1) + res = conn.execute(stm) + print(res.one()) + + +def simple_insert(conn): + stm = Episodes.__table__.insert().values( + series_id=3, season_id=6, episode_id=1, title="TBD" + ) + conn.execute(stm) + + +def test_types(conn): + types_tb = Table( + "test_types", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("str", String), + Column("num", Float), + Column("dt", TIMESTAMP), + ) + types_tb.drop(bind=conn.engine, checkfirst=True) + types_tb.create(bind=conn.engine, checkfirst=True) + + stm = types_tb.insert().values( + id=1, + str=b"Hello World!", + num=3.1415, + dt=datetime.datetime.now(), + ) + conn.execute(stm) + + # GROUP BY + stm = sa.select(types_tb.c.str, sa.func.max(types_tb.c.num)).group_by( + types_tb.c.str + ) + rs = conn.execute(stm) + for x in rs: + print(x) + + +def run_example_orm(engine): + Base.metadata.bind = engine + Base.metadata.drop_all() + Base.metadata.create_all() + + session = orm.sessionmaker(bind=engine)() + + rs = session.query(Episodes).all() + for e in rs: + print(f"{e.episode_id}: {e.title}") + + fill_all_tables(session.connection()) + + try: + session.add_all( + [ + Episodes( + series_id=2, + season_id=1, + episode_id=1, + title="Minimum Viable Product", + air_date=to_days("2014-04-06"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=2, + title="The Cap Table", + air_date=to_days("2014-04-13"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=3, + title="Articles of Incorporation", + air_date=to_days("2014-04-20"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=4, + title="Fiduciary Duties", + air_date=to_days("2014-04-27"), + ), + Episodes( + series_id=2, + season_id=1, + episode_id=5, + title="Signaling Risk", + air_date=to_days("2014-05-04"), + ), + ] + ) + session.commit() + except exc.DatabaseError: + print("Episodes already added!") + session.rollback() + + rs = session.query(Episodes).all() + for e in rs: + print(f"{e.episode_id}: {e.title}") + + rs = session.query(Episodes).filter(Episodes.title == "abc??").all() + for e in rs: + print(e.title) + + print("Episodes count:", session.query(Episodes).count()) + + max_episode = session.query(sql.expression.func.max(Episodes.episode_id)).scalar() + print("Maximum episodes id:", max_episode) + + session.add( + Episodes( + series_id=2, + season_id=1, + episode_id=max_episode + 1, + title="Signaling Risk", + air_date=to_days("2014-05-04"), + ) + ) + + print("Episodes count:", session.query(Episodes).count()) + + +def run_example_core(engine): + with engine.connect() as conn: + # raw sql + rs = conn.execute("SELECT 1 AS value") + print(rs.fetchone()["value"]) + + fill_all_tables(conn) + + for t in "series seasons episodes".split(): + describe_table(engine, t) + + tb = sa.Table("episodes", sa.MetaData(engine), autoload=True) + stm = ( + sa.select([tb.c.title]) + .where(sa.and_(tb.c.series_id == 1, tb.c.season_id == 3)) + .where(tb.c.title.like("%")) + .order_by(sa.asc(tb.c.title)) + # TODO: limit isn't working now + # .limit(3) + ) + rs = conn.execute(stm) + print(rs.fetchall()) + + simple_select(conn) + + simple_insert(conn) + + # simple join + stm = sa.select( + [Episodes.__table__.join(Series, Episodes.series_id == Series.series_id)] + ).where(sa.and_(Series.series_id == 1, Episodes.season_id == 1)) + rs = conn.execute(stm) + for row in rs: + print(f"{row.series_title}({row.episode_id}): {row.title}") + + rs = conn.execute(sa.select(Episodes).where(Episodes.series_id == 3)) + print(rs.fetchall()) + + # count + cnt = conn.execute(sa.func.count(Episodes.episode_id)).scalar() + print("Episodes cnt:", cnt) + + # simple delete + conn.execute(sa.delete(Episodes).where(Episodes.title == "TBD")) + cnt = conn.execute(sa.func.count(Episodes.episode_id)).scalar() + print("Episodes cnt:", cnt) + + test_types(conn) + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""\033[92mYandex.Database examples sqlalchemy usage.\x1b[0m\n""", + ) + parser.add_argument( + "-d", + "--database", + help="Name of the database to use", + default="/local", + ) + parser.add_argument( + "-e", + "--endpoint", + help="Endpoint url to use", + default="grpc://localhost:2136", + ) + + args = parser.parse_args() + register_dialect() + engine = sa.create_engine( + "yql:///ydb/", + connect_args={"database": args.database, "endpoint": args.endpoint}, + ) + + logging.basicConfig(level=logging.INFO) + logging.getLogger("_sqlalchemy.engine").setLevel(logging.INFO) + + run_example_core(engine) + # run_example_orm(engine) + + +if __name__ == "__main__": + main() diff --git a/examples/_sqlalchemy_example/fill_tables.py b/examples/_sqlalchemy_example/fill_tables.py new file mode 100644 index 00000000..5a9eb954 --- /dev/null +++ b/examples/_sqlalchemy_example/fill_tables.py @@ -0,0 +1,83 @@ +import iso8601 + +import sqlalchemy as sa +from models import Base, Series, Seasons, Episodes + + +def to_days(date): + timedelta = iso8601.parse_date(date) - iso8601.parse_date("1970-1-1") + return timedelta.days + + +def fill_series(conn): + data = [ + ( + 1, + "IT Crowd", + "The IT Crowd is a British sitcom produced by Channel 4, written by Graham Linehan, produced by " + "Ash Atalla and starring Chris O'Dowd, Richard Ayoade, Katherine Parkinson, and Matt Berry.", + to_days("2006-02-03"), + ), + ( + 2, + "Silicon Valley", + "Silicon Valley is an American comedy television series created by Mike Judge, John Altschuler and " + "Dave Krinsky. The series focuses on five young men who founded a startup company in Silicon Valley.", + to_days("2014-04-06"), + ), + ] + conn.execute(sa.insert(Series).values(data)) + + +def fill_seasons(conn): + data = [ + (1, 1, "Season 1", to_days("2006-02-03"), to_days("2006-03-03")), + (1, 2, "Season 2", to_days("2007-08-24"), to_days("2007-09-28")), + (1, 3, "Season 3", to_days("2008-11-21"), to_days("2008-12-26")), + (1, 4, "Season 4", to_days("2010-06-25"), to_days("2010-07-30")), + (2, 1, "Season 1", to_days("2014-04-06"), to_days("2014-06-01")), + (2, 2, "Season 2", to_days("2015-04-12"), to_days("2015-06-14")), + (2, 3, "Season 3", to_days("2016-04-24"), to_days("2016-06-26")), + (2, 4, "Season 4", to_days("2017-04-23"), to_days("2017-06-25")), + (2, 5, "Season 5", to_days("2018-03-25"), to_days("2018-05-13")), + ] + conn.execute(sa.insert(Seasons).values(data)) + + +def fill_episodes(conn): + data = [ + (1, 1, 1, "Yesterday's Jam", to_days("2006-02-03")), + (1, 1, 2, "Calamity Jen", to_days("2006-02-03")), + (1, 1, 3, "Fifty-Fifty", to_days("2006-02-10")), + (1, 1, 4, "The Red Door", to_days("2006-02-17")), + (1, 1, 5, "The Haunting of Bill Crouse", to_days("2006-02-24")), + (1, 1, 6, "Aunt Irma Visits", to_days("2006-03-03")), + (1, 2, 1, "The Work Outing", to_days("2006-08-24")), + (1, 2, 2, "Return of the Golden Child", to_days("2007-08-31")), + (1, 2, 3, "Moss and the German", to_days("2007-09-07")), + (1, 2, 4, "The Dinner Party", to_days("2007-09-14")), + (1, 2, 5, "Smoke and Mirrors", to_days("2007-09-21")), + (1, 2, 6, "Men Without Women", to_days("2007-09-28")), + (1, 3, 1, "From Hell", to_days("2008-11-21")), + (1, 3, 2, "Are We Not Men?", to_days("2008-11-28")), + (1, 3, 3, "Tramps Like Us", to_days("2008-12-05")), + (1, 3, 4, "The Speech", to_days("2008-12-12")), + (1, 3, 5, "Friendface", to_days("2008-12-19")), + (1, 3, 6, "Calendar Geeks", to_days("2008-12-26")), + (1, 4, 1, "Jen The Fredo", to_days("2010-06-25")), + (1, 4, 2, "The Final Countdown", to_days("2010-07-02")), + (1, 4, 3, "Something Happened", to_days("2010-07-09")), + (1, 4, 4, "Italian For Beginners", to_days("2010-07-16")), + (1, 4, 5, "Bad Boys", to_days("2010-07-23")), + (1, 4, 6, "Reynholm vs Reynholm", to_days("2010-07-30")), + ] + conn.execute(sa.insert(Episodes).values(data)) + + +def fill_all_tables(conn): + Base.metadata.drop_all(conn.engine) + Base.metadata.create_all(conn.engine) + + fill_series(conn) + fill_seasons(conn) + fill_episodes(conn) diff --git a/examples/_sqlalchemy_example/models.py b/examples/_sqlalchemy_example/models.py new file mode 100644 index 00000000..a02349a9 --- /dev/null +++ b/examples/_sqlalchemy_example/models.py @@ -0,0 +1,34 @@ +import sqlalchemy.orm as orm +from sqlalchemy import Column, Integer, Unicode + + +Base = orm.declarative_base() + + +class Series(Base): + __tablename__ = "series" + + series_id = Column(Integer, primary_key=True) + title = Column(Unicode) + series_info = Column(Unicode) + release_date = Column(Integer) + + +class Seasons(Base): + __tablename__ = "seasons" + + series_id = Column(Integer, primary_key=True) + season_id = Column(Integer, primary_key=True) + title = Column(Unicode) + first_aired = Column(Integer) + last_aired = Column(Integer) + + +class Episodes(Base): + __tablename__ = "episodes" + + series_id = Column(Integer, primary_key=True) + season_id = Column(Integer, primary_key=True) + episode_id = Column(Integer, primary_key=True) + title = Column(Unicode) + air_date = Column(Integer) diff --git a/examples/reservations-bot-demo/cloud_function/requirements.txt b/examples/reservations-bot-demo/cloud_function/requirements.txt index 9b243fbb..56665263 100644 --- a/examples/reservations-bot-demo/cloud_function/requirements.txt +++ b/examples/reservations-bot-demo/cloud_function/requirements.txt @@ -11,7 +11,6 @@ pycparser==2.20 pydantic==1.6.2 PyJWT==2.4.0 requests==2.24.0 -six==1.15.0 urllib3==1.26.5 yandexcloud==0.48.0 ydb==0.0.41 diff --git a/examples/reservations-bot-demo/cloud_function/utils.py b/examples/reservations-bot-demo/cloud_function/utils.py index 52aa56f9..10e0d53d 100644 --- a/examples/reservations-bot-demo/cloud_function/utils.py +++ b/examples/reservations-bot-demo/cloud_function/utils.py @@ -7,7 +7,7 @@ def make_driver_config(endpoint, database, path): return ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), + credentials=ydb.credentials_from_env_variables(), root_certificates=ydb.load_ydb_root_certificate(), ) diff --git a/examples/secondary_indexes_builtin/secondary_indexes_builtin.py b/examples/secondary_indexes_builtin/secondary_indexes_builtin.py index 03949970..877ed84c 100644 --- a/examples/secondary_indexes_builtin/secondary_indexes_builtin.py +++ b/examples/secondary_indexes_builtin/secondary_indexes_builtin.py @@ -245,7 +245,7 @@ def callee(session): def run(endpoint, database, path): driver_config = ydb.DriverConfig( - endpoint, database, credentials=ydb.construct_credentials_from_environ() + endpoint, database, credentials=ydb.credentials_from_env_variables() ) with ydb.Driver(driver_config) as driver: try: diff --git a/examples/time-series-serverless/database.py b/examples/time-series-serverless/database.py index df3e2f6a..8e8d9181 100644 --- a/examples/time-series-serverless/database.py +++ b/examples/time-series-serverless/database.py @@ -14,7 +14,7 @@ def create_driver(self) -> ydb.Driver: driver_config = ydb.DriverConfig( self.config.endpoint, self.config.database, - credentials=ydb.construct_credentials_from_environ(), + credentials=ydb.credentials_from_env_variables(), root_certificates=ydb.load_ydb_root_certificate(), ) diff --git a/examples/topic/reader_async_example.py b/examples/topic/reader_async_example.py new file mode 100644 index 00000000..848aeed2 --- /dev/null +++ b/examples/topic/reader_async_example.py @@ -0,0 +1,104 @@ +import asyncio +import time + +import ydb + + +async def connect(): + db = ydb.aio.Driver( + connection_string="grpc://localhost:2135?database=/local", + credentials=ydb.credentials.AnonymousCredentials(), + ) + reader = db.topic_client.reader("/local/topic", consumer="consumer") + return reader + + +async def create_reader_and_close_with_context_manager(db: ydb.aio.Driver): + async with db.topic_client.reader( + "/database/topic/path", consumer="consumer" + ) as reader: # noqa + ... + + +async def print_message_content(reader: ydb.TopicReaderAsyncIO): + while True: + message = await reader.receive_message() + print("text", message.data.read().decode("utf-8")) + # await and async_commit need only for sync commit mode - for wait ack from servr + await reader.commit(message) + + +async def process_messages_batch_with_commit(reader: ydb.TopicReaderAsyncIO): + # Explicit commit example + while True: + batch = await reader.receive_batch() + ... + await reader.commit(batch) + + +async def get_message_with_timeout(reader: ydb.TopicReaderAsyncIO): + try: + message = await asyncio.wait_for(reader.receive_message(), timeout=1) + except asyncio.TimeoutError: + print("Have no new messages in a second") + return + + print("mess", message.data) + + +async def get_all_messages_with_small_wait(reader: ydb.TopicReaderAsyncIO): + while True: + try: + message = await reader.receive_message() + await _process(message) + except asyncio.TimeoutError: + print("Have no new messages in a second") + + +async def get_a_message_from_external_loop(reader: ydb.TopicReaderAsyncIO): + for i in range(10): + try: + message = await asyncio.wait_for(reader.receive_message(), timeout=1) + except asyncio.TimeoutError: + return + await _process(message) + + +async def get_one_batch_from_external_loop_async(reader: ydb.TopicReaderAsyncIO): + for i in range(10): + try: + batch = await asyncio.wait_for(reader.receive_batch(), timeout=2) + except asyncio.TimeoutError: + return + + for message in batch.messages: + await _process(message) + await reader.commit(batch) + + +async def handle_partition_stop(reader: ydb.TopicReaderAsyncIO): + while True: + message = await reader.receive_message() + time.sleep(123) # some work + if message.alive: + time.sleep(1) # some other work + await reader.commit(message) + + +async def handle_partition_stop_batch(reader: ydb.TopicReaderAsyncIO): + def process_batch(batch): + for message in batch.messages: + if not batch.alive: + # no reason work with expired batch + # go read next - good batch + return + _process(message) + reader.commit(batch) + + while True: + batch = await reader.receive_batch() + process_batch(batch) + + +async def _process(msg): + raise NotImplementedError() diff --git a/examples/topic/reader_example.py b/examples/topic/reader_example.py new file mode 100644 index 00000000..f1c4eb73 --- /dev/null +++ b/examples/topic/reader_example.py @@ -0,0 +1,112 @@ +import time + +import ydb + + +def connect(): + db = ydb.Driver( + connection_string="grpc://localhost:2135?database=/local", + credentials=ydb.credentials.AnonymousCredentials(), + ) + reader = db.topic_client.reader("/local/topic", consumer="consumer") + return reader + + +def create_reader_and_close_with_context_manager(db: ydb.Driver): + with db.topic_client.reader( + "/database/topic/path", consumer="consumer", buffer_size_bytes=123 + ) as reader: + while True: + message = reader.receive_message() # noqa + pass + + +def print_message_content(reader: ydb.TopicReader): + while True: + message = reader.receive_message() + print("text", message.data.read().decode("utf-8")) + reader.commit(message) + + +def process_messages_batch_explicit_commit(reader: ydb.TopicReader): + while True: + batch = reader.receive_batch() + for message in batch.messages: + _process(message) + reader.commit(batch) + + +def get_message_with_timeout(reader: ydb.TopicReader): + try: + message = reader.receive_message(timeout=1) + except TimeoutError: + print("Have no new messages in a second") + return + + print("mess", message.data) + + +def get_all_messages_with_small_wait(reader: ydb.TopicReader): + while True: + try: + message = reader.receive_message(timeout=1) + _process(message) + except TimeoutError: + print("Have no new messages in a second") + + +def get_a_message_from_external_loop(reader: ydb.TopicReader): + for i in range(10): + try: + message = reader.receive_message(timeout=1) + except TimeoutError: + return + _process(message) + + +def get_one_batch_from_external_loop(reader: ydb.TopicReader): + for i in range(10): + try: + batch = reader.receive_batch(timeout=2) + except TimeoutError: + return + + for message in batch.messages: + _process(message) + reader.commit(batch) + + +def handle_partition_stop(reader: ydb.TopicReader): + while True: + message = reader.receive_message() + time.sleep(123) # some work + if message.alive: + time.sleep(1) # some other work + reader.commit(message) + + +def handle_partition_stop_batch(reader: ydb.TopicReader): + def process_batch(batch): + for message in batch.messages: + if not batch.alive: + # no reason work with expired batch + # go read next - good batch + return + _process(message) + reader.commit(batch) + + while True: + batch = reader.receive_batch() + process_batch(batch) + + +def handle_partition_graceful_stop_batch(reader: ydb.TopicReader): + # no special handle, but batch will contain less than prefer count messages + while True: + batch = reader.receive_batch() + _process(batch) + reader.commit(batch) + + +def _process(msg): + raise NotImplementedError() diff --git a/examples/topic/writer_async_example.py b/examples/topic/writer_async_example.py new file mode 100644 index 00000000..4d11a86d --- /dev/null +++ b/examples/topic/writer_async_example.py @@ -0,0 +1,121 @@ +import asyncio +import datetime +from typing import Dict, List + +import ydb +from ydb import TopicWriterMessage + + +async def create_writer(db: ydb.aio.Driver): + async with db.topic_client.writer( + "/database/topic/path", + producer_id="producer-id", + ) as writer: + await writer.write(TopicWriterMessage("asd")) + + +async def connect_and_wait(db: ydb.aio.Driver): + async with db.topic_client.writer( + "/database/topic/path", + producer_id="producer-id", + ) as writer: + info = await writer.wait_init() # noqa + ... + + +async def connect_without_context_manager(db: ydb.aio.Driver): + writer = db.topic_client.writer( + "/database/topic/path", + producer_id="producer-id", + ) + try: + pass # some code + finally: + await writer.close() + + +async def send_messages(writer: ydb.TopicWriterAsyncIO): + # simple str/bytes without additional metadata + await writer.write("mess") # send text + await writer.write(bytes([1, 2, 3])) # send single message with bytes 1,2,3 + await writer.write(["mess-1", "mess-2"]) # send two messages + + # full forms + await writer.write(ydb.TopicWriterMessage("mess")) # send text + await writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes + await writer.write( + [ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2")] + ) # send few messages by one call + + # with meta + await writer.write( + ydb.TopicWriterMessage("asd", seqno=123, created_at=datetime.datetime.now()) + ) + + +async def send_message_without_block_if_internal_buffer_is_full( + writer: ydb.TopicWriterAsyncIO, msg +) -> bool: + try: + # put message to internal queue for send, but if buffer is full - fast return + # without wait + await asyncio.wait_for(writer.write(msg), 0) + return True + except TimeoutError(): + return False + + +async def send_messages_with_manual_seqno(writer: ydb.TopicWriter): + await writer.write(ydb.TopicWriterMessage("mess")) # send text + + +async def send_messages_with_wait_ack(writer: ydb.TopicWriterAsyncIO): + # future wait + await writer.write_with_ack( + [ + ydb.TopicWriterMessage("mess", seqno=1), + ydb.TopicWriterMessage("mess", seqno=2), + ] + ) + + # send with flush + await writer.write(["1", "2", "3"]) + await writer.flush() + + +async def send_messages_and_wait_all_commit_with_flush(writer: ydb.TopicWriterAsyncIO): + for i in range(10): + await writer.write(ydb.TopicWriterMessage("%s" % i)) + await writer.flush() + + +async def send_messages_and_wait_all_commit_with_results( + writer: ydb.TopicWriterAsyncIO, +): + for i in range(10): + content = "%s" % i + await writer.write(content) + + await writer.flush() + + +async def switch_messages_with_many_producers( + writers: Dict[str, ydb.TopicWriterAsyncIO], messages: List[str] +): + futures = [] # type: List[asyncio.Future] + + for msg in messages: + # select writer for the msg + writer_idx = msg[:1] + writer = writers[writer_idx] + future = await writer.write_with_ack_future(msg) + futures.append(future) + + # wait acks from all writes + await asyncio.wait(futures) + for future in futures: + if future.exception() is not None: + raise future.exception() + + # all ok, explicit return - for better + return diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py new file mode 100644 index 00000000..14b401e0 --- /dev/null +++ b/examples/topic/writer_example.py @@ -0,0 +1,139 @@ +import concurrent.futures +import datetime +from typing import Dict, List +from concurrent.futures import Future, wait + +import ydb +from ydb import TopicWriterMessage + + +def connect(): + db = ydb.Driver( + connection_string="grpc://localhost:2135?database=/local", + credentials=ydb.credentials.AnonymousCredentials(), + ) + writer = db.topic_client.writer( + "/local/topic", + producer_id="producer-id", + ) + writer.write(TopicWriterMessage("asd")) + + +def create_writer(db: ydb.Driver): + with db.topic_client.writer( + "/database/topic/path", + producer_id="producer-id", + ) as writer: + writer.write(TopicWriterMessage("asd")) + + +def connect_and_wait(db: ydb.Driver): + with db.topic_client.writer( + "/database/topic/path", + producer_id="producer-id", + ) as writer: + info = writer.wait_init() # noqa + + +def connect_without_context_manager(db: ydb.Driver): + writer = db.topic_client.writer( + "/database/topic/path", + producer_id="producer-id", + ) + try: + pass # some code + finally: + writer.close() + + +def send_messages(writer: ydb.TopicWriter): + # simple str/bytes without additional metadata + writer.write("mess") # send text + writer.write(bytes([1, 2, 3])) # send single message with bytes 1,2,3 + writer.write(["mess-1", "mess-2"]) # send two messages + + # full forms + writer.write(ydb.TopicWriterMessage("mess")) # send text + writer.write(ydb.TopicWriterMessage(bytes([1, 2, 3]))) # send bytes + writer.write( + [ydb.TopicWriterMessage("mess-1"), ydb.TopicWriterMessage("mess-2")] + ) # send few messages by one call + + # with meta + writer.write( + ydb.TopicWriterMessage("asd", seqno=123, created_at=datetime.datetime.now()) + ) + + +def send_message_without_block_if_internal_buffer_is_full( + writer: ydb.TopicWriter, msg +) -> bool: + try: + # put message to internal queue for send, but if buffer is full - fast return + # without wait + writer.write(msg, timeout=0) + return True + except TimeoutError(): + return False + + +def send_messages_with_manual_seqno(writer: ydb.TopicWriter): + writer.write(ydb.TopicWriterMessage("mess")) # send text + + +def send_messages_with_wait_ack(writer: ydb.TopicWriter): + # Explicit future wait + writer.async_write_with_ack( + [ + ydb.TopicWriterMessage("mess", seqno=1), + ydb.TopicWriterMessage("mess", seqno=2), + ] + ).result() + + # implicit, by sync call + writer.write_with_ack( + [ + ydb.TopicWriterMessage("mess", seqno=1), + ydb.TopicWriterMessage("mess", seqno=2), + ] + ) + # write_with_ack + + # send with flush + writer.write(["1", "2", "3"]) + writer.flush() + + +def send_messages_and_wait_all_commit_with_flush(writer: ydb.TopicWriter): + for i in range(10): + content = "%s" % i + writer.write(content) + writer.flush() + + +def send_messages_and_wait_all_commit_with_results(writer: ydb.TopicWriter): + futures = [] # type: List[concurrent.futures.Future] + for i in range(10): + future = writer.async_write_with_ack() + futures.append(future) + + concurrent.futures.wait(futures) + for future in futures: + if future.exception() is not None: + raise future.exception() + + +def switch_messages_with_many_producers( + writers: Dict[str, ydb.TopicWriter], messages: List[str] +): + futures = [] # type: List[Future] + + for msg in messages: + # select writer for the msg + writer_idx = msg[:1] + writer = writers[writer_idx] + future = writer.async_write_with_ack(msg) + futures.append(future) + + # wait acks from all writes + wait(futures) diff --git a/examples/ttl/ttl.py b/examples/ttl/ttl.py index bdbe0dec..7a4c5cc4 100644 --- a/examples/ttl/ttl.py +++ b/examples/ttl/ttl.py @@ -307,7 +307,7 @@ def _run(driver, database, path): def run(endpoint, database, path): driver_config = ydb.DriverConfig( - endpoint, database, credentials=ydb.construct_credentials_from_environ() + endpoint, database, credentials=ydb.credentials_from_env_variables() ) with ydb.Driver(driver_config) as driver: try: diff --git a/examples/ttl_readtable/ttl.py b/examples/ttl_readtable/ttl.py index f5fff741..f9ca50aa 100644 --- a/examples/ttl_readtable/ttl.py +++ b/examples/ttl_readtable/ttl.py @@ -288,7 +288,7 @@ def _run(driver, session_pool, database, path): def run(endpoint, database, path): driver_config = ydb.DriverConfig( - endpoint, database, credentials=ydb.construct_credentials_from_environ() + endpoint, database, credentials=ydb.credentials_from_env_variables() ) with ydb.Driver(driver_config) as driver: try: diff --git a/kikimr/public/api/__init__.py b/kikimr/public/api/__init__.py deleted file mode 100644 index c97daf69..00000000 --- a/kikimr/public/api/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from ydb.public.api import * # noqa -import sys -import warnings - -sys.modules['kikimr.public.api'] = sys.modules['ydb.public.api'] -warnings.warn("using kikimr.public.api module is deprecated. please use ydb.public.api import instead") diff --git a/kikimr/public/sdk/python/client/__init__.py b/kikimr/public/sdk/python/client/__init__.py deleted file mode 100644 index e2050d47..00000000 --- a/kikimr/public/sdk/python/client/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# -*- coding: utf-8 -*- -from ydb import * # noqa -import sys -import six -import warnings - -warnings.warn("module kikimr.public.sdk.python.client is deprecated. please use ydb instead") - - -for name, module in six.iteritems(sys.modules.copy()): - if not name.startswith("ydb"): - continue - - if name.startswith("ydb.public"): - continue - - module_import_path = name.split('.') - if len(module_import_path) < 2: - continue - - sys.modules['kikimr.public.sdk.python.client.' + '.'.join(module_import_path[1:])] = module diff --git a/kikimr/public/sdk/python/client/frameworks/__init__.py b/kikimr/public/sdk/python/client/frameworks/__init__.py deleted file mode 100644 index e5c4940a..00000000 --- a/kikimr/public/sdk/python/client/frameworks/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -try: - from ydb.tornado import * # noqa - import sys - import warnings - - warnings.warn("module kikimr.public.sdk.python.client.frameworks is deprecated. please use ydb.tornado instead") - - sys.modules['kikimr.public.sdk.python.client.frameworks.tornado_helpers'] = sys.modules['ydb.tornado.tornado_helpers'] -except ImportError: - pass diff --git a/kikimr/public/sdk/python/iam/__init__.py b/kikimr/public/sdk/python/iam/__init__.py deleted file mode 100644 index 884d5c1b..00000000 --- a/kikimr/public/sdk/python/iam/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ydb.iam import * # noqa -import sys -import warnings - -warnings.warn("using kikimr.public.sdk.python.iam module is deprecated. please use ydb.iam import instead") - -sys.modules['kikimr.public.sdk.python.iam.auth'] = sys.modules['ydb.iam.auth'] diff --git a/kikimr/stub.txt b/kikimr/stub.txt new file mode 100644 index 00000000..7467d31a --- /dev/null +++ b/kikimr/stub.txt @@ -0,0 +1 @@ +the folder must not use for prevent issues with intersect with old packages. diff --git a/requirements.txt b/requirements.txt index 8ee59727..6a60288e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,4 @@ -aiohttp==3.7.4 -enum-compat>=0.0.1 -grpcio>=1.5.0 +grpcio>=1.42.0 packaging -protobuf>3.13.0,<5.0.0 -pytest==6.2.4 -six<2 +protobuf>=3.13.0,<5.0.0 +aiohttp==3.7.4 diff --git a/setup.py b/setup.py index 879af9bc..6d322334 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ setuptools.setup( name="ydb", - version="2.13.3", # AUTOVERSION + version="3.0.1b11", # AUTOVERSION description="YDB Python SDK", author="Yandex LLC", author_email="ydb@yandex-team.ru", diff --git a/test-requirements.txt b/test-requirements.txt index 433096d7..af6c3a53 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -11,8 +11,8 @@ docker==5.0.0 docker-compose==1.29.2 dockerpty==0.4.1 docopt==0.6.2 -grpcio -grpcio-tools +grpcio==1.47.0 +grpcio-tools==1.47.0 idna==3.2 importlib-metadata==4.6.1 iniconfig==1.1.1 @@ -20,7 +20,7 @@ jsonschema==3.2.0 packaging==21.0 paramiko==2.10.1 pluggy==0.13.1 -protobuf>3.13.0,<5.0.0 +protobuf==3.13.0 py==1.10.0 pycparser==2.20 PyNaCl==1.4.0 @@ -32,7 +32,6 @@ pytest-docker-compose==3.2.1 python-dotenv==0.18.0 PyYAML==5.4.1 requests==2.26.0 -six==1.16.0 texttable==1.6.4 toml==0.10.2 typing-extensions==3.10.0.0 @@ -46,4 +45,7 @@ flake8==3.9.2 sqlalchemy==1.4.26 pylint-protobuf cython +freezegun==1.2.2 grpcio-tools +pytest-cov +-e . diff --git a/tests/_sqlalchemy/_test_inspect.py b/tests/_sqlalchemy/_test_inspect.py new file mode 100644 index 00000000..69a57243 --- /dev/null +++ b/tests/_sqlalchemy/_test_inspect.py @@ -0,0 +1,24 @@ +import ydb + +import sqlalchemy as sa + + +def test_get_columns(driver_sync, engine): + session = ydb.retry_operation_sync( + lambda: driver_sync.table_client.session().create() + ) + session.execute_scheme( + "CREATE TABLE test(id Int64 NOT NULL, value TEXT, num DECIMAL(22, 9), PRIMARY KEY (id))" + ) + inspect = sa.inspect(engine) + columns = inspect.get_columns("test") + for c in columns: + c["type"] = type(c["type"]) + + assert columns == [ + {"name": "id", "type": sa.INTEGER, "nullable": False}, + {"name": "value", "type": sa.TEXT, "nullable": True}, + {"name": "num", "type": sa.DECIMAL, "nullable": True}, + ] + + session.execute_scheme("DROP TABLE test") diff --git a/tests/_sqlalchemy/conftest.py b/tests/_sqlalchemy/conftest.py new file mode 100644 index 00000000..6ebabac3 --- /dev/null +++ b/tests/_sqlalchemy/conftest.py @@ -0,0 +1,22 @@ +import pytest +import sqlalchemy as sa + +from ydb._sqlalchemy import register_dialect + + +@pytest.fixture(scope="module") +def engine(endpoint, database): + register_dialect() + engine = sa.create_engine( + "yql:///ydb/", + connect_args={"database": database, "endpoint": endpoint}, + ) + + yield engine + engine.dispose() + + +@pytest.fixture(scope="module") +def connection(engine): + with engine.connect() as conn: + yield conn diff --git a/tests/_sqlalchemy/test_dbapi.py b/tests/_sqlalchemy/test_dbapi.py new file mode 100644 index 00000000..407cc9e4 --- /dev/null +++ b/tests/_sqlalchemy/test_dbapi.py @@ -0,0 +1,84 @@ +from ydb import _dbapi as dbapi + + +def test_dbapi(endpoint, database): + conn = dbapi.connect(endpoint, database=database) + assert conn + + conn.commit() + conn.rollback() + + cur = conn.cursor() + assert cur + + cur.execute( + "CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", + context={"isddl": True}, + ) + + cur.execute('INSERT INTO test(id, text) VALUES (1, "foo")') + + cur.execute("SELECT id, text FROM test") + assert cur.fetchone() == (1, "foo"), "fetchone is ok" + + cur.execute("SELECT id, text FROM test WHERE id = %(id)s", {"id": 1}) + assert cur.fetchone() == (1, "foo"), "parametrized query is ok" + + cur.execute( + "INSERT INTO test(id, text) VALUES (%(id1)s, %(text1)s), (%(id2)s, %(text2)s)", + {"id1": 2, "text1": "", "id2": 3, "text2": "bar"}, + ) + + cur.execute( + "UPDATE test SET text = %(t)s WHERE id = %(id)s", {"id": 2, "t": "foo2"} + ) + + cur.execute("SELECT id FROM test") + assert cur.fetchall() == [(1,), (2,), (3,)], "fetchall is ok" + + cur.execute("SELECT id FROM test ORDER BY id DESC") + assert cur.fetchmany(2) == [(3,), (2,)], "fetchmany is ok" + assert cur.fetchmany(1) == [(1,)] + + cur.execute("SELECT id FROM test ORDER BY id LIMIT 2") + assert cur.fetchall() == [(1,), (2,)], "limit clause without params is ok" + + # TODO: Failed to convert type: Int64 to Uint64 + # cur.execute("SELECT id FROM test ORDER BY id LIMIT %(limit)s", {"limit": 2}) + # assert cur.fetchall() == [(1,), (2,)], "limit clause with params is ok" + + cur2 = conn.cursor() + cur2.execute( + "INSERT INTO test(id) VALUES (%(id1)s), (%(id2)s)", {"id1": 5, "id2": 6} + ) + + cur.execute("SELECT id FROM test ORDER BY id") + assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,)], "cursor2 commit changes" + + cur.execute("SELECT text FROM test WHERE id > %(min_id)s", {"min_id": 3}) + assert cur.fetchall() == [(None,), (None,)], "NULL returns as None" + + cur.execute("SELECT id, text FROM test WHERE text LIKE %(p)s", {"p": "foo%"}) + assert cur.fetchall() == [(1, "foo"), (2, "foo2")], "like clause works" + + cur.execute( + # DECLARE statement (DECLARE $data AS List>) + # will generate automatically + """INSERT INTO test SELECT id, text FROM AS_TABLE($data);""", + { + "data": [ + {"id": 17, "text": "seventeen"}, + {"id": 21, "text": "twenty one"}, + ] + }, + ) + + cur.execute("SELECT id FROM test ORDER BY id") + assert cur.rowcount == 7, "rowcount ok" + assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,), (17,), (21,)], "ok" + + cur.execute("DROP TABLE test", context={"isddl": True}) + + cur.close() + cur2.close() + conn.close() diff --git a/tests/_sqlalchemy/test_sqlalchemy.py b/tests/_sqlalchemy/test_sqlalchemy.py new file mode 100644 index 00000000..914553ea --- /dev/null +++ b/tests/_sqlalchemy/test_sqlalchemy.py @@ -0,0 +1,27 @@ +import sqlalchemy as sa +from sqlalchemy import MetaData, Table, Column, Integer, Unicode + +meta = MetaData() + + +def clear_sql(stm): + return stm.replace("\n", " ").replace(" ", " ").strip() + + +def test_sqlalchemy_core(connection): + # raw sql + rs = connection.execute("SELECT 1 AS value") + assert rs.fetchone()["value"] == 1 + + tb_test = Table( + "test", + meta, + Column("id", Integer, primary_key=True), + Column("text", Unicode), + ) + + stm = sa.select(tb_test) + assert clear_sql(str(stm)) == "SELECT test.id, test.text FROM test" + + stm = sa.insert(tb_test).values(id=2, text="foo") + assert clear_sql(str(stm)) == "INSERT INTO test (id, text) VALUES (:id, :text)" diff --git a/tests/aio/test_connection_pool.py b/tests/aio/test_connection_pool.py index 221f1e39..12882f38 100644 --- a/tests/aio/test_connection_pool.py +++ b/tests/aio/test_connection_pool.py @@ -11,8 +11,6 @@ async def test_async_call(endpoint, database): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = Driver(driver_config=driver_config) @@ -26,8 +24,6 @@ async def test_gzip_compression(endpoint, database): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), compression=ydb.RPCCompression.Gzip, ) @@ -53,8 +49,6 @@ async def test_session(endpoint, database): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = Driver(driver_config=driver_config) @@ -98,8 +92,6 @@ async def test_raises_when_disconnect(endpoint, database, docker_project): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = Driver(driver_config=driver_config) @@ -124,8 +116,6 @@ async def test_disconnect_by_call(endpoint, database, docker_project): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = Driver(driver_config=driver_config) diff --git a/tests/aio/test_tx.py b/tests/aio/test_tx.py index f421a70e..69e144af 100644 --- a/tests/aio/test_tx.py +++ b/tests/aio/test_tx.py @@ -85,6 +85,7 @@ async def test_tx_snapshot_ro(driver, database): await ro_tx.commit() + ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly()) with pytest.raises(ydb.issues.GenericError) as exc_info: await ro_tx.execute("UPDATE `test` SET value = value + 1") assert "read only transaction" in exc_info.value.message @@ -167,12 +168,14 @@ async def check_transaction(s: ydb.aio.table.Session): await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) await tx.commit() - await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + with pytest.raises(RuntimeError): + await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + await tx.commit() async with s.transaction() as tx: rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) - assert rs[0].rows[0].cnt == 2 + assert rs[0].rows[0].cnt == 1 await pool.retry_operation(check_transaction) @@ -193,14 +196,12 @@ async def test_truncated_response(driver, table_name, table_path): s = table_client.session() await s.create() t = s.transaction() - - result = await t.execute("SELECT * FROM %s" % table_name) - assert result[0].truncated - assert len(result[0].rows) == 1000 + with pytest.raises(ydb.TruncatedResponseError): + await t.execute("SELECT * FROM %s" % table_name) @pytest.mark.asyncio -async def test_truncated_response_deny(driver, table_name, table_path): +async def test_truncated_response_allow(driver, table_name, table_path): column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) rows = [] @@ -212,11 +213,11 @@ async def test_truncated_response_deny(driver, table_name, table_path): await driver.table_client.bulk_upsert(table_path, rows, column_types) table_client = ydb.TableClient( - driver, ydb.TableClientSettings().with_allow_truncated_result(False) + driver, ydb.TableClientSettings().with_allow_truncated_result(True) ) s = table_client.session() await s.create() t = s.transaction() - - with pytest.raises(ydb.TruncatedResponseError): - await t.execute("SELECT * FROM %s" % table_name) + result = await t.execute("SELECT * FROM %s" % table_name) + assert result[0].truncated + assert len(result[0].rows) == 1000 diff --git a/tests/conftest.py b/tests/conftest.py index a7b42c09..73281a4f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,9 @@ import os -from unittest import mock import pytest import ydb import time - - -@pytest.fixture(autouse=True, scope="session") -def mock_settings_env_vars(): - with mock.patch.dict(os.environ, {"YDB_ANONYMOUS_CREDENTIALS": "1"}): - yield +from ydb import issues @pytest.fixture(scope="module") @@ -18,7 +12,7 @@ def docker_compose_file(pytestconfig): def wait_container_ready(driver): - driver.wait(timeout=10) + driver.wait(timeout=30) with ydb.SessionPool(driver) as pool: @@ -86,8 +80,6 @@ async def driver(endpoint, database, event_loop): driver_config = ydb.DriverConfig( endpoint, database, - credentials=ydb.construct_credentials_from_environ(), - root_certificates=ydb.load_ydb_root_certificate(), ) driver = ydb.aio.Driver(driver_config=driver_config) @@ -143,3 +135,62 @@ def create_table(s): @pytest.fixture() def table_path(database, table_name) -> str: return database + "/" + table_name + + +@pytest.fixture() +def topic_consumer(): + return "fixture-consumer" + + +@pytest.fixture() +@pytest.mark.asyncio() +async def topic_path(driver, topic_consumer, database) -> str: + topic_path = database + "/test-topic" + + try: + await driver.topic_client.drop_topic(topic_path) + except issues.SchemeError: + pass + + await driver.topic_client.create_topic( + path=topic_path, + consumers=[topic_consumer], + ) + + return topic_path + + +@pytest.fixture() +@pytest.mark.asyncio() +async def topic_with_messages(driver, topic_path): + writer = driver.topic_client.writer( + topic_path, producer_id="fixture-producer-id", codec=ydb.TopicCodec.RAW + ) + await writer.write_with_ack( + [ + ydb.TopicWriterMessage(data="123".encode()), + ydb.TopicWriterMessage(data="456".encode()), + ] + ) + await writer.write_with_ack( + [ + ydb.TopicWriterMessage(data="789".encode()), + ydb.TopicWriterMessage(data="0".encode()), + ] + ) + await writer.close() + + +@pytest.fixture() +@pytest.mark.asyncio() +async def topic_reader(driver, topic_consumer, topic_path) -> ydb.TopicReaderAsyncIO: + reader = driver.topic_client.reader(topic=topic_path, consumer=topic_consumer) + yield reader + await reader.close() + + +@pytest.fixture() +def topic_reader_sync(driver_sync, topic_consumer, topic_path) -> ydb.TopicReader: + reader = driver_sync.topic_client.reader(topic=topic_path, consumer=topic_consumer) + yield reader + reader.close() diff --git a/tests/table/test_tx.py b/tests/table/test_tx.py index 1ee3b8a9..eb79c579 100644 --- a/tests/table/test_tx.py +++ b/tests/table/test_tx.py @@ -80,6 +80,7 @@ def test_tx_snapshot_ro(driver_sync, database): ro_tx.commit() + ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly()) with pytest.raises(ydb.issues.GenericError) as exc_info: ro_tx.execute("UPDATE `test` SET value = value + 1") assert "read only transaction" in exc_info.value.message @@ -91,7 +92,7 @@ def test_tx_snapshot_ro(driver_sync, database): assert data[0].rows == [{"value": 2}] -def test_split_transactions_deny_split_explicit_commit(driver_sync, table_name): +def test_split_transactions_deny_split(driver_sync, table_name): with ydb.SessionPool(driver_sync, 1) as pool: def check_transaction(s: ydb.table.Session): @@ -158,12 +159,14 @@ def check_transaction(s: ydb.table.Session): tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name) tx.commit() - tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + with pytest.raises(RuntimeError): + tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name) + tx.commit() with s.transaction() as tx: rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name) - assert rs[0].rows[0].cnt == 2 + assert rs[0].rows[0].cnt == 1 pool.retry_operation_sync(check_transaction) @@ -185,14 +188,11 @@ def test_truncated_response(driver_sync, table_name, table_path): s = table_client.session() s.create() t = s.transaction() - - result = t.execute("SELECT * FROM %s" % table_name) - assert result[0].truncated - assert len(result[0].rows) == 1000 + with pytest.raises(ydb.TruncatedResponseError): + t.execute("SELECT * FROM %s" % table_name) -@pytest.mark.asyncio -async def test_truncated_response_deny(driver_sync, table_name, table_path): +def test_truncated_response_allow(driver_sync, table_name, table_path): column_types = ydb.BulkUpsertColumns().add_column("id", ydb.PrimitiveType.Int64) rows = [] @@ -204,11 +204,11 @@ async def test_truncated_response_deny(driver_sync, table_name, table_path): driver_sync.table_client.bulk_upsert(table_path, rows, column_types) table_client = ydb.TableClient( - driver_sync, ydb.TableClientSettings().with_allow_truncated_result(False) + driver_sync, ydb.TableClientSettings().with_allow_truncated_result(True) ) s = table_client.session() s.create() t = s.transaction() - - with pytest.raises(ydb.TruncatedResponseError): - t.execute("SELECT * FROM %s" % table_name) + result = t.execute("SELECT * FROM %s" % table_name) + assert result[0].truncated + assert len(result[0].rows) == 1000 diff --git a/tests/topics/test_control_plane.py b/tests/topics/test_control_plane.py new file mode 100644 index 00000000..2446ddcf --- /dev/null +++ b/tests/topics/test_control_plane.py @@ -0,0 +1,74 @@ +import os.path + +import pytest + +from ydb import issues + + +@pytest.mark.asyncio +class TestTopicClientControlPlaneAsyncIO: + async def test_create_topic(self, driver, database): + client = driver.topic_client + + topic_path = database + "/my-test-topic" + + await client.create_topic(topic_path) + + with pytest.raises(issues.SchemeError): + # double create is ok - try create topic with bad path + await client.create_topic(database) + + async def test_drop_topic(self, driver, topic_path): + client = driver.topic_client + + await client.drop_topic(topic_path) + + with pytest.raises(issues.SchemeError): + await client.drop_topic(topic_path) + + async def test_describe_topic(self, driver, topic_path: str, topic_consumer): + res = await driver.topic_client.describe_topic(topic_path) + + assert res.self.name == os.path.basename(topic_path) + + has_consumer = False + for consumer in res.consumers: + if consumer.name == topic_consumer: + has_consumer = True + break + + assert has_consumer + + +class TestTopicClientControlPlane: + def test_create_topic(self, driver_sync, database): + client = driver_sync.topic_client + + topic_path = database + "/my-test-topic" + + client.create_topic(topic_path) + + with pytest.raises(issues.SchemeError): + # double create is ok - try create topic with bad path + client.create_topic(database) + + def test_drop_topic(self, driver_sync, topic_path): + client = driver_sync.topic_client + + client.drop_topic(topic_path) + + with pytest.raises(issues.SchemeError): + client.drop_topic(topic_path) + + def test_describe_topic(self, driver_sync, topic_path: str, topic_consumer): + res = driver_sync.topic_client.describe_topic(topic_path) + + assert res.self.name == os.path.basename(topic_path) + + has_consumer = False + for consumer in res.consumers: + if consumer.name == topic_consumer: + has_consumer = True + break + + assert has_consumer diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py new file mode 100644 index 00000000..03731b76 --- /dev/null +++ b/tests/topics/test_topic_reader.py @@ -0,0 +1,137 @@ +import pytest + +import ydb + + +@pytest.mark.asyncio +class TestTopicReaderAsyncIO: + async def test_read_batch( + self, driver, topic_path, topic_with_messages, topic_consumer + ): + reader = driver.topic_client.reader(topic_path, topic_consumer) + batch = await reader.receive_batch() + + assert batch is not None + assert len(batch.messages) > 0 + + await reader.close() + + async def test_read_message( + self, driver, topic_path, topic_with_messages, topic_consumer + ): + reader = driver.topic_client.reader(topic_path, topic_consumer) + msg = await reader.receive_message() + + assert msg is not None + assert msg.seqno + + await reader.close() + + async def test_read_and_commit_message( + self, driver, topic_path, topic_with_messages, topic_consumer + ): + + reader = driver.topic_client.reader(topic_path, topic_consumer) + batch = await reader.receive_batch() + await reader.commit_with_ack(batch) + + reader = driver.topic_client.reader(topic_path, topic_consumer) + batch2 = await reader.receive_batch() + assert batch.messages[0] != batch2.messages[0] + + await reader.close() + + async def test_read_compressed_messages(self, driver, topic_path, topic_consumer): + async with driver.topic_client.writer( + topic_path, codec=ydb.TopicCodec.GZIP + ) as writer: + await writer.write("123") + + async with driver.topic_client.reader(topic_path, topic_consumer) as reader: + batch = await reader.receive_batch() + assert batch.messages[0].data.decode() == "123" + + async def test_read_custom_encoded(self, driver, topic_path, topic_consumer): + codec = 10001 + + def encode(b: bytes): + return bytes(reversed(b)) + + def decode(b: bytes): + return bytes(reversed(b)) + + async with driver.topic_client.writer( + topic_path, codec=codec, encoders={codec: encode} + ) as writer: + await writer.write("123") + + async with driver.topic_client.reader( + topic_path, topic_consumer, decoders={codec: decode} + ) as reader: + batch = await reader.receive_batch() + assert batch.messages[0].data.decode() == "123" + + +class TestTopicReaderSync: + def test_read_batch( + self, driver_sync, topic_path, topic_with_messages, topic_consumer + ): + reader = driver_sync.topic_client.reader(topic_path, topic_consumer) + batch = reader.receive_batch() + + assert batch is not None + assert len(batch.messages) > 0 + + reader.close() + + def test_read_message( + self, driver_sync, topic_path, topic_with_messages, topic_consumer + ): + reader = driver_sync.topic_client.reader(topic_path, topic_consumer) + msg = reader.receive_message() + + assert msg is not None + assert msg.seqno + + reader.close() + + def test_read_and_commit_message( + self, driver_sync, topic_path, topic_with_messages, topic_consumer + ): + reader = driver_sync.topic_client.reader(topic_path, topic_consumer) + batch = reader.receive_batch() + reader.commit_with_ack(batch) + + reader = driver_sync.topic_client.reader(topic_path, topic_consumer) + batch2 = reader.receive_batch() + assert batch.messages[0] != batch2.messages[0] + + def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer): + with driver_sync.topic_client.writer( + topic_path, codec=ydb.TopicCodec.GZIP + ) as writer: + writer.write("123") + + with driver_sync.topic_client.reader(topic_path, topic_consumer) as reader: + batch = reader.receive_batch() + assert batch.messages[0].data.decode() == "123" + + def test_read_custom_encoded(self, driver_sync, topic_path, topic_consumer): + codec = 10001 + + def encode(b: bytes): + return bytes(reversed(b)) + + def decode(b: bytes): + return bytes(reversed(b)) + + with driver_sync.topic_client.writer( + topic_path, codec=codec, encoders={codec: encode} + ) as writer: + writer.write("123") + + with driver_sync.topic_client.reader( + topic_path, topic_consumer, decoders={codec: decode} + ) as reader: + batch = reader.receive_batch() + assert batch.messages[0].data.decode() == "123" diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py new file mode 100644 index 00000000..68c34a8e --- /dev/null +++ b/tests/topics/test_topic_writer.py @@ -0,0 +1,193 @@ +import pytest + +import ydb.aio + + +@pytest.mark.asyncio +class TestTopicWriterAsyncIO: + async def test_send_message(self, driver: ydb.aio.Driver, topic_path): + writer = driver.topic_client.writer(topic_path, producer_id="test") + await writer.write(ydb.TopicWriterMessage(data="123".encode())) + await writer.close() + + async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): + async with driver.topic_client.writer( + topic_path, + producer_id="test", + auto_seqno=False, + ) as writer: + await writer.write_with_ack( + ydb.TopicWriterMessage(data="123".encode(), seqno=5) + ) + + async with driver.topic_client.writer( + topic_path, + producer_id="test", + ) as writer2: + init_info = await writer2.wait_init() + assert init_info.last_seqno == 5 + + async def test_random_producer_id( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + async with driver.topic_client.writer(topic_path) as writer: + await writer.write(ydb.TopicWriterMessage(data="123".encode())) + async with driver.topic_client.writer(topic_path) as writer: + await writer.write(ydb.TopicWriterMessage(data="123".encode())) + + batch1 = await topic_reader.receive_batch() + batch2 = await topic_reader.receive_batch() + + assert batch1.messages[0].producer_id != batch2.messages[0].producer_id + + async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): + async with driver.topic_client.writer( + topic_path, + producer_id="test", + auto_seqno=False, + ) as writer: + last_seqno = 0 + for i in range(10): + last_seqno = i + 1 + await writer.write( + ydb.TopicWriterMessage(data=f"msg-{i}", seqno=last_seqno) + ) + + async with driver.topic_client.writer( + topic_path, + producer_id="test", + ) as writer: + init_info = await writer.wait_init() + assert init_info.last_seqno == last_seqno + + async def test_write_multi_message_with_ack( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + async with driver.topic_client.writer(topic_path) as writer: + await writer.write_with_ack( + [ + ydb.TopicWriterMessage(data="123".encode()), + ydb.TopicWriterMessage(data="456".encode()), + ] + ) + + batch = await topic_reader.receive_batch() + + assert batch.messages[0].offset == 0 + assert batch.messages[0].seqno == 1 + assert batch.messages[0].data == "123".encode() + + # remove second recieve batch when implement batching + # https://github.com/ydb-platform/ydb-python-sdk/issues/142 + batch = await topic_reader.receive_batch() + assert batch.messages[0].offset == 1 + assert batch.messages[0].seqno == 2 + assert batch.messages[0].data == "456".encode() + + @pytest.mark.parametrize( + "codec", + [ + ydb.TopicCodec.RAW, + ydb.TopicCodec.GZIP, + None, + ], + ) + async def test_write_encoded(self, driver: ydb.Driver, topic_path: str, codec): + async with driver.topic_client.writer(topic_path, codec=codec) as writer: + await writer.write("a" * 1000) + await writer.write("b" * 1000) + await writer.write("c" * 1000) + + +class TestTopicWriterSync: + def test_send_message(self, driver_sync: ydb.Driver, topic_path): + writer = driver_sync.topic_client.writer(topic_path, producer_id="test") + writer.write(ydb.TopicWriterMessage(data="123".encode())) + writer.close() + + def test_wait_last_seqno(self, driver_sync: ydb.Driver, topic_path): + with driver_sync.topic_client.writer( + topic_path, + producer_id="test", + auto_seqno=False, + ) as writer: + writer.write_with_ack(ydb.TopicWriterMessage(data="123".encode(), seqno=5)) + + with driver_sync.topic_client.writer( + topic_path, + producer_id="test", + ) as writer2: + init_info = writer2.wait_init() + assert init_info.last_seqno == 5 + + def test_auto_flush_on_close(self, driver_sync: ydb.Driver, topic_path): + with driver_sync.topic_client.writer( + topic_path, + producer_id="test", + auto_seqno=False, + ) as writer: + last_seqno = 0 + for i in range(10): + last_seqno = i + 1 + writer.write(ydb.TopicWriterMessage(data=f"msg-{i}", seqno=last_seqno)) + + with driver_sync.topic_client.writer( + topic_path, + producer_id="test", + ) as writer: + init_info = writer.wait_init() + assert init_info.last_seqno == last_seqno + + def test_random_producer_id( + self, + driver_sync: ydb.aio.Driver, + topic_path, + topic_reader_sync: ydb.TopicReader, + ): + with driver_sync.topic_client.writer(topic_path) as writer: + writer.write(ydb.TopicWriterMessage(data="123".encode())) + with driver_sync.topic_client.writer(topic_path) as writer: + writer.write(ydb.TopicWriterMessage(data="123".encode())) + + batch1 = topic_reader_sync.receive_batch() + batch2 = topic_reader_sync.receive_batch() + + assert batch1.messages[0].producer_id != batch2.messages[0].producer_id + + def test_write_multi_message_with_ack( + self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader + ): + with driver_sync.topic_client.writer(topic_path) as writer: + writer.write_with_ack( + [ + ydb.TopicWriterMessage(data="123".encode()), + ydb.TopicWriterMessage(data="456".encode()), + ] + ) + + batch = topic_reader_sync.receive_batch() + + assert batch.messages[0].offset == 0 + assert batch.messages[0].seqno == 1 + assert batch.messages[0].data == "123".encode() + + # remove second recieve batch when implement batching + # https://github.com/ydb-platform/ydb-python-sdk/issues/142 + batch = topic_reader_sync.receive_batch() + assert batch.messages[0].offset == 1 + assert batch.messages[0].seqno == 2 + assert batch.messages[0].data == "456".encode() + + @pytest.mark.parametrize( + "codec", + [ + ydb.TopicCodec.RAW, + ydb.TopicCodec.GZIP, + None, + ], + ) + def test_write_encoded(self, driver_sync: ydb.Driver, topic_path: str, codec): + with driver_sync.topic_client.writer(topic_path, codec=codec) as writer: + writer.write("a" * 1000) + writer.write("b" * 1000) + writer.write("c" * 1000) diff --git a/tox.ini b/tox.ini index f7876df5..7aca13db 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py,py-proto3,py-tls,py-tls-proto3,style,pylint,black,protoc +envlist = py,py-proto3,py-tls,py-tls-proto3,style,pylint,black,protoc,py-cov minversion = 4.2.6 skipsdist = True ignore_basepython_conflict = true @@ -25,6 +25,12 @@ deps = commands = pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} +[testenv:py-cov] +commands = + pytest -v -m "not tls" \ + --cov-report html:cov_html --cov=ydb \ + --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} + [testenv:py-proto3] commands = pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} @@ -46,12 +52,12 @@ deps = [testenv:black-format] skip_install = true commands = - black ydb examples tests --extend-exclude ydb/_grpc + black ydb examples tests --extend-exclude "ydb/_grpc/v3|ydb/_grpc/v4" [testenv:black] skip_install = true commands = - black --diff --check ydb examples tests --extend-exclude ydb/_grpc + black --diff --check ydb examples tests --extend-exclude "ydb/_grpc/v3|ydb/_grpc/v4" [testenv:pylint] deps = pylint @@ -69,3 +75,6 @@ builtins = _ max-line-length = 160 ignore=E203,W503 exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,ydb/public/api/protos/*,docs/*,ydb/public/api/grpc/*,persqueue/*,client/*,dbapi/*,ydb/default_pem.py,*docs/conf.py + +[pytest] +asyncio_mode = auto diff --git a/ydb/__init__.py b/ydb/__init__.py index 56b73478..f305fdbc 100644 --- a/ydb/__init__.py +++ b/ydb/__init__.py @@ -13,6 +13,7 @@ from .scripting import * # noqa from .import_client import * # noqa from .tracing import * # noqa +from .topic import * # noqa try: import ydb.aio as aio # noqa diff --git a/ydb/_apis.py b/ydb/_apis.py index 6f2fc3ab..27bc1bbe 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -1,13 +1,15 @@ # -*- coding: utf-8 -*- +import typing + # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4 import ( ydb_cms_v1_pb2_grpc, ydb_discovery_v1_pb2_grpc, ydb_scheme_v1_pb2_grpc, ydb_table_v1_pb2_grpc, ydb_operation_v1_pb2_grpc, + ydb_topic_v1_pb2_grpc, ) from ._grpc.v4.protos import ( @@ -26,6 +28,7 @@ ydb_scheme_v1_pb2_grpc, ydb_table_v1_pb2_grpc, ydb_operation_v1_pb2_grpc, + ydb_topic_v1_pb2_grpc, ) from ._grpc.common.protos import ( @@ -38,6 +41,7 @@ ydb_common_pb2, ) + StatusIds = ydb_status_codes_pb2.StatusIds FeatureFlag = ydb_common_pb2.FeatureFlag primitive_types = ydb_value_pb2.Type.PrimitiveTypeId @@ -95,3 +99,13 @@ class TableService(object): KeepAlive = "KeepAlive" StreamReadTable = "StreamReadTable" BulkUpsert = "BulkUpsert" + + +class TopicService(object): + Stub = ydb_topic_v1_pb2_grpc.TopicServiceStub + + CreateTopic = "CreateTopic" + DescribeTopic = "DescribeTopic" + DropTopic = "DropTopic" + StreamRead = "StreamRead" + StreamWrite = "StreamWrite" diff --git a/ydb/_dbapi/__init__.py b/ydb/_dbapi/__init__.py new file mode 100644 index 00000000..8756b0f2 --- /dev/null +++ b/ydb/_dbapi/__init__.py @@ -0,0 +1,36 @@ +from .connection import Connection +from .errors import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) + +apilevel = "1.0" + +threadsafety = 0 + +paramstyle = "pyformat" + +errors = ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) + + +def connect(*args, **kwargs): + return Connection(*args, **kwargs) diff --git a/ydb/_dbapi/connection.py b/ydb/_dbapi/connection.py new file mode 100644 index 00000000..75bfeb58 --- /dev/null +++ b/ydb/_dbapi/connection.py @@ -0,0 +1,73 @@ +import posixpath + +import ydb +from .cursor import Cursor +from .errors import DatabaseError + + +class Connection: + def __init__(self, endpoint, database=None, **conn_kwargs): + self.endpoint = endpoint + self.database = database + self.driver = self._create_driver(self.endpoint, self.database, **conn_kwargs) + self.pool = ydb.SessionPool(self.driver) + + def cursor(self): + return Cursor(self) + + def describe(self, table_path): + full_path = posixpath.join(self.database, table_path) + try: + res = self.pool.retry_operation_sync( + lambda cli: cli.describe_table(full_path) + ) + return res.columns + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + except Exception: + raise DatabaseError(f"Failed to describe table {table_path}") + + def check_exists(self, table_path): + try: + self.driver.scheme_client.describe_path(table_path) + return True + except ydb.SchemeError: + return False + + def commit(self): + pass + + def rollback(self): + pass + + def close(self): + if self.pool: + self.pool.stop() + if self.driver: + self.driver.stop() + + @staticmethod + def _create_driver(endpoint, database, **conn_kwargs): + # TODO: add cache for initialized drivers/pools? + driver_config = ydb.DriverConfig( + endpoint, + database=database, + table_client_settings=ydb.TableClientSettings() + .with_native_date_in_result_sets(True) + .with_native_datetime_in_result_sets(True) + .with_native_timestamp_in_result_sets(True) + .with_native_interval_in_result_sets(True) + .with_native_json_in_result_sets(True), + **conn_kwargs, + ) + driver = ydb.Driver(driver_config) + try: + driver.wait(timeout=5, fail_fast=True) + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + except Exception: + driver.stop() + raise DatabaseError( + f"Failed to connect to YDB, details {driver.discovery_debug_details()}" + ) + return driver diff --git a/ydb/_dbapi/cursor.py b/ydb/_dbapi/cursor.py new file mode 100644 index 00000000..57659c7a --- /dev/null +++ b/ydb/_dbapi/cursor.py @@ -0,0 +1,172 @@ +import datetime +import itertools +import logging +import uuid +import decimal + +import ydb +from .errors import DatabaseError, ProgrammingError + + +logger = logging.getLogger(__name__) + + +def get_column_type(type_obj): + return str(ydb.convert.type_to_native(type_obj)) + + +def _generate_type_str(value): + tvalue = type(value) + + stype = { + bool: "Bool", + bytes: "String", + str: "Utf8", + int: "Int64", + float: "Double", + decimal.Decimal: "Decimal(22, 9)", + datetime.date: "Date", + datetime.datetime: "Timestamp", + datetime.timedelta: "Interval", + uuid.UUID: "Uuid", + }.get(tvalue) + + if tvalue == dict: + types_lst = ", ".join(f"{k}: {_generate_type_str(v)}" for k, v in value.items()) + stype = f"Struct<{types_lst}>" + + elif tvalue == tuple: + types_lst = ", ".join(_generate_type_str(x) for x in value) + stype = f"Tuple<{types_lst}>" + + elif tvalue == list: + nested_type = _generate_type_str(value[0]) + stype = f"List<{nested_type}>" + + elif tvalue == set: + nested_type = _generate_type_str(next(iter(value))) + stype = f"Set<{nested_type}>" + + if stype is None: + raise ProgrammingError( + "Cannot translate python type to ydb type.", tvalue, value + ) + + return stype + + +def _generate_declare_stms(params: dict) -> str: + return "".join( + f"DECLARE {k} AS {_generate_type_str(t)}; " for k, t in params.items() + ) + + +class Cursor(object): + def __init__(self, connection): + self.connection = connection + self.description = None + self.arraysize = 1 + self.rows = None + self._rows_prefetched = None + + def execute(self, sql, parameters=None, context=None): + self.description = None + sql_params = None + + if parameters: + sql = sql % {k: f"${k}" for k, v in parameters.items()} + sql_params = {f"${k}": v for k, v in parameters.items()} + declare_stms = _generate_declare_stms(sql_params) + sql = f"{declare_stms}{sql}" + + logger.info("execute sql: %s, params: %s", sql, sql_params) + + def _execute_in_pool(cli): + try: + if context and context.get("isddl"): + return cli.execute_scheme(sql) + else: + prepared_query = cli.prepare(sql) + return cli.transaction().execute( + prepared_query, sql_params, commit_tx=True + ) + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + + chunks = self.connection.pool.retry_operation_sync(_execute_in_pool) + rows = self._rows_iterable(chunks) + # Prefetch the description: + try: + first_row = next(rows) + except StopIteration: + pass + else: + rows = itertools.chain((first_row,), rows) + if self.rows is not None: + rows = itertools.chain(self.rows, rows) + + self.rows = rows + + def _rows_iterable(self, chunks_iterable): + try: + for chunk in chunks_iterable: + self.description = [ + ( + col.name, + get_column_type(col.type), + None, + None, + None, + None, + None, + ) + for col in chunk.columns + ] + for row in chunk.rows: + # returns tuple to be compatible with SqlAlchemy and because + # of this PEP to return a sequence: https://www.python.org/dev/peps/pep-0249/#fetchmany + yield row[::] + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) + + def _ensure_prefetched(self): + if self.rows is not None and self._rows_prefetched is None: + self._rows_prefetched = list(self.rows) + self.rows = iter(self._rows_prefetched) + return self._rows_prefetched + + def executemany(self, sql, seq_of_parameters): + for parameters in seq_of_parameters: + self.execute(sql, parameters) + + def executescript(self, script): + return self.execute(script) + + def fetchone(self): + if self.rows is None: + return None + return next(self.rows, None) + + def fetchmany(self, size=None): + size = self.arraysize if size is None else size + return list(itertools.islice(self.rows, size)) + + def fetchall(self): + return list(self.rows) + + def nextset(self): + self.fetchall() + + def setinputsizes(self, sizes): + pass + + def setoutputsize(self, column=None): + pass + + def close(self): + self.rows = None + self._rows_prefetched = None + + @property + def rowcount(self): + return len(self._ensure_prefetched()) diff --git a/ydb/_dbapi/errors.py b/ydb/_dbapi/errors.py new file mode 100644 index 00000000..ddb55b4c --- /dev/null +++ b/ydb/_dbapi/errors.py @@ -0,0 +1,92 @@ +class Warning(Exception): + pass + + +class Error(Exception): + def __init__(self, message, issues=None, status=None): + super(Error, self).__init__(message) + + pretty_issues = _pretty_issues(issues) + self.issues = issues + self.message = pretty_issues or message + self.status = status + + +class InterfaceError(Error): + pass + + +class DatabaseError(Error): + pass + + +class DataError(DatabaseError): + pass + + +class OperationalError(DatabaseError): + pass + + +class IntegrityError(DatabaseError): + pass + + +class InternalError(DatabaseError): + pass + + +class ProgrammingError(DatabaseError): + pass + + +class NotSupportedError(DatabaseError): + pass + + +def _pretty_issues(issues): + if issues is None: + return None + + children_messages = [_get_messages(issue, root=True) for issue in issues] + + if None in children_messages: + return None + + return "\n" + "\n".join(children_messages) + + +def _get_messages(issue, max_depth=100, indent=2, depth=0, root=False): + if depth >= max_depth: + return None + + margin_str = " " * depth * indent + pre_message = "" + children = "" + + if issue.issues: + collapsed_messages = [] + while not root and len(issue.issues) == 1: + collapsed_messages.append(issue.message) + issue = issue.issues[0] + + if collapsed_messages: + pre_message = f"{margin_str}{', '.join(collapsed_messages)}\n" + depth += 1 + margin_str = " " * depth * indent + + children_messages = [ + _get_messages(iss, max_depth=max_depth, indent=indent, depth=depth + 1) + for iss in issue.issues + ] + + if None in children_messages: + return None + + children = "\n".join(children_messages) + + return ( + f"{pre_message}{margin_str}{issue.message}\n{margin_str}" + f"severity level: {issue.severity}\n{margin_str}" + f"issue code: {issue.issue_code}\n{children}" + ) diff --git a/ydb/_grpc/common/__init__.py b/ydb/_grpc/common/__init__.py index 10138358..1a077800 100644 --- a/ydb/_grpc/common/__init__.py +++ b/ydb/_grpc/common/__init__.py @@ -1,4 +1,5 @@ import sys +import importlib.util import google.protobuf from packaging.version import Version @@ -8,13 +9,30 @@ # sdk code must always import from ydb._grpc.common protobuf_version = Version(google.protobuf.__version__) -if protobuf_version < Version("4.0"): - from ydb._grpc.v3 import * # noqa - from ydb._grpc.v3 import protos # noqa - sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v3"] - sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v3.protos"] +# for compatible with arcadia +if importlib.util.find_spec("ydb.public.api"): + from ydb.public.api.grpc import * # noqa + + sys.modules["ydb._grpc.common"] = sys.modules["ydb.public.api.grpc"] + + from ydb.public.api import protos + + sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb.public.api.protos"] else: - from ydb._grpc.v4 import * # noqa - from ydb._grpc.v4 import protos # noqa - sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v4"] - sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v4.protos"] + # common way, outside of arcadia + if protobuf_version < Version("4.0"): + from ydb._grpc.v3 import * # noqa + + sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v3"] + + from ydb._grpc.v3 import protos # noqa + + sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v3.protos"] + else: + from ydb._grpc.v4 import * # noqa + + sys.modules["ydb._grpc.common"] = sys.modules["ydb._grpc.v4"] + + from ydb._grpc.v4 import protos # noqa + + sys.modules["ydb._grpc.common.protos"] = sys.modules["ydb._grpc.v4.protos"] diff --git a/kikimr/__init__.py b/ydb/_grpc/grpcwrapper/__init__.py similarity index 100% rename from kikimr/__init__.py rename to ydb/_grpc/grpcwrapper/__init__.py diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py new file mode 100644 index 00000000..6c624520 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import abc +import asyncio +import contextvars +import datetime +import functools +import typing +from typing import ( + Optional, + Any, + Iterator, + AsyncIterator, + Callable, + Iterable, + Union, + Coroutine, +) +from dataclasses import dataclass + +import grpc +from google.protobuf.message import Message +from google.protobuf.duration_pb2 import Duration as ProtoDuration +from google.protobuf.timestamp_pb2 import Timestamp as ProtoTimeStamp + +import ydb.aio + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_topic_pb2, ydb_issue_message_pb2 +else: + from ..common.protos import ydb_topic_pb2, ydb_issue_message_pb2 + +from ... import issues, connection + + +class IFromProto(abc.ABC): + @staticmethod + @abc.abstractmethod + def from_proto(msg: Message) -> Any: + ... + + +class IFromProtoWithProtoType(IFromProto): + @staticmethod + @abc.abstractmethod + def empty_proto_message() -> Message: + ... + + +class IToProto(abc.ABC): + @abc.abstractmethod + def to_proto(self) -> Message: + ... + + +class IFromPublic(abc.ABC): + @staticmethod + @abc.abstractmethod + def from_public(o: typing.Any) -> typing.Any: + ... + + +class IToPublic(abc.ABC): + @abc.abstractmethod + def to_public(self) -> typing.Any: + ... + + +class UnknownGrpcMessageError(issues.Error): + pass + + +_stop_grpc_connection_marker = object() + + +class QueueToIteratorAsyncIO: + __slots__ = ("_queue",) + + def __init__(self, q: asyncio.Queue): + self._queue = q + + def __aiter__(self): + return self + + async def __anext__(self): + item = await self._queue.get() + if item is _stop_grpc_connection_marker: + raise StopAsyncIteration() + return item + + +class AsyncQueueToSyncIteratorAsyncIO: + __slots__ = ( + "_loop", + "_queue", + ) + _queue: asyncio.Queue + + def __init__(self, q: asyncio.Queue): + self._loop = asyncio.get_running_loop() + self._queue = q + + def __iter__(self): + return self + + def __next__(self): + item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result() + if item is _stop_grpc_connection_marker: + raise StopIteration() + return item + + +class SyncIteratorToAsyncIterator: + def __init__(self, sync_iterator: Iterator): + self._sync_iterator = sync_iterator + + def __aiter__(self): + return self + + async def __anext__(self): + try: + res = await to_thread(self._sync_iterator.__next__) + return res + except StopAsyncIteration: + raise StopIteration() + + +class IGrpcWrapperAsyncIO(abc.ABC): + @abc.abstractmethod + async def receive(self) -> Any: + ... + + @abc.abstractmethod + def write(self, wrap_message: IToProto): + ... + + @abc.abstractmethod + def close(self): + ... + + +SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver] + + +class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): + from_client_grpc: asyncio.Queue + from_server_grpc: AsyncIterator + convert_server_grpc_to_wrapper: Callable[[Any], Any] + _connection_state: str + _stream_call: Optional[ + Union[grpc.aio.StreamStreamCall, "grpc._channel._MultiThreadedRendezvous"] + ] + + def __init__(self, convert_server_grpc_to_wrapper): + self.from_client_grpc = asyncio.Queue() + self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper + self._connection_state = "new" + self._stream_call = None + + async def start(self, driver: SupportedDriverType, stub, method): + if asyncio.iscoroutinefunction(driver.__call__): + await self._start_asyncio_driver(driver, stub, method) + else: + await self._start_sync_driver(driver, stub, method) + self._connection_state = "started" + + def close(self): + self.from_client_grpc.put_nowait(_stop_grpc_connection_marker) + if self._stream_call: + self._stream_call.cancel() + + async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): + requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) + stream_call = await driver( + requests_iterator, + stub, + method, + ) + self._stream_call = stream_call + self.from_server_grpc = stream_call.__aiter__() + + async def _start_sync_driver(self, driver: ydb.Driver, stub, method): + requests_iterator = AsyncQueueToSyncIteratorAsyncIO(self.from_client_grpc) + stream_call = await to_thread( + driver, + requests_iterator, + stub, + method, + ) + self._stream_call = stream_call + self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__()) + + async def receive(self) -> Any: + # todo handle grpc exceptions and convert it to internal exceptions + try: + grpc_message = await self.from_server_grpc.__anext__() + except grpc.RpcError as e: + raise connection._rpc_error_handler(self._connection_state, e) + + issues._process_response(grpc_message) + + if self._connection_state != "has_received_messages": + self._connection_state = "has_received_messages" + + # print("rekby, grpc, received", grpc_message) + return self.convert_server_grpc_to_wrapper(grpc_message) + + def write(self, wrap_message: IToProto): + grpc_message = wrap_message.to_proto() + # print("rekby, grpc, send", grpc_message) + self.from_client_grpc.put_nowait(grpc_message) + + +@dataclass(init=False) +class ServerStatus(IFromProto): + __slots__ = ("_grpc_status_code", "_issues") + + def __init__( + self, + status: issues.StatusCode, + issues: Iterable[Any], + ): + self.status = status + self.issues = issues + + def __str__(self): + return self.__repr__() + + @staticmethod + def from_proto( + msg: Union[ + ydb_topic_pb2.StreamReadMessage.FromServer, + ydb_topic_pb2.StreamWriteMessage.FromServer, + ] + ) -> "ServerStatus": + return ServerStatus(msg.status, msg.issues) + + def is_success(self) -> bool: + return self.status == issues.StatusCode.SUCCESS + + @classmethod + def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage): + res = """code: %s message: "%s" """ % (issue.issue_code, issue.message) + if len(issue.issues) > 0: + d = ", " + res += d + d.join(str(sub_issue) for sub_issue in issue.issues) + return res + + +def callback_from_asyncio( + callback: Union[Callable, Coroutine] +) -> [asyncio.Future, asyncio.Task]: + loop = asyncio.get_running_loop() + + if asyncio.iscoroutinefunction(callback): + return loop.create_task(callback()) + else: + return loop.run_in_executor(None, callback) + + +async def to_thread(func, /, *args, **kwargs): + """Asynchronously run function *func* in a separate thread. + + Any *args and **kwargs supplied for this function are directly passed + to *func*. Also, the current :class:`contextvars.Context` is propagated, + allowing context variables from the main thread to be accessed in the + separate thread. + + Return a coroutine that can be awaited to get the eventual result of *func*. + + copy to_thread from 3.10 + """ + + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + +def proto_duration_from_timedelta(t: Optional[datetime.timedelta]) -> ProtoDuration: + if t is None: + return None + res = ProtoDuration() + res.FromTimedelta(t) + + +def proto_timestamp_from_datetime(t: Optional[datetime.datetime]) -> ProtoTimeStamp: + if t is None: + return None + + res = ProtoTimeStamp() + res.FromDatetime(t) + + +def datetime_from_proto_timestamp( + ts: Optional[ProtoTimeStamp], +) -> Optional[datetime.datetime]: + if ts is None: + return None + return ts.ToDatetime() + + +def timedelta_from_proto_duration( + d: Optional[ProtoDuration], +) -> Optional[datetime.timedelta]: + if d is None: + return None + return d.ToTimedelta() diff --git a/ydb/_grpc/grpcwrapper/ydb_scheme.py b/ydb/_grpc/grpcwrapper/ydb_scheme.py new file mode 100644 index 00000000..b9922035 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_scheme.py @@ -0,0 +1,36 @@ +import datetime +import enum +from dataclasses import dataclass +from typing import List + + +@dataclass +class Entry: + name: str + owner: str + type: "Entry.Type" + effective_permissions: "Permissions" + permissions: "Permissions" + size_bytes: int + created_at: datetime.datetime + + class Type(enum.IntEnum): + UNSPECIFIED = 0 + DIRECTORY = 1 + TABLE = 2 + PERS_QUEUE_GROUP = 3 + DATABASE = 4 + RTMR_VOLUME = 5 + BLOCK_STORE_VOLUME = 6 + COORDINATION_NODE = 7 + COLUMN_STORE = 12 + COLUMN_TABLE = 13 + SEQUENCE = 15 + REPLICATION = 16 + TOPIC = 17 + + +@dataclass +class Permissions: + subject: str + permission_names: List[str] diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py new file mode 100644 index 00000000..4784d486 --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -0,0 +1,1164 @@ +import datetime +import enum +import typing +from dataclasses import dataclass, field +from typing import List, Union, Dict, Optional + +from google.protobuf.message import Message + +from . import ydb_topic_public_types +from ... import scheme + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_scheme_pb2, ydb_topic_pb2 +else: + from ..common.protos import ydb_scheme_pb2, ydb_topic_pb2 + +from .common_utils import ( + IFromProto, + IFromProtoWithProtoType, + IToProto, + IToPublic, + IFromPublic, + ServerStatus, + UnknownGrpcMessageError, + proto_duration_from_timedelta, + proto_timestamp_from_datetime, + datetime_from_proto_timestamp, + timedelta_from_proto_duration, +) + + +class Codec(int, IToPublic): + CODEC_UNSPECIFIED = 0 + CODEC_RAW = 1 + CODEC_GZIP = 2 + CODEC_LZOP = 3 + CODEC_ZSTD = 4 + + @staticmethod + def from_proto_iterable(codecs: typing.Iterable[int]) -> List["Codec"]: + return [Codec(int(codec)) for codec in codecs] + + def to_public(self) -> ydb_topic_public_types.PublicCodec: + return ydb_topic_public_types.PublicCodec(int(self)) + + +@dataclass +class SupportedCodecs(IToProto, IFromProto, IToPublic): + codecs: List[Codec] + + def to_proto(self) -> ydb_topic_pb2.SupportedCodecs: + return ydb_topic_pb2.SupportedCodecs( + codecs=self.codecs, + ) + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.SupportedCodecs]) -> "SupportedCodecs": + if msg is None: + return SupportedCodecs(codecs=[]) + + return SupportedCodecs( + codecs=Codec.from_proto_iterable(msg.codecs), + ) + + def to_public(self) -> List[ydb_topic_public_types.PublicCodec]: + return list(map(Codec.to_public, self.codecs)) + + +@dataclass(order=True) +class OffsetsRange(IFromProto, IToProto): + """ + half-opened interval, include [start, end) offsets + """ + + __slots__ = ("start", "end") + + start: int # first offset + end: int # offset after last, included to range + + def __post_init__(self): + if self.end < self.start: + raise ValueError( + "offset end must be not less then start. Got start=%s end=%s" + % (self.start, self.end) + ) + + @staticmethod + def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange": + return OffsetsRange( + start=msg.start, + end=msg.end, + ) + + def to_proto(self) -> ydb_topic_pb2.OffsetsRange: + return ydb_topic_pb2.OffsetsRange( + start=self.start, + end=self.end, + ) + + def is_intersected_with(self, other: "OffsetsRange") -> bool: + return ( + self.start <= other.start < self.end + or self.start < other.end <= self.end + or other.start <= self.start < other.end + or other.start < self.end <= other.end + ) + + +@dataclass +class UpdateTokenRequest(IToProto): + token: str + + def to_proto(self) -> Message: + res = ydb_topic_pb2.UpdateTokenRequest() + res.token = self.token + return res + + +@dataclass +class UpdateTokenResponse(IFromProto): + @staticmethod + def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: + return UpdateTokenResponse() + + +######################################################################################################################## +# StreamWrite +######################################################################################################################## + + +class StreamWriteMessage: + @dataclass() + class InitRequest(IToProto): + path: str + producer_id: str + write_session_meta: typing.Dict[str, str] + partitioning: "StreamWriteMessage.PartitioningType" + get_last_seq_no: bool + + def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.InitRequest: + proto = ydb_topic_pb2.StreamWriteMessage.InitRequest() + proto.path = self.path + proto.producer_id = self.producer_id + + if self.partitioning is None: + pass + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningMessageGroupID + ): + proto.message_group_id = self.partitioning.message_group_id + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningPartitionID + ): + proto.partition_id = self.partitioning.partition_id + else: + raise Exception( + "Bad partitioning type at StreamWriteMessage.InitRequest" + ) + + if self.write_session_meta: + for key in self.write_session_meta: + proto.write_session_meta[key] = self.write_session_meta[key] + + proto.get_last_seq_no = self.get_last_seq_no + return proto + + @dataclass + class InitResponse(IFromProto): + last_seq_no: Union[int, None] + session_id: str + partition_id: int + supported_codecs: typing.List[int] + status: ServerStatus = None + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamWriteMessage.InitResponse, + ) -> "StreamWriteMessage.InitResponse": + codecs = [] # type: typing.List[int] + if msg.supported_codecs: + for codec in msg.supported_codecs.codecs: + codecs.append(codec) + + return StreamWriteMessage.InitResponse( + last_seq_no=msg.last_seq_no, + session_id=msg.session_id, + partition_id=msg.partition_id, + supported_codecs=codecs, + ) + + @dataclass + class WriteRequest(IToProto): + messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"] + codec: int + + @dataclass + class MessageData(IToProto): + seq_no: int + created_at: datetime.datetime + data: bytes + uncompressed_size: int + partitioning: "StreamWriteMessage.PartitioningType" + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData: + proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest.MessageData() + proto.seq_no = self.seq_no + proto.created_at.FromDatetime(self.created_at) + proto.data = self.data + proto.uncompressed_size = self.uncompressed_size + + if self.partitioning is None: + pass + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningPartitionID + ): + proto.partition_id = self.partitioning.partition_id + elif isinstance( + self.partitioning, StreamWriteMessage.PartitioningMessageGroupID + ): + proto.message_group_id = self.partitioning.message_group_id + else: + raise Exception( + "Bad partition at StreamWriteMessage.WriteRequest.MessageData" + ) + + return proto + + def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest: + proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest() + proto.codec = self.codec + + for message in self.messages: + proto_mess = proto.messages.add() + proto_mess.CopyFrom(message.to_proto()) + + return proto + + @dataclass + class WriteResponse(IFromProto): + partition_id: int + acks: typing.List["StreamWriteMessage.WriteResponse.WriteAck"] + write_statistics: "StreamWriteMessage.WriteResponse.WriteStatistics" + status: Optional[ServerStatus] = field(default=None) + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamWriteMessage.WriteResponse, + ) -> "StreamWriteMessage.WriteResponse": + acks = [] + for proto_ack in msg.acks: + ack = StreamWriteMessage.WriteResponse.WriteAck.from_proto(proto_ack) + acks.append(ack) + write_statistics = StreamWriteMessage.WriteResponse.WriteStatistics( + persisting_time=msg.write_statistics.persisting_time.ToTimedelta(), + min_queue_wait_time=msg.write_statistics.min_queue_wait_time.ToTimedelta(), + max_queue_wait_time=msg.write_statistics.max_queue_wait_time.ToTimedelta(), + partition_quota_wait_time=msg.write_statistics.partition_quota_wait_time.ToTimedelta(), + topic_quota_wait_time=msg.write_statistics.topic_quota_wait_time.ToTimedelta(), + ) + return StreamWriteMessage.WriteResponse( + partition_id=msg.partition_id, + acks=acks, + write_statistics=write_statistics, + status=None, + ) + + @dataclass + class WriteAck(IFromProto): + seq_no: int + message_write_status: Union[ + "StreamWriteMessage.WriteResponse.WriteAck.StatusWritten", + "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped", + int, + ] + + @classmethod + def from_proto( + cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.WriteAck + ): + if proto_ack.HasField("written"): + message_write_status = ( + StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( + proto_ack.written.offset + ) + ) + elif proto_ack.HasField("skipped"): + reason = proto_ack.skipped.reason + try: + message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped( + reason=StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason.from_protobuf_code( + reason + ) + ) + except ValueError: + message_write_status = reason + else: + raise NotImplementedError("unexpected ack status") + + return StreamWriteMessage.WriteResponse.WriteAck( + seq_no=proto_ack.seq_no, + message_write_status=message_write_status, + ) + + @dataclass + class StatusWritten: + offset: int + + @dataclass + class StatusSkipped: + reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason" + + class Reason(enum.Enum): + UNSPECIFIED = 0 + ALREADY_WRITTEN = 1 + + @classmethod + def from_protobuf_code( + cls, code: int + ) -> Union[ + "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason", + int, + ]: + try: + return StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason( + code + ) + except ValueError: + return code + + @dataclass + class WriteStatistics: + persisting_time: datetime.timedelta + min_queue_wait_time: datetime.timedelta + max_queue_wait_time: datetime.timedelta + partition_quota_wait_time: datetime.timedelta + topic_quota_wait_time: datetime.timedelta + + @dataclass + class PartitioningMessageGroupID: + message_group_id: str + + @dataclass + class PartitioningPartitionID: + partition_id: int + + PartitioningType = Union[PartitioningMessageGroupID, PartitioningPartitionID, None] + + @dataclass + class FromClient(IToProto): + value: "WriterMessagesFromClientToServer" + + def __init__(self, value: "WriterMessagesFromClientToServer"): + self.value = value + + def to_proto(self) -> Message: + res = ydb_topic_pb2.StreamWriteMessage.FromClient() + value = self.value + if isinstance(value, StreamWriteMessage.WriteRequest): + res.write_request.CopyFrom(value.to_proto()) + elif isinstance(value, StreamWriteMessage.InitRequest): + res.init_request.CopyFrom(value.to_proto()) + elif isinstance(value, UpdateTokenRequest): + res.update_token_request.CopyFrom(value.to_proto()) + else: + raise Exception("Unknown outcoming grpc message: %s" % value) + return res + + class FromServer(IFromProto): + @staticmethod + def from_proto(msg: ydb_topic_pb2.StreamWriteMessage.FromServer) -> typing.Any: + message_type = msg.WhichOneof("server_message") + if message_type == "write_response": + res = StreamWriteMessage.WriteResponse.from_proto(msg.write_response) + elif message_type == "init_response": + res = StreamWriteMessage.InitResponse.from_proto(msg.init_response) + elif message_type == "update_token_response": + res = UpdateTokenResponse.from_proto(msg.update_token_response) + else: + # todo log instead of exception - for allow add messages in the future + raise UnknownGrpcMessageError("Unexpected proto message: %s" % msg) + + res.status = ServerStatus(msg.status, msg.issues) + return res + + +WriterMessagesFromClientToServer = Union[ + StreamWriteMessage.InitRequest, StreamWriteMessage.WriteRequest, UpdateTokenRequest +] +WriterMessagesFromServerToClient = Union[ + StreamWriteMessage.InitResponse, + StreamWriteMessage.WriteResponse, + UpdateTokenResponse, +] + + +######################################################################################################################## +# StreamRead +######################################################################################################################## + + +class StreamReadMessage: + @dataclass + class PartitionSession(IFromProto): + partition_session_id: int + path: str + partition_id: int + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.PartitionSession, + ) -> "StreamReadMessage.PartitionSession": + return StreamReadMessage.PartitionSession( + partition_session_id=msg.partition_session_id, + path=msg.path, + partition_id=msg.partition_id, + ) + + @dataclass + class InitRequest(IToProto): + topics_read_settings: List["StreamReadMessage.InitRequest.TopicReadSettings"] + consumer: str + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.InitRequest: + res = ydb_topic_pb2.StreamReadMessage.InitRequest() + res.consumer = self.consumer + for settings in self.topics_read_settings: + res.topics_read_settings.append(settings.to_proto()) + return res + + @dataclass + class TopicReadSettings(IToProto): + path: str + partition_ids: List[int] = field(default_factory=list) + max_lag_seconds: Union[datetime.timedelta, None] = None + read_from: Union[int, float, datetime.datetime, None] = None + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings: + res = ydb_topic_pb2.StreamReadMessage.InitRequest.TopicReadSettings() + res.path = self.path + res.partition_ids.extend(self.partition_ids) + if self.max_lag_seconds is not None: + res.max_lag = proto_duration_from_timedelta(self.max_lag_seconds) + return res + + @dataclass + class InitResponse(IFromProto): + session_id: str + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.InitResponse, + ) -> "StreamReadMessage.InitResponse": + return StreamReadMessage.InitResponse(session_id=msg.session_id) + + @dataclass + class ReadRequest(IToProto): + bytes_size: int + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.ReadRequest: + res = ydb_topic_pb2.StreamReadMessage.ReadRequest() + res.bytes_size = self.bytes_size + return res + + @dataclass + class ReadResponse(IFromProto): + partition_data: List["StreamReadMessage.ReadResponse.PartitionData"] + bytes_size: int + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse, + ) -> "StreamReadMessage.ReadResponse": + partition_data = [] + for proto_partition_data in msg.partition_data: + partition_data.append( + StreamReadMessage.ReadResponse.PartitionData.from_proto( + proto_partition_data + ) + ) + return StreamReadMessage.ReadResponse( + partition_data=partition_data, + bytes_size=msg.bytes_size, + ) + + @dataclass + class MessageData(IFromProto): + offset: int + seq_no: int + created_at: datetime.datetime + data: bytes + uncompresed_size: int + message_group_id: str + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData, + ) -> "StreamReadMessage.ReadResponse.MessageData": + return StreamReadMessage.ReadResponse.MessageData( + offset=msg.offset, + seq_no=msg.seq_no, + created_at=msg.created_at.ToDatetime(), + data=msg.data, + uncompresed_size=msg.uncompressed_size, + message_group_id=msg.message_group_id, + ) + + @dataclass + class Batch(IFromProto): + message_data: List["StreamReadMessage.ReadResponse.MessageData"] + producer_id: str + write_session_meta: Dict[str, str] + codec: int + written_at: datetime.datetime + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch, + ) -> "StreamReadMessage.ReadResponse.Batch": + message_data = [] + for message in msg.message_data: + message_data.append( + StreamReadMessage.ReadResponse.MessageData.from_proto(message) + ) + return StreamReadMessage.ReadResponse.Batch( + message_data=message_data, + producer_id=msg.producer_id, + write_session_meta=dict(msg.write_session_meta), + codec=msg.codec, + written_at=msg.written_at.ToDatetime(), + ) + + @dataclass + class PartitionData(IFromProto): + partition_session_id: int + batches: List["StreamReadMessage.ReadResponse.Batch"] + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.PartitionData, + ) -> "StreamReadMessage.ReadResponse.PartitionData": + batches = [] + for proto_batch in msg.batches: + batches.append( + StreamReadMessage.ReadResponse.Batch.from_proto(proto_batch) + ) + return StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=msg.partition_session_id, + batches=batches, + ) + + @dataclass + class CommitOffsetRequest(IToProto): + commit_offsets: List["PartitionCommitOffset"] + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest: + res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest( + commit_offsets=list( + map( + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset.to_proto, + self.commit_offsets, + ) + ), + ) + return res + + @dataclass + class PartitionCommitOffset(IToProto): + partition_session_id: int + offsets: List["OffsetsRange"] + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset: + res = ydb_topic_pb2.StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=self.partition_session_id, + offsets=list(map(OffsetsRange.to_proto, self.offsets)), + ) + return res + + @dataclass + class CommitOffsetResponse(IFromProto): + partitions_committed_offsets: List[ + "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset" + ] + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse, + ) -> "StreamReadMessage.CommitOffsetResponse": + return StreamReadMessage.CommitOffsetResponse( + partitions_committed_offsets=list( + map( + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset.from_proto, + msg.partitions_committed_offsets, + ) + ) + ) + + @dataclass + class PartitionCommittedOffset(IFromProto): + partition_session_id: int + committed_offset: int + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset, + ) -> "StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset": + return StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=msg.partition_session_id, + committed_offset=msg.committed_offset, + ) + + @dataclass + class PartitionSessionStatusRequest: + partition_session_id: int + + @dataclass + class PartitionSessionStatusResponse: + partition_session_id: int + partition_offsets: "OffsetsRange" + committed_offset: int + write_time_high_watermark: float + + @dataclass + class StartPartitionSessionRequest(IFromProto): + partition_session: "StreamReadMessage.PartitionSession" + committed_offset: int + partition_offsets: "OffsetsRange" + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.StartPartitionSessionRequest, + ) -> "StreamReadMessage.StartPartitionSessionRequest": + return StreamReadMessage.StartPartitionSessionRequest( + partition_session=StreamReadMessage.PartitionSession.from_proto( + msg.partition_session + ), + committed_offset=msg.committed_offset, + partition_offsets=OffsetsRange.from_proto(msg.partition_offsets), + ) + + @dataclass + class StartPartitionSessionResponse(IToProto): + partition_session_id: int + read_offset: Optional[int] + commit_offset: Optional[int] + + def to_proto( + self, + ) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse: + res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse() + res.partition_session_id = self.partition_session_id + if self.read_offset is not None: + res.read_offset = self.read_offset + if self.commit_offset is not None: + res.commit_offset = self.commit_offset + return res + + @dataclass + class StopPartitionSessionRequest: + partition_session_id: int + graceful: bool + committed_offset: int + + @dataclass + class StopPartitionSessionResponse: + partition_session_id: int + + @dataclass + class FromClient(IToProto): + client_message: "ReaderMessagesFromClientToServer" + + def __init__(self, client_message: "ReaderMessagesFromClientToServer"): + self.client_message = client_message + + def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient: + res = ydb_topic_pb2.StreamReadMessage.FromClient() + if isinstance(self.client_message, StreamReadMessage.ReadRequest): + res.read_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.CommitOffsetRequest): + res.commit_offset_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, StreamReadMessage.InitRequest): + res.init_request.CopyFrom(self.client_message.to_proto()) + elif isinstance(self.client_message, UpdateTokenRequest): + res.update_token_request.CopyFrom(self.client_message.to_proto()) + elif isinstance( + self.client_message, StreamReadMessage.StartPartitionSessionResponse + ): + res.start_partition_session_response.CopyFrom( + self.client_message.to_proto() + ) + else: + raise NotImplementedError( + "Unknown message type: %s" % type(self.client_message) + ) + return res + + @dataclass + class FromServer(IFromProto): + server_message: "ReaderMessagesFromServerToClient" + server_status: ServerStatus + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.StreamReadMessage.FromServer, + ) -> "StreamReadMessage.FromServer": + mess_type = msg.WhichOneof("server_message") + server_status = ServerStatus.from_proto(msg) + if mess_type == "read_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.ReadResponse.from_proto( + msg.read_response + ), + ) + elif mess_type == "commit_offset_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.CommitOffsetResponse.from_proto( + msg.commit_offset_response + ), + ) + elif mess_type == "init_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.InitResponse.from_proto( + msg.init_response + ), + ) + elif mess_type == "start_partition_session_request": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto( + msg.start_partition_session_request + ), + ) + elif mess_type == "update_token_response": + return StreamReadMessage.FromServer( + server_status=server_status, + server_message=UpdateTokenResponse.from_proto( + msg.update_token_response + ), + ) + + # todo replace exception to log + raise NotImplementedError() + + +ReaderMessagesFromClientToServer = Union[ + StreamReadMessage.InitRequest, + StreamReadMessage.ReadRequest, + StreamReadMessage.CommitOffsetRequest, + StreamReadMessage.PartitionSessionStatusRequest, + UpdateTokenRequest, + StreamReadMessage.StartPartitionSessionResponse, + StreamReadMessage.StopPartitionSessionResponse, +] + +ReaderMessagesFromServerToClient = Union[ + StreamReadMessage.InitResponse, + StreamReadMessage.ReadResponse, + StreamReadMessage.CommitOffsetResponse, + StreamReadMessage.PartitionSessionStatusResponse, + UpdateTokenResponse, + StreamReadMessage.StartPartitionSessionRequest, + StreamReadMessage.StopPartitionSessionRequest, +] + + +@dataclass +class MultipleWindowsStat(IFromProto, IToPublic): + per_minute: int + per_hour: int + per_day: int + + @staticmethod + def from_proto( + msg: Optional[ydb_topic_pb2.MultipleWindowsStat], + ) -> Optional["MultipleWindowsStat"]: + if msg is None: + return None + return MultipleWindowsStat( + per_minute=msg.per_minute, + per_hour=msg.per_hour, + per_day=msg.per_day, + ) + + def to_public(self) -> ydb_topic_public_types.PublicMultipleWindowsStat: + return ydb_topic_public_types.PublicMultipleWindowsStat( + per_minute=self.per_minute, + per_hour=self.per_hour, + per_day=self.per_day, + ) + + +@dataclass +class Consumer(IToProto, IFromProto, IFromPublic, IToPublic): + name: str + important: bool + read_from: typing.Optional[datetime.datetime] + supported_codecs: SupportedCodecs + attributes: Dict[str, str] + consumer_stats: typing.Optional["Consumer.ConsumerStats"] + + def to_proto(self) -> ydb_topic_pb2.Consumer: + return ydb_topic_pb2.Consumer( + name=self.name, + important=self.important, + read_from=proto_timestamp_from_datetime(self.read_from), + supported_codecs=self.supported_codecs.to_proto(), + attributes=self.attributes, + # consumer_stats - readonly field + ) + + @staticmethod + def from_proto(msg: Optional[ydb_topic_pb2.Consumer]) -> Optional["Consumer"]: + return Consumer( + name=msg.name, + important=msg.important, + read_from=datetime_from_proto_timestamp(msg.read_from), + supported_codecs=SupportedCodecs.from_proto(msg.supported_codecs), + attributes=dict(msg.attributes), + consumer_stats=Consumer.ConsumerStats.from_proto(msg.consumer_stats), + ) + + @staticmethod + def from_public(consumer: ydb_topic_public_types.PublicConsumer): + if consumer is None: + return None + + supported_codecs = [] + if consumer.supported_codecs is not None: + supported_codecs = consumer.supported_codecs + + return Consumer( + name=consumer.name, + important=consumer.important, + read_from=consumer.read_from, + supported_codecs=SupportedCodecs(codecs=supported_codecs), + attributes=consumer.attributes, + consumer_stats=None, + ) + + def to_public(self) -> ydb_topic_public_types.PublicConsumer: + return ydb_topic_public_types.PublicConsumer( + name=self.name, + important=self.important, + read_from=self.read_from, + supported_codecs=self.supported_codecs.to_public(), + attributes=self.attributes, + ) + + @dataclass + class ConsumerStats(IFromProto): + min_partitions_last_read_time: datetime.datetime + max_read_time_lag: datetime.timedelta + max_write_time_lag: datetime.timedelta + bytes_read: MultipleWindowsStat + + @staticmethod + def from_proto( + msg: ydb_topic_pb2.Consumer.ConsumerStats, + ) -> "Consumer.ConsumerStats": + return Consumer.ConsumerStats( + min_partitions_last_read_time=datetime_from_proto_timestamp( + msg.min_partitions_last_read_time + ), + max_read_time_lag=timedelta_from_proto_duration(msg.max_read_time_lag), + max_write_time_lag=timedelta_from_proto_duration( + msg.max_write_time_lag + ), + bytes_read=MultipleWindowsStat.from_proto(msg.bytes_read), + ) + + +@dataclass +class PartitioningSettings(IToProto, IFromProto): + min_active_partitions: int + partition_count_limit: int + + @staticmethod + def from_proto(msg: ydb_topic_pb2.PartitioningSettings) -> "PartitioningSettings": + return PartitioningSettings( + min_active_partitions=msg.min_active_partitions, + partition_count_limit=msg.partition_count_limit, + ) + + def to_proto(self) -> ydb_topic_pb2.PartitioningSettings: + return ydb_topic_pb2.PartitioningSettings( + min_active_partitions=self.min_active_partitions, + partition_count_limit=self.partition_count_limit, + ) + + +class MeteringMode(int, IFromProto, IFromPublic, IToPublic): + UNSPECIFIED = 0 + RESERVED_CAPACITY = 1 + REQUEST_UNITS = 2 + + @staticmethod + def from_public( + m: Optional[ydb_topic_public_types.PublicMeteringMode], + ) -> Optional["MeteringMode"]: + if m is None: + return None + + return MeteringMode(m) + + @staticmethod + def from_proto(code: Optional[int]) -> Optional["MeteringMode"]: + if code is None: + return None + + return MeteringMode(code) + + def to_public(self) -> ydb_topic_public_types.PublicMeteringMode: + try: + ydb_topic_public_types.PublicMeteringMode(int(self)) + except KeyError: + return ydb_topic_public_types.PublicMeteringMode.UNSPECIFIED + + +@dataclass +class CreateTopicRequest(IToProto, IFromPublic): + path: str + partitioning_settings: "PartitioningSettings" + retention_period: typing.Optional[datetime.timedelta] + retention_storage_mb: typing.Optional[int] + supported_codecs: "SupportedCodecs" + partition_write_speed_bytes_per_second: typing.Optional[int] + partition_write_burst_bytes: typing.Optional[int] + attributes: Dict[str, str] + consumers: List["Consumer"] + metering_mode: "MeteringMode" + + def to_proto(self) -> ydb_topic_pb2.CreateTopicRequest: + return ydb_topic_pb2.CreateTopicRequest( + path=self.path, + partitioning_settings=self.partitioning_settings.to_proto(), + retention_period=proto_duration_from_timedelta(self.retention_period), + retention_storage_mb=self.retention_storage_mb, + supported_codecs=self.supported_codecs.to_proto(), + partition_write_speed_bytes_per_second=self.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=self.partition_write_burst_bytes, + attributes=self.attributes, + consumers=[consumer.to_proto() for consumer in self.consumers], + metering_mode=self.metering_mode, + ) + + @staticmethod + def from_public(req: ydb_topic_public_types.CreateTopicRequestParams): + supported_codecs = [] + + if req.supported_codecs is not None: + supported_codecs = req.supported_codecs + + consumers = [] + if req.consumers is not None: + for consumer in req.consumers: + if isinstance(consumer, str): + consumer = ydb_topic_public_types.PublicConsumer(name=consumer) + consumers.append(Consumer.from_public(consumer)) + + return CreateTopicRequest( + path=req.path, + partitioning_settings=PartitioningSettings( + min_active_partitions=req.min_active_partitions, + partition_count_limit=req.partition_count_limit, + ), + retention_period=req.retention_period, + retention_storage_mb=req.retention_storage_mb, + supported_codecs=SupportedCodecs( + codecs=supported_codecs, + ), + partition_write_speed_bytes_per_second=req.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=req.partition_write_burst_bytes, + attributes=req.attributes, + consumers=consumers, + metering_mode=MeteringMode.from_public(req.metering_mode), + ) + + +@dataclass +class CreateTopicResult: + pass + + +@dataclass +class DescribeTopicRequest: + path: str + include_stats: bool + + +@dataclass +class DescribeTopicResult(IFromProtoWithProtoType, IToPublic): + self_proto: ydb_scheme_pb2.Entry + partitioning_settings: PartitioningSettings + partitions: List["DescribeTopicResult.PartitionInfo"] + retention_period: datetime.timedelta + retention_storage_mb: int + supported_codecs: SupportedCodecs + partition_write_speed_bytes_per_second: int + partition_write_burst_bytes: int + attributes: Dict[str, str] + consumers: List["Consumer"] + metering_mode: MeteringMode + topic_stats: "DescribeTopicResult.TopicStats" + + @staticmethod + def from_proto(msg: ydb_topic_pb2.DescribeTopicResult) -> "DescribeTopicResult": + return DescribeTopicResult( + self_proto=msg.self, + partitioning_settings=PartitioningSettings.from_proto( + msg.partitioning_settings + ), + partitions=list( + map(DescribeTopicResult.PartitionInfo.from_proto, msg.partitions) + ), + retention_period=msg.retention_period, + retention_storage_mb=msg.retention_storage_mb, + supported_codecs=SupportedCodecs.from_proto(msg.supported_codecs), + partition_write_speed_bytes_per_second=msg.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=msg.partition_write_burst_bytes, + attributes=dict(msg.attributes), + consumers=list(map(Consumer.from_proto, msg.consumers)), + metering_mode=MeteringMode.from_proto(msg.metering_mode), + topic_stats=DescribeTopicResult.TopicStats.from_proto(msg.topic_stats), + ) + + @staticmethod + def empty_proto_message() -> ydb_topic_pb2.DescribeTopicResult: + return ydb_topic_pb2.DescribeTopicResult() + + def to_public(self) -> ydb_topic_public_types.PublicDescribeTopicResult: + return ydb_topic_public_types.PublicDescribeTopicResult( + self=scheme._wrap_scheme_entry(self.self_proto), + min_active_partitions=self.partitioning_settings.min_active_partitions, + partition_count_limit=self.partitioning_settings.partition_count_limit, + partitions=list( + map(DescribeTopicResult.PartitionInfo.to_public, self.partitions) + ), + retention_period=self.retention_period, + retention_storage_mb=self.retention_storage_mb, + supported_codecs=self.supported_codecs.to_public(), + partition_write_speed_bytes_per_second=self.partition_write_speed_bytes_per_second, + partition_write_burst_bytes=self.partition_write_burst_bytes, + attributes=self.attributes, + consumers=list(map(Consumer.to_public, self.consumers)), + metering_mode=self.metering_mode.to_public(), + topic_stats=self.topic_stats.to_public(), + ) + + @dataclass + class PartitionInfo(IFromProto, IToPublic): + partition_id: int + active: bool + child_partition_ids: List[int] + parent_partition_ids: List[int] + partition_stats: "PartitionStats" + + @staticmethod + def from_proto( + msg: Optional[ydb_topic_pb2.DescribeTopicResult.PartitionInfo], + ) -> Optional["DescribeTopicResult.PartitionInfo"]: + if msg is None: + return None + + return DescribeTopicResult.PartitionInfo( + partition_id=msg.partition_id, + active=msg.active, + child_partition_ids=list(msg.child_partition_ids), + parent_partition_ids=list(msg.parent_partition_ids), + partition_stats=PartitionStats.from_proto(msg.partition_stats), + ) + + def to_public( + self, + ) -> ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo: + partition_stats = None + if self.partition_stats is not None: + partition_stats = self.partition_stats.to_public() + return ydb_topic_public_types.PublicDescribeTopicResult.PartitionInfo( + partition_id=self.partition_id, + active=self.active, + child_partition_ids=self.child_partition_ids, + parent_partition_ids=self.parent_partition_ids, + partition_stats=partition_stats, + ) + + @dataclass + class TopicStats(IFromProto, IToPublic): + store_size_bytes: int + min_last_write_time: datetime.datetime + max_write_time_lag: datetime.timedelta + bytes_written: "MultipleWindowsStat" + + @staticmethod + def from_proto( + msg: Optional[ydb_topic_pb2.DescribeTopicResult.TopicStats], + ) -> Optional["DescribeTopicResult.TopicStats"]: + if msg is None: + return None + + return DescribeTopicResult.TopicStats( + store_size_bytes=msg.store_size_bytes, + min_last_write_time=datetime_from_proto_timestamp( + msg.min_last_write_time + ), + max_write_time_lag=timedelta_from_proto_duration( + msg.max_write_time_lag + ), + bytes_written=MultipleWindowsStat.from_proto(msg.bytes_written), + ) + + def to_public( + self, + ) -> ydb_topic_public_types.PublicDescribeTopicResult.TopicStats: + return ydb_topic_public_types.PublicDescribeTopicResult.TopicStats( + store_size_bytes=self.store_size_bytes, + min_last_write_time=self.min_last_write_time, + max_write_time_lag=self.max_write_time_lag, + bytes_written=self.bytes_written.to_public(), + ) + + +@dataclass +class PartitionStats(IFromProto, IToPublic): + partition_offsets: OffsetsRange + store_size_bytes: int + last_write_time: datetime.datetime + max_write_time_lag: datetime.timedelta + bytes_written: "MultipleWindowsStat" + partition_node_id: int + + @staticmethod + def from_proto( + msg: Optional[ydb_topic_pb2.PartitionStats], + ) -> Optional["PartitionStats"]: + if msg is None: + return None + return PartitionStats( + partition_offsets=OffsetsRange.from_proto(msg.partition_offsets), + store_size_bytes=msg.store_size_bytes, + last_write_time=datetime_from_proto_timestamp(msg.last_write_time), + max_write_time_lag=timedelta_from_proto_duration(msg.max_write_time_lag), + bytes_written=MultipleWindowsStat.from_proto(msg.bytes_written), + partition_node_id=msg.partition_node_id, + ) + + def to_public(self) -> ydb_topic_public_types.PublicPartitionStats: + return ydb_topic_public_types.PublicPartitionStats( + partition_start=self.partition_offsets.start, + partition_end=self.partition_offsets.end, + store_size_bytes=self.store_size_bytes, + last_write_time=self.last_write_time, + max_write_time_lag=self.max_write_time_lag, + bytes_written=self.bytes_written.to_public(), + partition_node_id=self.partition_node_id, + ) diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py new file mode 100644 index 00000000..4582f19a --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_topic_public_types.py @@ -0,0 +1,200 @@ +import datetime +import typing +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Optional, List, Union, Dict + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ..v4.protos import ydb_topic_pb2 +else: + from ..common.protos import ydb_topic_pb2 + +from .common_utils import IToProto +from ...scheme import SchemeEntry + + +@dataclass +# need similar struct to PublicDescribeTopicResult +class CreateTopicRequestParams: + path: str + min_active_partitions: Optional[int] + partition_count_limit: Optional[int] + retention_period: Optional[datetime.timedelta] + retention_storage_mb: Optional[int] + supported_codecs: Optional[List[Union["PublicCodec", int]]] + partition_write_speed_bytes_per_second: Optional[int] + partition_write_burst_bytes: Optional[int] + attributes: Optional[Dict[str, str]] + consumers: Optional[List[Union["PublicConsumer", str]]] + metering_mode: Optional["PublicMeteringMode"] + + +class PublicCodec(int): + """ + Codec value may contain any int number. + + Values below is only well-known predefined values, + but protocol support custom codecs. + """ + + UNSPECIFIED = 0 + RAW = 1 + GZIP = 2 + LZOP = 3 # Has not supported codec in standard library + ZSTD = 4 # Has not supported codec in standard library + + +class PublicMeteringMode(IntEnum): + UNSPECIFIED = 0 + RESERVED_CAPACITY = 1 + REQUEST_UNITS = 2 + + +@dataclass +class PublicConsumer: + name: str + important: bool = False + """ + Consumer may be marked as 'important'. It means messages for this consumer will never expire due to retention. + User should take care that such consumer never stalls, to prevent running out of disk space. + """ + + read_from: Optional[datetime.datetime] = None + "All messages with smaller server written_at timestamp will be skipped." + + supported_codecs: List[PublicCodec] = field(default_factory=lambda: list()) + """ + List of supported codecs by this consumer. + supported_codecs on topic must be contained inside this list. + """ + + attributes: Dict[str, str] = field(default_factory=lambda: dict()) + "Attributes of consumer" + + +@dataclass +class DropTopicRequestParams(IToProto): + path: str + + def to_proto(self) -> ydb_topic_pb2.DropTopicRequest: + return ydb_topic_pb2.DropTopicRequest(path=self.path) + + +@dataclass +class DescribeTopicRequestParams(IToProto): + path: str + include_stats: bool + + def to_proto(self) -> ydb_topic_pb2.DescribeTopicRequest: + return ydb_topic_pb2.DescribeTopicRequest( + path=self.path, include_stats=self.include_stats + ) + + +@dataclass +# Need similar struct to CreateTopicRequestParams +class PublicDescribeTopicResult: + self: SchemeEntry + "Description of scheme object" + + min_active_partitions: int + "Minimum partition count auto merge would stop working at" + + partition_count_limit: int + "Limit for total partition count, including active (open for write) and read-only partitions" + + partitions: List["PublicDescribeTopicResult.PartitionInfo"] + "Partitions description" + + retention_period: datetime.timedelta + "How long data in partition should be stored" + + retention_storage_mb: int + "How much data in partition should be stored. Zero value means infinite limit" + + supported_codecs: List[PublicCodec] + "List of allowed codecs for writers" + + partition_write_speed_bytes_per_second: int + "Partition write speed in bytes per second" + + partition_write_burst_bytes: int + "Burst size for write in partition, in bytes" + + attributes: Dict[str, str] + """User and server attributes of topic. Server attributes starts from "_" and will be validated by server.""" + + consumers: List[PublicConsumer] + """List of consumers for this topic""" + + metering_mode: PublicMeteringMode + "Metering settings" + + topic_stats: "PublicDescribeTopicResult.TopicStats" + "Statistics of topic" + + @dataclass + class PartitionInfo: + partition_id: int + "Partition identifier" + + active: bool + "Is partition open for write" + + child_partition_ids: List[int] + "Ids of partitions which was formed when this partition was split or merged" + + parent_partition_ids: List[int] + "Ids of partitions from which this partition was formed by split or merge" + + partition_stats: Optional["PublicPartitionStats"] + "Stats for partition, filled only when include_stats in request is true" + + @dataclass + class TopicStats: + store_size_bytes: int + "Approximate size of topic" + + min_last_write_time: datetime.datetime + "Minimum of timestamps of last write among all partitions." + + max_write_time_lag: datetime.timedelta + """ + Maximum of differences between write timestamp and create timestamp for all messages, + written during last minute. + """ + + bytes_written: "PublicMultipleWindowsStat" + "How much bytes were written statistics." + + +@dataclass +class PublicPartitionStats: + partition_start: int + "first message offset in the partition" + + partition_end: int + "offset after last stored message offset in the partition (last offset + 1)" + + store_size_bytes: int + "Approximate size of partition" + + last_write_time: datetime.datetime + "Timestamp of last write" + + max_write_time_lag: datetime.timedelta + "Maximum of differences between write timestamp and create timestamp for all messages, written during last minute." + + bytes_written: "PublicMultipleWindowsStat" + "How much bytes were written during several windows in this partition." + + partition_node_id: int + "Host where tablet for this partition works. Useful for debugging purposes." + + +@dataclass +class PublicMultipleWindowsStat: + per_minute: int + per_hour: int + per_day: int diff --git a/ydb/_grpc/grpcwrapper/ydb_topic_test.py b/ydb/_grpc/grpcwrapper/ydb_topic_test.py new file mode 100644 index 00000000..bff9b43d --- /dev/null +++ b/ydb/_grpc/grpcwrapper/ydb_topic_test.py @@ -0,0 +1,27 @@ +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange + + +def test_offsets_range_intersected(): + # not intersected + for test in [(0, 1, 1, 2), (1, 2, 3, 5)]: + assert not OffsetsRange(test[0], test[1]).is_intersected_with( + OffsetsRange(test[2], test[3]) + ) + assert not OffsetsRange(test[2], test[3]).is_intersected_with( + OffsetsRange(test[0], test[1]) + ) + + # intersected + for test in [ + (1, 2, 1, 2), + (1, 10, 1, 2), + (1, 10, 2, 3), + (1, 10, 5, 15), + (10, 20, 5, 15), + ]: + assert OffsetsRange(test[0], test[1]).is_intersected_with( + OffsetsRange(test[2], test[3]) + ) + assert OffsetsRange(test[2], test[3]).is_intersected_with( + OffsetsRange(test[0], test[1]) + ) diff --git a/ydb/_sp_impl.py b/ydb/_sp_impl.py index a8529d73..5974a301 100644 --- a/ydb/_sp_impl.py +++ b/ydb/_sp_impl.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import collections from concurrent import futures -from six.moves import queue +import queue import time import threading from . import settings, issues, _utilities, tracing diff --git a/ydb/_sqlalchemy/__init__.py b/ydb/_sqlalchemy/__init__.py new file mode 100644 index 00000000..d8931a5d --- /dev/null +++ b/ydb/_sqlalchemy/__init__.py @@ -0,0 +1,327 @@ +""" +Experimental +Work in progress, breaking changes are possible. +""" +import ydb +import ydb._dbapi as dbapi + +import sqlalchemy as sa +from sqlalchemy import dialects +from sqlalchemy import Table +from sqlalchemy.exc import CompileError +from sqlalchemy.sql import functions, literal_column +from sqlalchemy.sql.compiler import ( + IdentifierPreparer, + GenericTypeCompiler, + SQLCompiler, + DDLCompiler, +) +from sqlalchemy.sql.elements import ClauseList +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.util.compat import inspect_getfullargspec + +from ydb._sqlalchemy.types import UInt32, UInt64 + + +SQLALCHEMY_VERSION = tuple(sa.__version__.split(".")) +SA_14 = SQLALCHEMY_VERSION >= ("1", "4") + + +class YqlIdentifierPreparer(IdentifierPreparer): + def __init__(self, dialect): + super(YqlIdentifierPreparer, self).__init__( + dialect, + initial_quote="`", + final_quote="`", + ) + + def _requires_quotes(self, value): + # Force all identifiers to get quoted unless already quoted. + return not ( + value.startswith(self.initial_quote) and value.endswith(self.final_quote) + ) + + +class YqlTypeCompiler(GenericTypeCompiler): + def visit_VARCHAR(self, type_, **kw): + return "STRING" + + def visit_unicode(self, type_, **kw): + return "UTF8" + + def visit_NVARCHAR(self, type_, **kw): + return "UTF8" + + def visit_TEXT(self, type_, **kw): + return "UTF8" + + def visit_FLOAT(self, type_, **kw): + return "DOUBLE" + + def visit_BOOLEAN(self, type_, **kw): + return "BOOL" + + def visit_uint32(self, type_, **kw): + return "UInt32" + + def visit_uint64(self, type_, **kw): + return "UInt64" + + def visit_uint8(self, type_, **kw): + return "UInt8" + + def visit_INTEGER(self, type_, **kw): + return "Int64" + + def visit_NUMERIC(self, type_, **kw): + return "Int64" + + +class ParametrizedFunction(functions.Function): + __visit_name__ = "parametrized_function" + + def __init__(self, name, params, *args, **kwargs): + super(ParametrizedFunction, self).__init__(name, *args, **kwargs) + self._func_name = name + self._func_params = params + self.params_expr = ClauseList( + operator=functions.operators.comma_op, group_contents=True, *params + ).self_group() + + +class YqlCompiler(SQLCompiler): + def group_by_clause(self, select, **kw): + # Hack to ensure it is possible to define labels in groupby. + kw.update(within_columns_clause=True) + return super(YqlCompiler, self).group_by_clause(select, **kw) + + def visit_lambda(self, lambda_, **kw): + func = lambda_.func + spec = inspect_getfullargspec(func) + + if spec.varargs: + raise CompileError("Lambdas with *args are not supported") + + try: + keywords = spec.keywords + except AttributeError: + keywords = spec.varkw + + if keywords: + raise CompileError("Lambdas with **kwargs are not supported") + + text = "(" + ", ".join("$" + arg for arg in spec.args) + ")" + " -> " + + args = [literal_column("$" + arg) for arg in spec.args] + text += "{ RETURN " + self.process(func(*args), **kw) + " ;}" + + return text + + def visit_parametrized_function(self, func, **kwargs): + name = func.name + name_parts = [] + for name in name.split("::"): + fname = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + or isinstance(name, sa.sql.elements.quoted_name) + else name + ) + + name_parts.append(fname) + + name = "::".join(name_parts) + params = func.params_expr._compiler_dispatch(self, **kwargs) + args = self.function_argspec(func, **kwargs) + return "%(name)s%(params)s%(args)s" % dict(name=name, params=params, args=args) + + def visit_function(self, func, add_to_result_map=None, **kwargs): + # Copypaste of `sa.sql.compiler.SQLCompiler.visit_function` with + # `::` as namespace separator instead of `.` + if add_to_result_map is not None: + add_to_result_map(func.name, func.name, (), func.type) + + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + return disp(func, **kwargs) + else: + name = sa.sql.compiler.FUNCTIONS.get(func.__class__, None) + if name: + if func._has_args: + name += "%(expr)s" + else: + name = func.name + name = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + or isinstance(name, sa.sql.elements.quoted_name) + else name + ) + name = name + "%(expr)s" + return "::".join( + [ + ( + self.preparer.quote(tok) + if self.preparer._requires_quotes_illegal_chars(tok) + or isinstance(name, sa.sql.elements.quoted_name) + else tok + ) + for tok in func.packagenames + ] + + [name] + ) % {"expr": self.function_argspec(func, **kwargs)} + + +class YqlDdlCompiler(DDLCompiler): + pass + + +def upsert(table): + return sa.sql.Insert(table) + + +COLUMN_TYPES = { + ydb.PrimitiveType.Int8: sa.INTEGER, + ydb.PrimitiveType.Int16: sa.INTEGER, + ydb.PrimitiveType.Int32: sa.INTEGER, + ydb.PrimitiveType.Int64: sa.INTEGER, + ydb.PrimitiveType.Uint8: sa.INTEGER, + ydb.PrimitiveType.Uint16: sa.INTEGER, + ydb.PrimitiveType.Uint32: UInt32, + ydb.PrimitiveType.Uint64: UInt64, + ydb.PrimitiveType.Float: sa.FLOAT, + ydb.PrimitiveType.Double: sa.FLOAT, + ydb.PrimitiveType.String: sa.TEXT, + ydb.PrimitiveType.Utf8: sa.TEXT, + ydb.PrimitiveType.Json: sa.JSON, + ydb.PrimitiveType.JsonDocument: sa.JSON, + ydb.DecimalType: sa.DECIMAL, + ydb.PrimitiveType.Yson: sa.TEXT, + ydb.PrimitiveType.Date: sa.DATE, + ydb.PrimitiveType.Datetime: sa.DATETIME, + ydb.PrimitiveType.Timestamp: sa.DATETIME, + ydb.PrimitiveType.Interval: sa.INTEGER, + ydb.PrimitiveType.Bool: sa.BOOLEAN, + ydb.PrimitiveType.DyNumber: sa.TEXT, +} + + +def _get_column_info(t): + nullable = False + if isinstance(t, ydb.OptionalType): + nullable = True + t = t.item + + if isinstance(t, ydb.DecimalType): + return sa.DECIMAL(precision=t.precision, scale=t.scale), nullable + + return COLUMN_TYPES[t], nullable + + +class YqlDialect(DefaultDialect): + name = "yql" + supports_alter = False + max_identifier_length = 63 + supports_sane_rowcount = False + supports_statement_cache = False + + supports_native_enum = False + supports_native_boolean = True + supports_smallserial = False + + supports_sequences = False + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + + supports_default_values = False + supports_empty_insert = False + supports_multivalues_insert = True + default_paramstyle = "qmark" + + isolation_level = None + + preparer = YqlIdentifierPreparer + statement_compiler = YqlCompiler + ddl_compiler = YqlDdlCompiler + type_compiler = YqlTypeCompiler + + driver = ydb.Driver + + @staticmethod + def dbapi(): + return dbapi + + def _check_unicode_returns(self, *args, **kwargs): + # Normally, this would do 2 SQL queries, which isn't quite necessary. + return "conditional" + + def get_columns(self, connection, table_name, schema=None, **kw): + if schema is not None: + raise dbapi.errors.NotSupportedError("unsupported on non empty schema") + + qt = table_name.name if isinstance(table_name, Table) else table_name + + if SA_14: + raw_conn = connection.connection + else: + raw_conn = connection.raw_connection() + + columns = raw_conn.describe(qt) + as_compatible = [] + for column in columns: + col_type, nullable = _get_column_info(column.type) + as_compatible.append( + { + "name": column.name, + "type": col_type, + "nullable": nullable, + } + ) + + return as_compatible + + def has_table(self, connection, table_name, schema=None, **kwargs): + if schema is not None: + raise dbapi.errors.NotSupportedError("unsupported on non empty schema") + + quote = self.identifier_preparer.quote_identifier + qtable = quote(table_name) + + # TODO: use `get_columns` instead. + statement = "SELECT * FROM " + qtable + try: + connection.execute(statement) + return True + except Exception: + return False + + def get_pk_constraint(self, connection, table_name, schema=None, **kwargs): + # TODO: implement me + return [] + + def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): + # foreign keys unsupported + return [] + + def get_indexes(self, connection, table_name, schema=None, **kwargs): + # TODO: implement me + return [] + + def do_commit(self, dbapi_connection) -> None: + # TODO: needs to implement? + pass + + def do_execute(self, cursor, statement, parameters, context=None) -> None: + c = None + if context is not None and context.isddl: + c = {"isddl": True} + cursor.execute(statement, parameters, c) + + +def register_dialect( + name="yql", + module=__name__, + cls="YqlDialect", +): + return dialects.registry.register(name, module, cls) diff --git a/ydb/_sqlalchemy/types.py b/ydb/_sqlalchemy/types.py new file mode 100644 index 00000000..21748ec1 --- /dev/null +++ b/ydb/_sqlalchemy/types.py @@ -0,0 +1,28 @@ +from sqlalchemy.types import Integer +from sqlalchemy.sql import type_api +from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy import util, exc + + +class UInt32(Integer): + __visit_name__ = "uint32" + + +class UInt64(Integer): + __visit_name__ = "uint64" + + +class UInt8(Integer): + __visit_name__ = "uint8" + + +class Lambda(ColumnElement): + + __visit_name__ = "lambda" + + def __init__(self, func): + if not util.callable(func): + raise exc.ArgumentError("func must be callable") + + self.type = type_api.NULLTYPE + self.func = func diff --git a/kikimr/public/__init__.py b/ydb/_topic_common/__init__.py similarity index 100% rename from kikimr/public/__init__.py rename to ydb/_topic_common/__init__.py diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py new file mode 100644 index 00000000..9e8f1326 --- /dev/null +++ b/ydb/_topic_common/common.py @@ -0,0 +1,147 @@ +import asyncio +import concurrent.futures +import threading +import typing +from typing import Optional + +from .. import operation, issues +from .._grpc.grpcwrapper.common_utils import IFromProtoWithProtoType + +TimeoutType = typing.Union[int, float, None] + + +def wrap_operation(rpc_state, response_pb, driver=None): + return operation.Operation(rpc_state, response_pb, driver) + + +ResultType = typing.TypeVar("ResultType", bound=IFromProtoWithProtoType) + + +def create_result_wrapper( + result_type: typing.Type[ResultType], +) -> typing.Callable[[typing.Any, typing.Any, typing.Any], ResultType]: + def wrapper(rpc_state, response_pb, driver=None): + issues._process_response(response_pb.operation) + msg = result_type.empty_proto_message() + response_pb.operation.result.Unpack(msg) + return result_type.from_proto(msg) + + return wrapper + + +_shared_event_loop_lock = threading.Lock() +_shared_event_loop: Optional[asyncio.AbstractEventLoop] = None + + +def _get_shared_event_loop() -> asyncio.AbstractEventLoop: + global _shared_event_loop + + if _shared_event_loop is not None: + return _shared_event_loop + + with _shared_event_loop_lock: + if _shared_event_loop is not None: + return _shared_event_loop + + event_loop_set_done = concurrent.futures.Future() + + def start_event_loop(): + event_loop = asyncio.new_event_loop() + event_loop_set_done.set_result(event_loop) + asyncio.set_event_loop(event_loop) + event_loop.run_forever() + + t = threading.Thread( + target=start_event_loop, + name="Common ydb topic event loop", + daemon=True, + ) + t.start() + + _shared_event_loop = event_loop_set_done.result() + return _shared_event_loop + + +class CallFromSyncToAsync: + _loop: asyncio.AbstractEventLoop + + def __init__(self, loop: asyncio.AbstractEventLoop): + self._loop = loop + + def unsafe_call_with_future( + self, coro: typing.Coroutine + ) -> concurrent.futures.Future: + """ + returned result from coro may be lost + """ + return asyncio.run_coroutine_threadsafe(coro, self._loop) + + def unsafe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType): + """ + returned result from coro may be lost by race future cancel by timeout and return value from coroutine + """ + f = self.unsafe_call_with_future(coro) + try: + return f.result(timeout) + except concurrent.futures.TimeoutError: + raise TimeoutError() + finally: + if not f.done(): + f.cancel() + + def safe_call_with_result(self, coro: typing.Coroutine, timeout: TimeoutType): + """ + no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation. + """ + + if timeout is not None and timeout <= 0: + return self._safe_call_fast(coro) + + async def call_coro(): + task = self._loop.create_task(coro) + try: + res = await asyncio.wait_for(task, timeout) + return res + except asyncio.TimeoutError: + try: + res = await task + return res + except asyncio.CancelledError: + pass + + # return builtin TimeoutError instead of asyncio.TimeoutError + raise TimeoutError() + + return asyncio.run_coroutine_threadsafe(call_coro(), self._loop).result() + + def _safe_call_fast(self, coro: typing.Coroutine): + """ + no lost returned value from coro, but may be slower especially timeout latency - it wait coroutine cancelation. + Wait coroutine result only one loop. + """ + res = concurrent.futures.Future() + + async def call_coro(): + try: + res.set_result(await coro) + except asyncio.CancelledError: + res.set_exception(TimeoutError()) + + coro_future = asyncio.run_coroutine_threadsafe(call_coro(), self._loop) + asyncio.run_coroutine_threadsafe(asyncio.sleep(0), self._loop).result() + coro_future.cancel() + return res.result() + + def call_sync(self, callback: typing.Callable[[], typing.Any]) -> typing.Any: + result = concurrent.futures.Future() + + def call_callback(): + try: + res = callback() + result.set_result(res) + except BaseException as err: + result.set_exception(err) + + self._loop.call_soon_threadsafe(call_callback) + + return result.result() diff --git a/ydb/_topic_common/common_test.py b/ydb/_topic_common/common_test.py new file mode 100644 index 00000000..b31f9af9 --- /dev/null +++ b/ydb/_topic_common/common_test.py @@ -0,0 +1,290 @@ +import asyncio +import threading +import time +import typing + +import grpc +import pytest + +from .common import CallFromSyncToAsync +from .._grpc.grpcwrapper.common_utils import ( + GrpcWrapperAsyncIO, + ServerStatus, + callback_from_asyncio, +) +from .. import issues + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ydb._grpc.v4.protos import ( + ydb_status_codes_pb2, + ydb_topic_pb2, + ) +else: + # noinspection PyUnresolvedReferences + from ydb._grpc.common.protos import ( + ydb_status_codes_pb2, + ydb_topic_pb2, + ) + + +@pytest.fixture() +def separate_loop(): + loop = asyncio.new_event_loop() + + def run_loop(): + loop.run_forever() + pass + + t = threading.Thread(target=run_loop, name="test separate loop") + t.start() + + yield loop + + loop.call_soon_threadsafe(lambda: loop.stop()) + t.join() + + +@pytest.mark.asyncio +class Test: + async def test_callback_from_asyncio(self): + class TestError(Exception): + pass + + def sync_success(): + return 1 + + assert await callback_from_asyncio(sync_success) == 1 + + def sync_failed(): + raise TestError() + + with pytest.raises(TestError): + await callback_from_asyncio(sync_failed) + + async def async_success(): + await asyncio.sleep(0) + return 1 + + assert await callback_from_asyncio(async_success) == 1 + + async def async_failed(): + await asyncio.sleep(0) + raise TestError() + + with pytest.raises(TestError): + await callback_from_asyncio(async_failed) + + +@pytest.mark.asyncio +class TestGrpcWrapperAsyncIO: + async def test_convert_grpc_errors_to_ydb(self): + class TestError(grpc.RpcError, grpc.Call): + def __init__(self): + pass + + def code(self): + return grpc.StatusCode.UNAUTHENTICATED + + def details(self): + return "test error" + + class FromServerMock: + async def __anext__(self): + raise TestError() + + wrapper = GrpcWrapperAsyncIO(lambda: None) + wrapper.from_server_grpc = FromServerMock() + + with pytest.raises(issues.Unauthenticated): + await wrapper.receive() + + async def convert_status_code_to_ydb_error(self): + class FromServerMock: + async def __anext__(self): + return ydb_topic_pb2.StreamReadMessage.FromServer( + status=ydb_status_codes_pb2.StatusIds.OVERLOADED, + issues=[], + ) + + wrapper = GrpcWrapperAsyncIO(lambda: None) + wrapper.from_server_grpc = FromServerMock() + + with pytest.raises(issues.Overloaded): + await wrapper.receive() + + +class TestServerStatus: + def test_success(self): + status = ServerStatus( + status=ydb_status_codes_pb2.StatusIds.SUCCESS, + issues=[], + ) + assert status.is_success() + assert issues._process_response(status) is None + + def test_failed(self): + status = ServerStatus( + status=ydb_status_codes_pb2.StatusIds.OVERLOADED, + issues=[], + ) + assert not status.is_success() + with pytest.raises(issues.Overloaded): + issues._process_response(status) + + +@pytest.mark.asyncio +class TestCallFromSyncToAsync: + @pytest.fixture() + def caller(self, separate_loop): + return CallFromSyncToAsync(separate_loop) + + def test_unsafe_call_with_future(self, separate_loop, caller): + callback_loop = None + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + return 1 + + f = caller.unsafe_call_with_future(callback()) + + assert f.result() == 1 + assert callback_loop is separate_loop + + def test_unsafe_call_with_result_ok(self, separate_loop, caller): + callback_loop = None + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + return 1 + + res = caller.unsafe_call_with_result(callback(), None) + + assert res == 1 + assert callback_loop is separate_loop + + def test_unsafe_call_with_result_timeout(self, separate_loop, caller): + timeout = 0.01 + callback_loop = None + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + await asyncio.sleep(1) + return 1 + + start = time.monotonic() + with pytest.raises(TimeoutError): + caller.unsafe_call_with_result(callback(), timeout) + finished = time.monotonic() + + assert callback_loop is separate_loop + assert finished - start > timeout + + def test_safe_call_with_result_ok(self, separate_loop, caller): + callback_loop = None + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + return 1 + + res = caller.safe_call_with_result(callback(), 1) + + assert res == 1 + assert callback_loop is separate_loop + + def test_safe_call_with_result_timeout(self, separate_loop, caller): + timeout = 0.01 + callback_loop = None + cancelled = False + + async def callback(): + nonlocal callback_loop, cancelled + callback_loop = asyncio.get_running_loop() + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + cancelled = True + raise + + return 1 + + start = time.monotonic() + with pytest.raises(TimeoutError): + caller.safe_call_with_result(callback(), timeout) + finished = time.monotonic() + + # wait one loop for handle task cancelation + asyncio.run_coroutine_threadsafe(asyncio.sleep(0), separate_loop) + + assert callback_loop is separate_loop + assert finished - start > timeout + assert cancelled + + def test_safe_callback_with_0_timeout_ok(self, separate_loop, caller): + callback_loop = None + + async def f1(): + return 1 + + async def f2(): + return await f1() + + async def callback(): + nonlocal callback_loop + callback_loop = asyncio.get_running_loop() + return await f2() + + res = caller.safe_call_with_result(callback(), 0) + assert callback_loop is separate_loop + assert res == 1 + + def test_safe_callback_with_0_timeout_timeout(self, separate_loop, caller): + callback_loop = None + cancelled = False + + async def callback(): + try: + nonlocal callback_loop, cancelled + + callback_loop = asyncio.get_running_loop() + await asyncio.sleep(1) + except asyncio.CancelledError: + cancelled = True + raise + + with pytest.raises(TimeoutError): + caller.safe_call_with_result(callback(), 0) + + assert callback_loop is separate_loop + assert cancelled + + def test_call_sync_ok(self, separate_loop, caller): + callback_eventloop = None + + def callback(): + nonlocal callback_eventloop + callback_eventloop = asyncio.get_running_loop() + return 1 + + res = caller.call_sync(callback) + assert callback_eventloop is separate_loop + assert res == 1 + + def test_call_sync_error(self, separate_loop, caller): + callback_eventloop = None + + class TestError(RuntimeError): + pass + + def callback(): + nonlocal callback_eventloop + callback_eventloop = asyncio.get_running_loop() + raise TestError + + with pytest.raises(TestError): + caller.call_sync(callback) + assert callback_eventloop is separate_loop diff --git a/ydb/_topic_common/test_helpers.py b/ydb/_topic_common/test_helpers.py new file mode 100644 index 00000000..96a812ab --- /dev/null +++ b/ydb/_topic_common/test_helpers.py @@ -0,0 +1,76 @@ +import asyncio +import time +import typing + +from .._grpc.grpcwrapper.common_utils import IToProto, IGrpcWrapperAsyncIO + + +class StreamMock(IGrpcWrapperAsyncIO): + from_server: asyncio.Queue + from_client: asyncio.Queue + _closed: bool + + def __init__(self): + self.from_server = asyncio.Queue() + self.from_client = asyncio.Queue() + self._closed = False + + async def receive(self) -> typing.Any: + if self._closed: + raise Exception("read from closed StreamMock") + + item = await self.from_server.get() + if item is None: + raise StopAsyncIteration() + if isinstance(item, Exception): + raise item + return item + + def write(self, wrap_message: IToProto): + if self._closed: + raise Exception("write to closed StreamMock") + self.from_client.put_nowait(wrap_message) + + def close(self): + if self._closed: + return + + self._closed = True + self.from_server.put_nowait(None) + + +class WaitConditionError(Exception): + pass + + +async def wait_condition( + f: typing.Callable[[], bool], + timeout: typing.Optional[typing.Union[float, int]] = None, +): + """ + timeout default is 1 second + if timeout is 0 - only counter work. It userful if test need fast timeout for condition (without wait full timeout) + """ + if timeout is None: + timeout = 1 + + minimal_loop_count_for_wait = 1000 + + start = time.monotonic() + counter = 0 + while (time.monotonic() - start < timeout) or counter < minimal_loop_count_for_wait: + counter += 1 + if f(): + return + await asyncio.sleep(0) + + raise WaitConditionError("Bad condition in test") + + +async def wait_for_fast( + awaitable: typing.Awaitable, + timeout: typing.Optional[typing.Union[float, int]] = None, +): + fut = asyncio.ensure_future(awaitable) + await wait_condition(lambda: fut.done(), timeout) + return fut.result() diff --git a/kikimr/public/sdk/__init__.py b/ydb/_topic_reader/__init__.py similarity index 100% rename from kikimr/public/sdk/__init__.py rename to ydb/_topic_reader/__init__.py diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py new file mode 100644 index 00000000..434eff35 --- /dev/null +++ b/ydb/_topic_reader/datatypes.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import abc +import asyncio +import bisect +import enum +from collections import deque +from dataclasses import dataclass, field +import datetime +from typing import Union, Any, List, Dict, Deque, Optional + +from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange, Codec +from ydb._topic_reader import topic_reader_asyncio + + +class ICommittable(abc.ABC): + @abc.abstractmethod + def _commit_get_partition_session(self) -> PartitionSession: + ... + + @abc.abstractmethod + def _commit_get_offsets_range(self) -> OffsetsRange: + ... + + +class ISessionAlive(abc.ABC): + @property + @abc.abstractmethod + def alive(self) -> bool: + pass + + +@dataclass +class PublicMessage(ICommittable, ISessionAlive): + seqno: int + created_at: datetime.datetime + message_group_id: str + session_metadata: Dict[str, str] + offset: int + written_at: datetime.datetime + producer_id: str + data: Union[ + bytes, Any + ] # set as original decompressed bytes or deserialized object if deserializer set in reader + _partition_session: PartitionSession + _commit_start_offset: int + _commit_end_offset: int + + def _commit_get_partition_session(self) -> PartitionSession: + return self._partition_session + + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange(self._commit_start_offset, self._commit_end_offset) + + # ISessionAlive implementation + @property + def alive(self) -> bool: + return not self._partition_session.closed + + +@dataclass +class PartitionSession: + id: int + state: "PartitionSession.State" + topic_path: str + partition_id: int + committed_offset: int # last commit offset, acked from server. Processed messages up to the field-1 offset. + reader_reconnector_id: int + reader_stream_id: int + _next_message_start_commit_offset: int = field(init=False) + + # todo: check if deque is optimal + _ack_waiters: Deque["PartitionSession.CommitAckWaiter"] = field( + init=False, default_factory=lambda: deque() + ) + + _state_changed: asyncio.Event = field( + init=False, default_factory=lambda: asyncio.Event(), compare=False + ) + _loop: Optional[asyncio.AbstractEventLoop] = field( + init=False + ) # may be None in tests + + def __post_init__(self): + self._next_message_start_commit_offset = self.committed_offset + + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = None + + def add_waiter(self, end_offset: int) -> "PartitionSession.CommitAckWaiter": + waiter = PartitionSession.CommitAckWaiter(end_offset, self._create_future()) + if end_offset <= self.committed_offset: + waiter._finish_ok() + return waiter + + # fast way + if len(self._ack_waiters) > 0 and self._ack_waiters[-1].end_offset < end_offset: + self._ack_waiters.append(waiter) + else: + bisect.insort(self._ack_waiters, waiter) + + return waiter + + def _create_future(self) -> asyncio.Future: + if self._loop: + return self._loop.create_future() + else: + return asyncio.Future() + + def ack_notify(self, offset: int): + self._ensure_not_closed() + + self.committed_offset = offset + + if len(self._ack_waiters) == 0: + # todo log warning + # must be never receive ack for not sended request + return + + while len(self._ack_waiters) > 0: + if self._ack_waiters[0].end_offset <= offset: + waiter = self._ack_waiters.popleft() + waiter._finish_ok() + else: + break + + def close(self): + if self.closed: + return + + self.state = PartitionSession.State.Stopped + exception = topic_reader_asyncio.TopicReaderCommitToExpiredPartition() + for waiter in self._ack_waiters: + waiter._finish_error(exception) + + @property + def closed(self): + return self.state == PartitionSession.State.Stopped + + def _ensure_not_closed(self): + if self.state == PartitionSession.State.Stopped: + raise topic_reader_asyncio.TopicReaderCommitToExpiredPartition() + + class State(enum.Enum): + Active = 1 + GracefulShutdown = 2 + Stopped = 3 + + @dataclass(order=True) + class CommitAckWaiter: + end_offset: int + future: asyncio.Future = field(compare=False) + _done: bool = field(default=False, init=False) + _exception: Optional[Exception] = field(default=None, init=False) + + def _finish_ok(self): + self._done = True + self.future.set_result(None) + + def _finish_error(self, error: Exception): + self._exception = error + self.future.set_exception(error) + + +@dataclass +class PublicBatch(ICommittable, ISessionAlive): + messages: List[PublicMessage] + _partition_session: PartitionSession + _bytes_size: int + _codec: Codec + + def _commit_get_partition_session(self) -> PartitionSession: + return self.messages[0]._commit_get_partition_session() + + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange( + self.messages[0]._commit_get_offsets_range().start, + self.messages[-1]._commit_get_offsets_range().end, + ) + + def empty(self) -> bool: + return len(self.messages) == 0 + + # ISessionAlive implementation + @property + def alive(self) -> bool: + return not self._partition_session.closed + + def pop_message(self) -> PublicMessage: + return self.messages.pop(0) diff --git a/ydb/_topic_reader/datatypes_test.py b/ydb/_topic_reader/datatypes_test.py new file mode 100644 index 00000000..2ec1229f --- /dev/null +++ b/ydb/_topic_reader/datatypes_test.py @@ -0,0 +1,221 @@ +import asyncio +import copy +import functools +from collections import deque +from typing import List + +import pytest + +from ydb._topic_common.test_helpers import wait_condition +from ydb._topic_reader import topic_reader_asyncio +from ydb._topic_reader.datatypes import PartitionSession + + +@pytest.mark.asyncio +class TestPartitionSession: + session_comitted_offset = 10 + + @pytest.fixture + def session(self) -> PartitionSession: + return PartitionSession( + id=1, + state=PartitionSession.State.Active, + topic_path="", + partition_id=1, + committed_offset=self.session_comitted_offset, + reader_reconnector_id=1, + reader_stream_id=1, + ) + + @pytest.mark.parametrize( + "offsets_waited,notify_offset,offsets_notified,offsets_waited_rest", + [ + ([1], 1, [1], []), + ([1], 10, [1], []), + ([1, 2, 3], 10, [1, 2, 3], []), + ([1, 2, 10, 20], 10, [1, 2, 10], [20]), + ([10, 20], 1, [], [10, 20]), + ], + ) + async def test_ack_notify( + self, + session, + offsets_waited: List[int], + notify_offset: int, + offsets_notified: List[int], + offsets_waited_rest: List[int], + ): + notified = [] + + for offset in offsets_waited: + fut = asyncio.Future() + + def add_notify(future, notified_offset): + notified.append(notified_offset) + + fut.add_done_callback(functools.partial(add_notify, notified_offset=offset)) + waiter = PartitionSession.CommitAckWaiter(offset, fut) + session._ack_waiters.append(waiter) + + session.ack_notify(notify_offset) + assert session._ack_waiters == deque( + [ + PartitionSession.CommitAckWaiter(offset, asyncio.Future()) + for offset in offsets_waited_rest + ] + ) + + await wait_condition(lambda: len(notified) == len(offsets_notified)) + + notified.sort() + assert notified == offsets_notified + assert session.committed_offset == notify_offset + + # noinspection PyTypeChecker + @pytest.mark.parametrize( + "original,add,is_done,result", + [ + ( + [], + session_comitted_offset - 5, + True, + [], + ), + ( + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 0, + True, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + ], + ), + ( + [], + session_comitted_offset + 5, + False, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + ], + ), + ( + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 6, + False, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 6, None), + ], + ), + ( + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 4, + False, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 4, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + ], + ), + ( + [PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None)], + session_comitted_offset + 100, + False, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), + ], + session_comitted_offset + 50, + False, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 50, None + ), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 7, None), + ], + session_comitted_offset + 6, + False, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 6, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 7, None), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), + ], + session_comitted_offset + 6, + False, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter(session_comitted_offset + 6, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), + ], + ), + ( + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), + ], + session_comitted_offset + 99, + False, + [ + PartitionSession.CommitAckWaiter(session_comitted_offset + 5, None), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 99, None + ), + PartitionSession.CommitAckWaiter( + session_comitted_offset + 100, None + ), + ], + ), + ], + ) + def test_add_waiter( + self, + session, + original: List[PartitionSession.CommitAckWaiter], + add: int, + is_done: bool, + result: List[PartitionSession.CommitAckWaiter], + ): + session._ack_waiters = copy.deepcopy(original) + res = session.add_waiter(add) + assert result == session._ack_waiters + assert res.future.done() == is_done + + def test_close_notify_waiters(self, session): + waiter = session.add_waiter(session.committed_offset + 1) + session.close() + + with pytest.raises(topic_reader_asyncio.TopicReaderCommitToExpiredPartition): + waiter.future.result() + + def test_close_twice(self, session): + session.close() + session.close() diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py new file mode 100644 index 00000000..b3d5637d --- /dev/null +++ b/ydb/_topic_reader/topic_reader.py @@ -0,0 +1,103 @@ +import concurrent.futures +import enum +import datetime +from dataclasses import dataclass +from typing import ( + Union, + Optional, + List, + Mapping, + Callable, +) + +from ..table import RetrySettings +from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, OffsetsRange + + +class Selector: + path: str + partitions: Union[None, int, List[int]] + read_from_timestamp_ms: Optional[int] + max_time_lag_ms: Optional[int] + + def __init__(self, path, *, partitions: Union[None, int, List[int]] = None): + self.path = path + self.partitions = partitions + + +@dataclass +class PublicReaderSettings: + consumer: str + topic: str + buffer_size_bytes: int = 50 * 1024 * 1024 + + decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None + """decoders: map[codec_code] func(encoded_bytes)->decoded_bytes""" + + # decoder_executor, must be set for handle non raw messages + decoder_executor: Optional[concurrent.futures.Executor] = None + update_token_interval: Union[int, float] = 3600 + + def _init_message(self) -> StreamReadMessage.InitRequest: + return StreamReadMessage.InitRequest( + topics_read_settings=[ + StreamReadMessage.InitRequest.TopicReadSettings( + path=self.topic, + ) + ], + consumer=self.consumer, + ) + + def _retry_settings(self) -> RetrySettings: + return RetrySettings(idempotent=True) + + +class Events: + class OnCommit: + topic: str + offset: int + + class OnPartitionGetStartOffsetRequest: + topic: str + partition_id: int + + class OnPartitionGetStartOffsetResponse: + start_offset: int + + class OnInitPartition: + pass + + class OnShutdownPatition: + pass + + +class RetryPolicy: + connection_timeout_sec: float + overload_timeout_sec: float + retry_access_denied: bool = False + + +class CommitResult: + topic: str + partition: int + offset: int + state: "CommitResult.State" + details: str # for humans only, content messages may be change in any time + + class State(enum.Enum): + UNSENT = 1 # commit didn't send to the server + SENT = 2 # commit was sent to server, but ack hasn't received + ACKED = 3 # ack from server is received + + +class SessionStat: + path: str + partition_id: str + partition_offsets: OffsetsRange + committed_offset: int + write_time_high_watermark: datetime.datetime + write_time_high_watermark_timestamp_nano: int + + +class StubEvent: + pass diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py new file mode 100644 index 00000000..9eda2fbf --- /dev/null +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -0,0 +1,658 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import gzip +import typing +from asyncio import Task +from collections import deque +from typing import Optional, Set, Dict, Union, Callable + +from .. import _apis, issues +from .._utilities import AtomicCounter +from ..aio import Driver +from ..issues import Error as YdbError, _process_response +from . import datatypes +from . import topic_reader +from .._grpc.grpcwrapper.common_utils import ( + IGrpcWrapperAsyncIO, + SupportedDriverType, + GrpcWrapperAsyncIO, +) +from .._grpc.grpcwrapper.ydb_topic import ( + StreamReadMessage, + UpdateTokenRequest, + UpdateTokenResponse, + Codec, +) +from .._errors import check_retriable_error + + +class TopicReaderError(YdbError): + pass + + +class TopicReaderUnexpectedCodec(YdbError): + pass + + +class TopicReaderCommitToExpiredPartition(TopicReaderError): + """ + Commit message when partition read session are dropped. + It is ok - the message/batch will not commit to server and will receive in other read session + (with this or other reader). + """ + + def __init__(self, message: str = "Topic reader partition session is closed"): + super().__init__(message) + + +class TopicReaderStreamClosedError(TopicReaderError): + def __init__(self): + super().__init__("Topic reader stream is closed") + + +class TopicReaderClosedError(TopicReaderError): + def __init__(self): + super().__init__("Topic reader is closed already") + + +class PublicAsyncIOReader: + _loop: asyncio.AbstractEventLoop + _closed: bool + _reconnector: ReaderReconnector + + def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): + self._loop = asyncio.get_running_loop() + self._closed = False + self._reconnector = ReaderReconnector(driver, settings) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + def __del__(self): + if not self._closed: + self._loop.create_task(self.close(), name="close reader") + + async def receive_batch( + self, + ) -> typing.Union[datatypes.PublicBatch, None]: + """ + Get one messages batch from reader. + All messages in a batch from same partition. + + use asyncio.wait_for for wait with timeout. + """ + await self._reconnector.wait_message() + return self._reconnector.receive_batch_nowait() + + async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]: + """ + Block until receive new message + + use asyncio.wait_for for wait with timeout. + """ + await self._reconnector.wait_message() + return self._reconnector.receive_message_nowait() + + def commit( + self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): + """ + Write commit message to a buffer. + + For the method no way check the commit result + (for example if lost connection - commits will not re-send and committed messages will receive again) + """ + self._reconnector.commit(batch) + + async def commit_with_ack( + self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): + """ + write commit message to a buffer and wait ack from the server. + + use asyncio.wait_for for wait with timeout. + """ + waiter = self._reconnector.commit(batch) + await waiter.future + + async def close(self): + if self._closed: + raise TopicReaderClosedError() + + self._closed = True + await self._reconnector.close() + + +class ReaderReconnector: + _static_reader_reconnector_counter = AtomicCounter() + + _id: int + _settings: topic_reader.PublicReaderSettings + _driver: Driver + _background_tasks: Set[Task] + + _state_changed: asyncio.Event + _stream_reader: Optional["ReaderStream"] + _first_error: asyncio.Future[YdbError] + + def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): + self._id = self._static_reader_reconnector_counter.inc_and_get() + + self._settings = settings + self._driver = driver + self._background_tasks = set() + + self._state_changed = asyncio.Event() + self._stream_reader = None + self._background_tasks.add(asyncio.create_task(self._connection_loop())) + self._first_error = asyncio.get_running_loop().create_future() + + async def _connection_loop(self): + attempt = 0 + while True: + try: + self._stream_reader = await ReaderStream.create( + self._id, self._driver, self._settings + ) + attempt = 0 + self._state_changed.set() + await self._stream_reader.wait_error() + except issues.Error as err: + retry_info = check_retriable_error( + err, self._settings._retry_settings(), attempt + ) + if not retry_info.is_retriable: + self._set_first_error(err) + return + await asyncio.sleep(retry_info.sleep_timeout_seconds) + + attempt += 1 + + async def wait_message(self): + while True: + if self._first_error.done(): + raise self._first_error.result() + + if self._stream_reader: + try: + await self._stream_reader.wait_messages() + return + except YdbError: + pass # handle errors in reconnection loop + + await self._state_changed.wait() + self._state_changed.clear() + + def receive_batch_nowait(self): + return self._stream_reader.receive_batch_nowait() + + def receive_message_nowait(self): + return self._stream_reader.receive_message_nowait() + + def commit( + self, batch: datatypes.ICommittable + ) -> datatypes.PartitionSession.CommitAckWaiter: + return self._stream_reader.commit(batch) + + async def close(self): + if self._stream_reader: + await self._stream_reader.close() + for task in self._background_tasks: + task.cancel() + + await asyncio.wait(self._background_tasks) + + def _set_first_error(self, err: issues.Error): + try: + self._first_error.set_result(err) + self._state_changed.set() + except asyncio.InvalidStateError: + # skip if already has result + pass + + +class ReaderStream: + _static_id_counter = AtomicCounter() + + _loop: asyncio.AbstractEventLoop + _id: int + _reader_reconnector_id: int + _session_id: str + _stream: Optional[IGrpcWrapperAsyncIO] + _started: bool + _background_tasks: Set[asyncio.Task] + _partition_sessions: Dict[int, datatypes.PartitionSession] + _buffer_size_bytes: int # use for init request, then for debug purposes only + _decode_executor: concurrent.futures.Executor + _decoders: Dict[ + int, typing.Callable[[bytes], bytes] + ] # dict[codec_code] func(encoded_bytes)->decoded_bytes + + if typing.TYPE_CHECKING: + _batches_to_decode: asyncio.Queue[datatypes.PublicBatch] + else: + _batches_to_decode: asyncio.Queue + + _state_changed: asyncio.Event + _closed: bool + _message_batches: typing.Deque[datatypes.PublicBatch] + _first_error: asyncio.Future[YdbError] + + _update_token_interval: Union[int, float] + _update_token_event: asyncio.Event + _get_token_function: Callable[[], str] + + def __init__( + self, + reader_reconnector_id: int, + settings: topic_reader.PublicReaderSettings, + get_token_function: Optional[Callable[[], str]] = None, + ): + self._loop = asyncio.get_running_loop() + self._id = ReaderStream._static_id_counter.inc_and_get() + self._reader_reconnector_id = reader_reconnector_id + self._session_id = "not initialized" + self._stream = None + self._started = False + self._background_tasks = set() + self._partition_sessions = dict() + self._buffer_size_bytes = settings.buffer_size_bytes + self._decode_executor = settings.decoder_executor + + self._decoders = {Codec.CODEC_GZIP: gzip.decompress} + if settings.decoders: + self._decoders.update(settings.decoders) + + self._state_changed = asyncio.Event() + self._closed = False + self._first_error = asyncio.get_running_loop().create_future() + self._batches_to_decode = asyncio.Queue() + self._message_batches = deque() + + self._update_token_interval = settings.update_token_interval + self._get_token_function = get_token_function + self._update_token_event = asyncio.Event() + + @staticmethod + async def create( + reader_reconnector_id: int, + driver: SupportedDriverType, + settings: topic_reader.PublicReaderSettings, + ) -> "ReaderStream": + stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto) + + await stream.start( + driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead + ) + + creds = driver._credentials + reader = ReaderStream( + reader_reconnector_id, + settings, + get_token_function=creds.get_auth_token if creds else None, + ) + await reader._start(stream, settings._init_message()) + return reader + + async def _start( + self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest + ): + if self._started: + raise TopicReaderError("Double start ReaderStream") + + self._started = True + self._stream = stream + + stream.write(StreamReadMessage.FromClient(client_message=init_message)) + init_response = await stream.receive() # type: StreamReadMessage.FromServer + if isinstance(init_response.server_message, StreamReadMessage.InitResponse): + self._session_id = init_response.server_message.session_id + else: + raise TopicReaderError( + "Unexpected message after InitRequest: %s", init_response + ) + + self._update_token_event.set() + + self._background_tasks.add( + asyncio.create_task(self._read_messages_loop(), name="read_messages_loop") + ) + self._background_tasks.add(asyncio.create_task(self._decode_batches_loop())) + if self._get_token_function: + self._background_tasks.add( + asyncio.create_task(self._update_token_loop(), name="update_token_loop") + ) + + async def wait_error(self): + raise await self._first_error + + async def wait_messages(self): + while True: + if self._get_first_error(): + raise self._get_first_error() + + if self._message_batches: + return + + await self._state_changed.wait() + self._state_changed.clear() + + def receive_batch_nowait(self): + if self._get_first_error(): + raise self._get_first_error() + + if not self._message_batches: + return None + + batch = self._message_batches.popleft() + self._buffer_release_bytes(batch._bytes_size) + return batch + + def receive_message_nowait(self): + try: + batch = self._message_batches[0] + message = batch.pop_message() + except IndexError: + return None + + if batch.empty(): + self._message_batches.popleft() + + return message + + def commit( + self, batch: datatypes.ICommittable + ) -> datatypes.PartitionSession.CommitAckWaiter: + partition_session = batch._commit_get_partition_session() + + if partition_session.reader_reconnector_id != self._reader_reconnector_id: + raise TopicReaderError("reader can commit only self-produced messages") + + if partition_session.reader_stream_id != self._id: + raise TopicReaderCommitToExpiredPartition( + "commit messages after reconnect to server" + ) + + if partition_session.id not in self._partition_sessions: + raise TopicReaderCommitToExpiredPartition( + "commit messages after server stop the partition read session" + ) + + commit_range = batch._commit_get_offsets_range() + waiter = partition_session.add_waiter(commit_range.end) + + if not waiter.future.done(): + client_message = StreamReadMessage.CommitOffsetRequest( + commit_offsets=[ + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=partition_session.id, + offsets=[commit_range], + ) + ] + ) + self._stream.write( + StreamReadMessage.FromClient(client_message=client_message) + ) + + return waiter + + async def _read_messages_loop(self): + try: + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.ReadRequest( + bytes_size=self._buffer_size_bytes, + ), + ) + ) + while True: + message = ( + await self._stream.receive() + ) # type: StreamReadMessage.FromServer + _process_response(message.server_status) + + if isinstance(message.server_message, StreamReadMessage.ReadResponse): + self._on_read_response(message.server_message) + + elif isinstance( + message.server_message, StreamReadMessage.CommitOffsetResponse + ): + self._on_commit_response(message.server_message) + + elif isinstance( + message.server_message, + StreamReadMessage.StartPartitionSessionRequest, + ): + self._on_start_partition_session(message.server_message) + + elif isinstance( + message.server_message, + StreamReadMessage.StopPartitionSessionRequest, + ): + self._on_partition_session_stop(message.server_message) + + elif isinstance(message.server_message, UpdateTokenResponse): + self._update_token_event.set() + + else: + raise NotImplementedError( + "Unexpected type of StreamReadMessage.FromServer message: %s" + % message.server_message + ) + + self._state_changed.set() + except Exception as e: + self._set_first_error(e) + raise + + async def _update_token_loop(self): + while True: + await asyncio.sleep(self._update_token_interval) + await self._update_token(token=self._get_token_function()) + + async def _update_token(self, token: str): + await self._update_token_event.wait() + try: + msg = StreamReadMessage.FromClient(UpdateTokenRequest(token)) + self._stream.write(msg) + finally: + self._update_token_event.clear() + + def _on_start_partition_session( + self, message: StreamReadMessage.StartPartitionSessionRequest + ): + try: + if ( + message.partition_session.partition_session_id + in self._partition_sessions + ): + raise TopicReaderError( + "Double start partition session: %s" + % message.partition_session.partition_session_id + ) + + self._partition_sessions[ + message.partition_session.partition_session_id + ] = datatypes.PartitionSession( + id=message.partition_session.partition_session_id, + state=datatypes.PartitionSession.State.Active, + topic_path=message.partition_session.path, + partition_id=message.partition_session.partition_id, + committed_offset=message.committed_offset, + reader_reconnector_id=self._reader_reconnector_id, + reader_stream_id=self._id, + ) + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.StartPartitionSessionResponse( + partition_session_id=message.partition_session.partition_session_id, + read_offset=None, + commit_offset=None, + ) + ), + ) + except YdbError as err: + self._set_first_error(err) + + def _on_partition_session_stop( + self, message: StreamReadMessage.StopPartitionSessionRequest + ): + if message.partition_session_id not in self._partition_sessions: + # may if receive stop partition with graceful=false after response on stop partition + # with graceful=true and remove partition from internal dictionary + return + + partition = self._partition_sessions.pop(message.partition_session_id) + partition.close() + + if message.graceful: + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.StopPartitionSessionResponse( + partition_session_id=message.partition_session_id, + ) + ) + ) + + def _on_read_response(self, message: StreamReadMessage.ReadResponse): + self._buffer_consume_bytes(message.bytes_size) + + batches = self._read_response_to_batches(message) + for batch in batches: + self._batches_to_decode.put_nowait(batch) + + def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse): + for partition_offset in message.partitions_committed_offsets: + if partition_offset.partition_session_id not in self._partition_sessions: + continue + + session = self._partition_sessions[partition_offset.partition_session_id] + session.ack_notify(partition_offset.committed_offset) + + def _buffer_consume_bytes(self, bytes_size): + self._buffer_size_bytes -= bytes_size + + def _buffer_release_bytes(self, bytes_size): + self._buffer_size_bytes += bytes_size + self._stream.write( + StreamReadMessage.FromClient( + client_message=StreamReadMessage.ReadRequest( + bytes_size=bytes_size, + ) + ) + ) + + def _read_response_to_batches( + self, message: StreamReadMessage.ReadResponse + ) -> typing.List[datatypes.PublicBatch]: + batches = [] + + batch_count = sum(len(p.batches) for p in message.partition_data) + if batch_count == 0: + return batches + + bytes_per_batch = message.bytes_size // batch_count + additional_bytes_to_last_batch = ( + message.bytes_size - bytes_per_batch * batch_count + ) + + for partition_data in message.partition_data: + partition_session = self._partition_sessions[ + partition_data.partition_session_id + ] + for server_batch in partition_data.batches: + messages = [] + for message_data in server_batch.message_data: + mess = datatypes.PublicMessage( + seqno=message_data.seq_no, + created_at=message_data.created_at, + message_group_id=message_data.message_group_id, + session_metadata=server_batch.write_session_meta, + offset=message_data.offset, + written_at=server_batch.written_at, + producer_id=server_batch.producer_id, + data=message_data.data, + _partition_session=partition_session, + _commit_start_offset=partition_session._next_message_start_commit_offset, + _commit_end_offset=message_data.offset + 1, + ) + messages.append(mess) + partition_session._next_message_start_commit_offset = ( + mess._commit_end_offset + ) + + if messages: + batch = datatypes.PublicBatch( + messages=messages, + _partition_session=partition_session, + _bytes_size=bytes_per_batch, + _codec=Codec(server_batch.codec), + ) + batches.append(batch) + + batches[-1]._bytes_size += additional_bytes_to_last_batch + return batches + + async def _decode_batches_loop(self): + while True: + batch = await self._batches_to_decode.get() + await self._decode_batch_inplace(batch) + self._message_batches.append(batch) + self._state_changed.set() + + async def _decode_batch_inplace(self, batch): + if batch._codec == Codec.CODEC_RAW: + return + + try: + decode_func = self._decoders[batch._codec] + except KeyError: + raise TopicReaderUnexpectedCodec( + "Receive message with unexpected codec: %s" % batch._codec + ) + + decode_data_futures = [] + for message in batch.messages: + future = self._loop.run_in_executor( + self._decode_executor, decode_func, message.data + ) + decode_data_futures.append(future) + + decoded_data = await asyncio.gather(*decode_data_futures) + for index, message in enumerate(batch.messages): + message.data = decoded_data[index] + + batch._codec = Codec.CODEC_RAW + + def _set_first_error(self, err: YdbError): + try: + self._first_error.set_result(err) + self._state_changed.set() + except asyncio.InvalidStateError: + # skip later set errors + pass + + def _get_first_error(self) -> Optional[YdbError]: + if self._first_error.done(): + return self._first_error.result() + + async def close(self): + if self._closed: + return + self._closed = True + + self._set_first_error(TopicReaderStreamClosedError()) + self._state_changed.set() + self._stream.close() + + for session in self._partition_sessions.values(): + session.close() + + for task in self._background_tasks: + task.cancel() + await asyncio.wait(self._background_tasks) diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py new file mode 100644 index 00000000..0134f38b --- /dev/null +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -0,0 +1,1197 @@ +import asyncio +import concurrent.futures +import copy +import datetime +import gzip +import typing +from collections import deque +from dataclasses import dataclass +from unittest import mock + +import pytest + +from ydb import issues +from . import datatypes, topic_reader_asyncio +from .datatypes import PublicBatch, PublicMessage +from .topic_reader import PublicReaderSettings +from .topic_reader_asyncio import ReaderStream, ReaderReconnector +from .._grpc.grpcwrapper.common_utils import SupportedDriverType, ServerStatus +from .._grpc.grpcwrapper.ydb_topic import ( + StreamReadMessage, + Codec, + OffsetsRange, + UpdateTokenRequest, + UpdateTokenResponse, +) +from .._topic_common.test_helpers import ( + StreamMock, + wait_condition, + wait_for_fast, + WaitConditionError, +) + +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from .._grpc.v4.protos import ydb_status_codes_pb2 +else: + from .._grpc.common.protos import ydb_status_codes_pb2 + + +@pytest.fixture(autouse=True) +def handle_exceptions(event_loop): + def handler(loop, context): + print(context) + + event_loop.set_exception_handler(handler) + + +@pytest.fixture() +def default_executor(): + executor = concurrent.futures.ThreadPoolExecutor( + max_workers=2, thread_name_prefix="decoder_executor" + ) + yield executor + executor.shutdown() + + +def stub_partition_session(): + return datatypes.PartitionSession( + id=0, + state=datatypes.PartitionSession.State.Active, + topic_path="asd", + partition_id=1, + committed_offset=0, + reader_reconnector_id=415, + reader_stream_id=513, + ) + + +def stub_message(id: int): + return PublicMessage( + seqno=id, + created_at=datetime.datetime(2023, 3, 18, 14, 15), + message_group_id="", + session_metadata={}, + offset=0, + written_at=datetime.datetime(2023, 3, 18, 14, 15), + producer_id="", + data=bytes(), + _partition_session=stub_partition_session(), + _commit_start_offset=0, + _commit_end_offset=1, + ) + + +@pytest.fixture() +def default_reader_settings(default_executor): + return PublicReaderSettings( + consumer="test-consumer", + topic="test-topic", + decoder_executor=default_executor, + ) + + +class StreamMockForReader(StreamMock): + async def receive(self) -> StreamReadMessage.FromServer: + return await super(self).receive() + + def write(self, message: StreamReadMessage.FromClient): + return super().write(message) + + +@pytest.mark.asyncio +class TestReaderStream: + default_batch_size = 1 + partition_session_id = 2 + partition_session_committed_offset = 10 + second_partition_session_id = 12 + second_partition_session_offset = 50 + default_reader_reconnector_id = 4 + + @pytest.fixture() + def stream(self): + return StreamMock() + + @pytest.fixture() + def partition_session( + self, default_reader_settings, stream_reader_started: ReaderStream + ) -> datatypes.PartitionSession: + partition_session = datatypes.PartitionSession( + id=2, + topic_path=default_reader_settings.topic, + partition_id=4, + state=datatypes.PartitionSession.State.Active, + committed_offset=self.partition_session_committed_offset, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=stream_reader_started._id, + ) + + assert partition_session.id not in stream_reader_started._partition_sessions + stream_reader_started._partition_sessions[ + partition_session.id + ] = partition_session + + return stream_reader_started._partition_sessions[partition_session.id] + + @pytest.fixture() + def second_partition_session( + self, default_reader_settings, stream_reader_started: ReaderStream + ): + partition_session = datatypes.PartitionSession( + id=12, + topic_path=default_reader_settings.topic, + partition_id=10, + state=datatypes.PartitionSession.State.Active, + committed_offset=self.second_partition_session_offset, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=stream_reader_started._id, + ) + + assert partition_session.id not in stream_reader_started._partition_sessions + stream_reader_started._partition_sessions[ + partition_session.id + ] = partition_session + + return stream_reader_started._partition_sessions[partition_session.id] + + async def get_started_reader(self, stream, *args, **kwargs) -> ReaderStream: + reader = ReaderStream(self.default_reader_reconnector_id, *args, **kwargs) + init_message = object() + + # noinspection PyTypeChecker + start = asyncio.create_task(reader._start(stream, init_message)) + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.InitResponse( + session_id="test-session" + ), + ) + ) + + init_request = await wait_for_fast(stream.from_client.get()) + assert init_request.client_message == init_message + + read_request = await wait_for_fast(stream.from_client.get()) + assert isinstance(read_request.client_message, StreamReadMessage.ReadRequest) + + await start + + await asyncio.sleep(0) + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + return reader + + @pytest.fixture() + async def stream_reader_started( + self, stream, default_reader_settings + ) -> ReaderStream: + return await self.get_started_reader(stream, default_reader_settings) + + @pytest.fixture() + async def stream_reader(self, stream_reader_started: ReaderStream): + yield stream_reader_started + + assert stream_reader_started._get_first_error() is None + await stream_reader_started.close() + + @pytest.fixture() + async def stream_reader_finish_with_error( + self, stream_reader_started: ReaderStream + ): + yield stream_reader_started + + assert stream_reader_started._get_first_error() is not None + await stream_reader_started.close() + + @staticmethod + def create_message( + partition_session: typing.Optional[datatypes.PartitionSession], + seqno: int, + offset_delta: int, + ): + return PublicMessage( + seqno=seqno, + created_at=datetime.datetime(2023, 2, 3, 14, 15), + message_group_id="test-message-group", + session_metadata={}, + offset=partition_session._next_message_start_commit_offset + + offset_delta + - 1, + written_at=datetime.datetime(2023, 2, 3, 14, 16), + producer_id="test-producer-id", + data=bytes(), + _partition_session=partition_session, + _commit_start_offset=partition_session._next_message_start_commit_offset + + offset_delta + - 1, + _commit_end_offset=partition_session._next_message_start_commit_offset + + offset_delta, + ) + + async def send_message(self, stream_reader, message: PublicMessage): + def batch_count(): + return len(stream_reader._message_batches) + + initial_batches = batch_count() + + stream = stream_reader._stream # type: StreamMock + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.ReadResponse( + partition_data=[ + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=message._partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=message.offset, + seq_no=message.seqno, + created_at=message.created_at, + data=message.data, + uncompresed_size=len(message.data), + message_group_id=message.message_group_id, + ) + ], + producer_id=message.producer_id, + write_session_meta=message.session_metadata, + codec=Codec.CODEC_RAW, + written_at=message.written_at, + ) + ], + ) + ], + bytes_size=self.default_batch_size, + ), + ) + ) + await wait_condition(lambda: batch_count() > initial_batches) + + async def test_unknown_error(self, stream, stream_reader_finish_with_error): + class TestError(Exception): + pass + + test_err = TestError() + stream.from_server.put_nowait(test_err) + + with pytest.raises(TestError): + await wait_for_fast(stream_reader_finish_with_error.wait_messages()) + + with pytest.raises(TestError): + stream_reader_finish_with_error.receive_batch_nowait() + + @pytest.mark.parametrize( + "commit,send_range", + [ + ( + OffsetsRange( + partition_session_committed_offset, + partition_session_committed_offset + 1, + ), + True, + ), + ( + OffsetsRange( + partition_session_committed_offset - 1, + partition_session_committed_offset, + ), + False, + ), + ], + ) + async def test_send_commit_messages( + self, + stream, + stream_reader: ReaderStream, + partition_session, + commit: OffsetsRange, + send_range: bool, + ): + @dataclass + class Commitable(datatypes.ICommittable): + start: int + end: int + + def _commit_get_partition_session(self) -> datatypes.PartitionSession: + return partition_session + + def _commit_get_offsets_range(self) -> OffsetsRange: + return OffsetsRange(self.start, self.end) + + start_ack_waiters = partition_session._ack_waiters.copy() + + waiter = stream_reader.commit(Commitable(commit.start, commit.end)) + + async def wait_message(): + return await wait_for_fast(stream.from_client.get(), timeout=0) + + if send_range: + msg = await wait_message() # type: StreamReadMessage.FromClient + assert msg.client_message == StreamReadMessage.CommitOffsetRequest( + commit_offsets=[ + StreamReadMessage.CommitOffsetRequest.PartitionCommitOffset( + partition_session_id=partition_session.id, + offsets=[commit], + ) + ] + ) + assert partition_session._ack_waiters[-1].end_offset == commit.end + else: + assert waiter.future.done() + + with pytest.raises(WaitConditionError): + msg = await wait_message() + pass + assert start_ack_waiters == partition_session._ack_waiters + + async def test_commit_ack_received( + self, stream_reader, stream, partition_session, second_partition_session + ): + offset1 = self.partition_session_committed_offset + 1 + waiter1 = partition_session.add_waiter(offset1) + + offset2 = self.second_partition_session_offset + 2 + waiter2 = second_partition_session.add_waiter(offset2) + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.CommitOffsetResponse( + partitions_committed_offsets=[ + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=partition_session.id, + committed_offset=offset1, + ), + StreamReadMessage.CommitOffsetResponse.PartitionCommittedOffset( + partition_session_id=second_partition_session.id, + committed_offset=offset2, + ), + ] + ), + ) + ) + + await wait_for_fast(waiter1.future) + await wait_for_fast(waiter2.future) + + async def test_close_ack_waiters_when_close_stream_reader( + self, stream_reader_started: ReaderStream, partition_session + ): + waiter = partition_session.add_waiter( + self.partition_session_committed_offset + 1 + ) + await wait_for_fast(stream_reader_started.close()) + + with pytest.raises(topic_reader_asyncio.TopicReaderCommitToExpiredPartition): + waiter.future.result() + + async def test_commit_ranges_for_received_messages( + self, stream, stream_reader_started: ReaderStream, partition_session + ): + m1 = self.create_message(partition_session, 1, 1) + m2 = self.create_message(partition_session, 2, 10) + m2._commit_start_offset = m1.offset + 1 + + await self.send_message(stream_reader_started, m1) + await self.send_message(stream_reader_started, m2) + + await stream_reader_started.wait_messages() + received = stream_reader_started.receive_batch_nowait().messages + assert received == [m1] + + await stream_reader_started.wait_messages() + received = stream_reader_started.receive_batch_nowait().messages + assert received == [m2] + + # noinspection PyTypeChecker + @pytest.mark.parametrize( + "batch,data_out", + [ + ( + PublicBatch( + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"123", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ) + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + [bytes(rb"123")], + ), + ( + PublicBatch( + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"123"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ) + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_GZIP, + ), + [bytes(rb"123")], + ), + ( + PublicBatch( + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"123", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=rb"456", + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + [bytes(rb"123"), bytes(rb"456")], + ), + ( + PublicBatch( + messages=[ + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"123"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + PublicMessage( + seqno=1, + created_at=datetime.datetime(2023, 3, 14, 15, 41), + message_group_id="", + session_metadata={}, + offset=1, + written_at=datetime.datetime(2023, 3, 14, 15, 42), + producer_id="asd", + data=gzip.compress(rb"456"), + _partition_session=None, + _commit_start_offset=5, + _commit_end_offset=15, + ), + ], + _partition_session=None, + _bytes_size=0, + _codec=Codec.CODEC_GZIP, + ), + [bytes(rb"123"), bytes(rb"456")], + ), + ], + ) + async def test_decode_loop( + self, stream_reader, batch: PublicBatch, data_out: typing.List[bytes] + ): + assert len(batch.messages) == len(data_out) + + expected = copy.deepcopy(batch) + expected._codec = Codec.CODEC_RAW + + for index, data in enumerate(data_out): + expected.messages[index].data = data + + await wait_for_fast(stream_reader._decode_batch_inplace(batch)) + + assert batch == expected + + async def test_error_from_status_code( + self, stream, stream_reader_finish_with_error + ): + # noinspection PyTypeChecker + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus( + status=issues.StatusCode.OVERLOADED, + issues=[], + ), + server_message=None, + ) + ) + + with pytest.raises(issues.Overloaded): + await wait_for_fast(stream_reader_finish_with_error.wait_messages()) + + with pytest.raises(issues.Overloaded): + stream_reader_finish_with_error.receive_batch_nowait() + + async def test_init_reader(self, stream, default_reader_settings): + reader = ReaderStream( + self.default_reader_reconnector_id, default_reader_settings + ) + init_message = StreamReadMessage.InitRequest( + consumer="test-consumer", + topics_read_settings=[ + StreamReadMessage.InitRequest.TopicReadSettings( + path="/local/test-topic", + partition_ids=[], + max_lag_seconds=None, + read_from=None, + ) + ], + ) + start_task = asyncio.create_task(reader._start(stream, init_message)) + + sent_message = await wait_for_fast(stream.from_client.get()) + expected_sent_init_message = StreamReadMessage.FromClient( + client_message=init_message + ) + assert sent_message == expected_sent_init_message + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.InitResponse(session_id="test"), + ) + ) + + await start_task + + read_request = await wait_for_fast(stream.from_client.get()) + assert read_request.client_message == StreamReadMessage.ReadRequest( + bytes_size=default_reader_settings.buffer_size_bytes, + ) + + assert reader._session_id == "test" + await reader.close() + + async def test_start_partition( + self, + stream_reader: ReaderStream, + stream, + default_reader_settings, + partition_session, + ): + def session_count(): + return len(stream_reader._partition_sessions) + + initial_session_count = session_count() + + test_partition_id = partition_session.partition_id + 1 + test_partition_session_id = partition_session.id + 1 + test_topic_path = default_reader_settings.topic + "-asd" + test_partition_committed_offset = 18 + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StartPartitionSessionRequest( + partition_session=StreamReadMessage.PartitionSession( + partition_session_id=test_partition_session_id, + path=test_topic_path, + partition_id=test_partition_id, + ), + committed_offset=test_partition_committed_offset, + partition_offsets=OffsetsRange( + start=0, + end=0, + ), + ), + ), + ) + response = await wait_for_fast(stream.from_client.get()) + assert response == StreamReadMessage.FromClient( + client_message=StreamReadMessage.StartPartitionSessionResponse( + partition_session_id=test_partition_session_id, + read_offset=None, + commit_offset=None, + ) + ) + + assert len(stream_reader._partition_sessions) == initial_session_count + 1 + assert stream_reader._partition_sessions[ + test_partition_session_id + ] == datatypes.PartitionSession( + id=test_partition_session_id, + state=datatypes.PartitionSession.State.Active, + topic_path=test_topic_path, + partition_id=test_partition_id, + committed_offset=test_partition_committed_offset, + reader_reconnector_id=self.default_reader_reconnector_id, + reader_stream_id=stream_reader._id, + ) + + async def test_partition_stop_force(self, stream, stream_reader, partition_session): + def session_count(): + return len(stream_reader._partition_sessions) + + initial_session_count = session_count() + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=False, + committed_offset=0, + ), + ) + ) + + await asyncio.sleep(0) # wait next loop + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + assert session_count() == initial_session_count - 1 + assert partition_session.id not in stream_reader._partition_sessions + + async def test_partition_stop_graceful( + self, stream, stream_reader, partition_session + ): + def session_count(): + return len(stream_reader._partition_sessions) + + initial_session_count = session_count() + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=True, + committed_offset=0, + ), + ) + ) + + resp = await wait_for_fast( + stream.from_client.get() + ) # type: StreamReadMessage.FromClient + assert session_count() == initial_session_count - 1 + assert partition_session.id not in stream_reader._partition_sessions + assert resp.client_message == StreamReadMessage.StopPartitionSessionResponse( + partition_session_id=partition_session.id + ) + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.StopPartitionSessionRequest( + partition_session_id=partition_session.id, + graceful=False, + committed_offset=0, + ), + ) + ) + + await asyncio.sleep(0) # wait next loop + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + async def test_receive_message_from_server( + self, + stream_reader, + stream, + partition_session: datatypes.PartitionSession, + second_partition_session, + ): + def reader_batch_count(): + return len(stream_reader._message_batches) + + initial_buffer_size = stream_reader._buffer_size_bytes + initial_batch_count = reader_batch_count() + + bytes_size = 10 + created_at = datetime.datetime(2020, 1, 1, 18, 12) + written_at = datetime.datetime(2023, 2, 1, 18, 12) + producer_id = "test-producer-id" + data = "123".encode() + session_meta = {"a": "b"} + message_group_id = "test-message-group-id" + + expected_message_offset = partition_session.committed_offset + + stream.from_server.put_nowait( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=StreamReadMessage.ReadResponse( + bytes_size=bytes_size, + partition_data=[ + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=expected_message_offset, + seq_no=2, + created_at=created_at, + data=data, + uncompresed_size=len(data), + message_group_id=message_group_id, + ) + ], + producer_id=producer_id, + write_session_meta=session_meta, + codec=Codec.CODEC_RAW, + written_at=written_at, + ) + ], + ) + ], + ), + ) + ), + + await wait_condition(lambda: reader_batch_count() == initial_batch_count + 1) + + assert stream_reader._buffer_size_bytes == initial_buffer_size - bytes_size + + last_batch = stream_reader._message_batches[-1] + assert last_batch == PublicBatch( + messages=[ + PublicMessage( + seqno=2, + created_at=created_at, + message_group_id=message_group_id, + session_metadata=session_meta, + offset=expected_message_offset, + written_at=written_at, + producer_id=producer_id, + data=data, + _partition_session=partition_session, + _commit_start_offset=expected_message_offset, + _commit_end_offset=expected_message_offset + 1, + ) + ], + _partition_session=partition_session, + _bytes_size=bytes_size, + _codec=Codec.CODEC_RAW, + ) + + async def test_read_batches( + self, stream_reader, partition_session, second_partition_session + ): + created_at = datetime.datetime(2020, 2, 1, 18, 12) + created_at2 = datetime.datetime(2020, 2, 2, 18, 12) + created_at3 = datetime.datetime(2020, 2, 3, 18, 12) + created_at4 = datetime.datetime(2020, 2, 4, 18, 12) + written_at = datetime.datetime(2023, 3, 1, 18, 12) + written_at2 = datetime.datetime(2023, 3, 2, 18, 12) + producer_id = "test-producer-id" + producer_id2 = "test-producer-id" + data = "123".encode() + data2 = "1235".encode() + session_meta = {"a": "b"} + session_meta2 = {"b": "c"} + + message_group_id = "test-message-group-id" + message_group_id2 = "test-message-group-id-2" + + partition1_mess1_expected_offset = partition_session.committed_offset + partition2_mess1_expected_offset = second_partition_session.committed_offset + partition2_mess2_expected_offset = second_partition_session.committed_offset + 1 + partition2_mess3_expected_offset = second_partition_session.committed_offset + 2 + + batches = stream_reader._read_response_to_batches( + StreamReadMessage.ReadResponse( + bytes_size=3, + partition_data=[ + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=partition1_mess1_expected_offset, + seq_no=3, + created_at=created_at, + data=data, + uncompresed_size=len(data), + message_group_id=message_group_id, + ) + ], + producer_id=producer_id, + write_session_meta=session_meta, + codec=Codec.CODEC_RAW, + written_at=written_at, + ) + ], + ), + StreamReadMessage.ReadResponse.PartitionData( + partition_session_id=second_partition_session.id, + batches=[ + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=partition2_mess1_expected_offset, + seq_no=2, + created_at=created_at2, + data=data, + uncompresed_size=len(data), + message_group_id=message_group_id, + ) + ], + producer_id=producer_id, + write_session_meta=session_meta, + codec=Codec.CODEC_RAW, + written_at=written_at2, + ), + StreamReadMessage.ReadResponse.Batch( + message_data=[ + StreamReadMessage.ReadResponse.MessageData( + offset=partition2_mess2_expected_offset, + seq_no=3, + created_at=created_at3, + data=data2, + uncompresed_size=len(data2), + message_group_id=message_group_id, + ), + StreamReadMessage.ReadResponse.MessageData( + offset=partition2_mess3_expected_offset, + seq_no=5, + created_at=created_at4, + data=data, + uncompresed_size=len(data), + message_group_id=message_group_id2, + ), + ], + producer_id=producer_id2, + write_session_meta=session_meta2, + codec=Codec.CODEC_RAW, + written_at=written_at2, + ), + ], + ), + ], + ) + ) + + last0 = batches[0] + last1 = batches[1] + last2 = batches[2] + + assert last0 == PublicBatch( + messages=[ + PublicMessage( + seqno=3, + created_at=created_at, + message_group_id=message_group_id, + session_metadata=session_meta, + offset=partition1_mess1_expected_offset, + written_at=written_at, + producer_id=producer_id, + data=data, + _partition_session=partition_session, + _commit_start_offset=partition1_mess1_expected_offset, + _commit_end_offset=partition1_mess1_expected_offset + 1, + ) + ], + _partition_session=partition_session, + _bytes_size=1, + _codec=Codec.CODEC_RAW, + ) + assert last1 == PublicBatch( + messages=[ + PublicMessage( + seqno=2, + created_at=created_at2, + message_group_id=message_group_id, + session_metadata=session_meta, + offset=partition2_mess1_expected_offset, + written_at=written_at2, + producer_id=producer_id, + data=data, + _partition_session=second_partition_session, + _commit_start_offset=partition2_mess1_expected_offset, + _commit_end_offset=partition2_mess1_expected_offset + 1, + ) + ], + _partition_session=second_partition_session, + _bytes_size=1, + _codec=Codec.CODEC_RAW, + ) + assert last2 == PublicBatch( + messages=[ + PublicMessage( + seqno=3, + created_at=created_at3, + message_group_id=message_group_id, + session_metadata=session_meta2, + offset=partition2_mess2_expected_offset, + written_at=written_at2, + producer_id=producer_id2, + data=data2, + _partition_session=second_partition_session, + _commit_start_offset=partition2_mess2_expected_offset, + _commit_end_offset=partition2_mess2_expected_offset + 1, + ), + PublicMessage( + seqno=5, + created_at=created_at4, + message_group_id=message_group_id2, + session_metadata=session_meta2, + offset=partition2_mess3_expected_offset, + written_at=written_at2, + producer_id=producer_id, + data=data, + _partition_session=second_partition_session, + _commit_start_offset=partition2_mess3_expected_offset, + _commit_end_offset=partition2_mess3_expected_offset + 1, + ), + ], + _partition_session=second_partition_session, + _bytes_size=1, + _codec=Codec.CODEC_RAW, + ) + + @pytest.mark.parametrize( + "batches_before,expected_message,batches_after", + [ + ([], None, []), + ( + [ + PublicBatch( + messages=[stub_message(1)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ) + ], + stub_message(1), + [], + ), + ( + [ + PublicBatch( + messages=[stub_message(1), stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], + stub_message(1), + [ + PublicBatch( + messages=[stub_message(2)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + messages=[stub_message(3), stub_message(4)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], + ), + ( + [ + PublicBatch( + messages=[stub_message(1)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + PublicBatch( + messages=[stub_message(2), stub_message(3)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ), + ], + stub_message(1), + [ + PublicBatch( + messages=[stub_message(2), stub_message(3)], + _partition_session=stub_partition_session(), + _bytes_size=0, + _codec=Codec.CODEC_RAW, + ) + ], + ), + ], + ) + async def test_read_message( + self, + stream_reader, + batches_before: typing.List[datatypes.PublicBatch], + expected_message: PublicMessage, + batches_after: typing.List[datatypes.PublicBatch], + ): + stream_reader._message_batches = deque(batches_before) + mess = stream_reader.receive_message_nowait() + + assert mess == expected_message + assert list(stream_reader._message_batches) == batches_after + + async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): + assert stream_reader.receive_batch_nowait() is None + + mess1 = self.create_message(partition_session, 1, 1) + await self.send_message(stream_reader, mess1) + + mess2 = self.create_message(partition_session, 2, 1) + await self.send_message(stream_reader, mess2) + + initial_buffer_size = stream_reader._buffer_size_bytes + + received = stream_reader.receive_batch_nowait() + assert received == PublicBatch( + messages=[mess1], + _partition_session=mess1._partition_session, + _bytes_size=self.default_batch_size, + _codec=Codec.CODEC_RAW, + ) + + received = stream_reader.receive_batch_nowait() + assert received == PublicBatch( + messages=[mess2], + _partition_session=mess2._partition_session, + _bytes_size=self.default_batch_size, + _codec=Codec.CODEC_RAW, + ) + + assert ( + stream_reader._buffer_size_bytes + == initial_buffer_size + 2 * self.default_batch_size + ) + + assert ( + StreamReadMessage.ReadRequest(self.default_batch_size) + == stream.from_client.get_nowait().client_message + ) + assert ( + StreamReadMessage.ReadRequest(self.default_batch_size) + == stream.from_client.get_nowait().client_message + ) + + with pytest.raises(asyncio.QueueEmpty): + stream.from_client.get_nowait() + + async def test_update_token(self, stream): + settings = PublicReaderSettings( + consumer="test-consumer", + topic="test-topic", + update_token_interval=0.1, + ) + reader = await self.get_started_reader( + stream, settings, get_token_function=lambda: "foo-bar" + ) + + assert stream.from_client.empty() + + expected = StreamReadMessage.FromClient(UpdateTokenRequest(token="foo-bar")) + got = await wait_for_fast(stream.from_client.get()) + assert expected == got, "send update token request" + + await asyncio.sleep(0.2) + assert stream.from_client.empty(), "no answer - no new update request" + + await stream.from_server.put( + StreamReadMessage.FromServer( + server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []), + server_message=UpdateTokenResponse(), + ) + ) + + got = await wait_for_fast(stream.from_client.get()) + assert expected == got + + await reader.close() + + +@pytest.mark.asyncio +class TestReaderReconnector: + async def test_reconnect_on_repeatable_error(self, monkeypatch): + test_error = issues.Overloaded("test error") + + async def wait_error(): + raise test_error + + reader_stream_mock_with_error = mock.Mock(ReaderStream) + reader_stream_mock_with_error.wait_error = mock.AsyncMock( + side_effect=wait_error + ) + + async def wait_messages(): + raise test_error + + reader_stream_mock_with_error.wait_messages = mock.AsyncMock( + side_effect=wait_messages + ) + + reader_stream_with_messages = mock.Mock(ReaderStream) + reader_stream_with_messages.wait_error.return_value = asyncio.Future() + reader_stream_with_messages.wait_messages.return_value = None + + stream_index = 0 + + async def stream_create( + reader_reconnector_id: int, + driver: SupportedDriverType, + settings: PublicReaderSettings, + ): + nonlocal stream_index + stream_index += 1 + if stream_index == 1: + return reader_stream_mock_with_error + elif stream_index == 2: + return reader_stream_with_messages + else: + raise Exception("unexpected create stream") + + with mock.patch.object(ReaderStream, "create", stream_create): + reconnector = ReaderReconnector(mock.Mock(), PublicReaderSettings("", "")) + await wait_for_fast(reconnector.wait_message()) + + reader_stream_mock_with_error.wait_error.assert_any_await() + reader_stream_mock_with_error.wait_messages.assert_any_await() diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py new file mode 100644 index 00000000..cea7e36c --- /dev/null +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -0,0 +1,165 @@ +import asyncio +import concurrent.futures +import typing +from typing import List, Union, Optional + +from ydb._grpc.grpcwrapper.common_utils import SupportedDriverType +from ydb._topic_common.common import ( + _get_shared_event_loop, + CallFromSyncToAsync, + TimeoutType, +) +from ydb._topic_reader import datatypes +from ydb._topic_reader.datatypes import PublicBatch +from ydb._topic_reader.topic_reader import ( + PublicReaderSettings, + CommitResult, +) +from ydb._topic_reader.topic_reader_asyncio import ( + PublicAsyncIOReader, + TopicReaderClosedError, +) + + +class TopicReaderSync: + _caller: CallFromSyncToAsync + _async_reader: PublicAsyncIOReader + _closed: bool + + def __init__( + self, + driver: SupportedDriverType, + settings: PublicReaderSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, + ): + self._closed = False + + if eventloop: + loop = eventloop + else: + loop = _get_shared_event_loop() + + self._caller = CallFromSyncToAsync(loop) + + async def create_reader(): + return PublicAsyncIOReader(driver, settings) + + self._async_reader = asyncio.run_coroutine_threadsafe( + create_reader(), loop + ).result() + + def __del__(self): + self.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def receive_message( + self, *, timeout: TimeoutType = None + ) -> datatypes.PublicMessage: + """ + Block until receive new message + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + receive_message(timeout=0) may return None even right after async_wait_message() is ok - because lost of partition + or connection to server lost + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. + """ + self._check_closed() + + return self._caller.safe_call_with_result( + self._async_reader.receive_message(), timeout + ) + + def async_wait_message(self) -> concurrent.futures.Future: + """ + Return future, which will completed when the reader has least one message in queue. + If reader already has message - future will return completed. + + Possible situation when receive signal about message available, but no messages when try to receive a message. + If message expired between send event and try to retrieve message (for example connection broken). + """ + self._check_closed() + + return self._caller.unsafe_call_with_future( + self._async_reader._reconnector.wait_message() + ) + + def receive_batch( + self, + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> Union[PublicBatch, None]: + """ + Get one messages batch from reader + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. + """ + self._check_closed() + + return self._caller.safe_call_with_result( + self._async_reader.receive_batch(), + timeout, + ) + + def commit( + self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ): + """ + Put commit message to internal buffer. + + For the method no way check the commit result + (for example if lost connection - commits will not re-send and committed messages will receive again) + """ + self._check_closed() + + self._caller.call_sync(self._async_reader.commit(mess)) + + def commit_with_ack( + self, + mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch], + timeout: TimeoutType = None, + ) -> Union[CommitResult, List[CommitResult]]: + """ + write commit message to a buffer and wait ack from the server. + + if receive in timeout seconds (default - infinite): raise TimeoutError() + """ + self._check_closed() + + return self._caller.unsafe_call_with_result( + self._async_reader.commit_with_ack(mess), timeout + ) + + def async_commit_with_ack( + self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch] + ) -> concurrent.futures.Future: + """ + write commit message to a buffer and return Future for wait result. + """ + self._check_closed() + + return self._caller.unsafe_call_with_future( + self._async_reader.commit_with_ack(mess) + ) + + def close(self, *, timeout: TimeoutType = None): + if self._closed: + return + + self._closed = True + + self._caller.safe_call_with_result(self._async_reader.close(), timeout) + + def _check_closed(self): + if self._closed: + raise TopicReaderClosedError() diff --git a/kikimr/public/sdk/python/__init__.py b/ydb/_topic_writer/__init__.py similarity index 100% rename from kikimr/public/sdk/python/__init__.py rename to ydb/_topic_writer/__init__.py diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py new file mode 100644 index 00000000..b94ff46b --- /dev/null +++ b/ydb/_topic_writer/topic_writer.py @@ -0,0 +1,213 @@ +import concurrent.futures +import datetime +import enum +import uuid +from dataclasses import dataclass +from enum import Enum +from typing import List, Union, Optional, Any, Dict + +import typing + +import ydb.aio +from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage +from .._grpc.grpcwrapper.common_utils import IToProto +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec + +Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] + + +@dataclass +class PublicWriterSettings: + """ + Settings for topic writer. + + order of fields IS NOT stable, use keywords only + """ + + topic: str + producer_id: Optional[str] = None + session_metadata: Optional[Dict[str, str]] = None + partition_id: Optional[int] = None + auto_seqno: bool = True + auto_created_at: bool = True + codec: Optional[PublicCodec] = None # default mean auto-select + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None # default shared client executor pool + encoders: Optional[ + typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]] + ] = None + update_token_interval: Union[int, float] = 3600 + + def __post_init__(self): + if self.producer_id is None: + self.producer_id = uuid.uuid4().hex + + +@dataclass +class PublicWriteResult: + @dataclass(eq=True) + class Written: + __slots__ = "offset" + offset: int + + @dataclass(eq=True) + class Skipped: + pass + + +PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped] + + +class WriterSettings(PublicWriterSettings): + def __init__(self, settings: PublicWriterSettings): + self.__dict__ = settings.__dict__.copy() + + def create_init_request(self) -> StreamWriteMessage.InitRequest: + return StreamWriteMessage.InitRequest( + path=self.topic, + producer_id=self.producer_id, + write_session_meta=self.session_metadata, + partitioning=self.get_partitioning(), + get_last_seq_no=True, + ) + + def get_partitioning(self) -> StreamWriteMessage.PartitioningType: + if self.partition_id is not None: + return StreamWriteMessage.PartitioningPartitionID(self.partition_id) + return StreamWriteMessage.PartitioningMessageGroupID(self.producer_id) + + +class SendMode(Enum): + ASYNC = 1 + SYNC = 2 + + +@dataclass +class PublicWriterInitInfo: + __slots__ = ("last_seqno", "supported_codecs") + last_seqno: Optional[int] + supported_codecs: List[PublicCodec] + + +class PublicMessage: + seqno: Optional[int] + created_at: Optional[datetime.datetime] + data: "PublicMessage.SimpleMessageSourceType" + + SimpleMessageSourceType = Union[str, bytes] # Will be extend + + def __init__( + self, + data: SimpleMessageSourceType, + *, + seqno: Optional[int] = None, + created_at: Optional[datetime.datetime] = None, + ): + self.seqno = seqno + self.created_at = created_at + self.data = data + + @staticmethod + def _create_message(data: Message) -> "PublicMessage": + if isinstance(data, PublicMessage): + return data + return PublicMessage(data=data) + + +class InternalMessage(StreamWriteMessage.WriteRequest.MessageData, IToProto): + codec: PublicCodec + + def __init__(self, mess: PublicMessage): + super().__init__( + seq_no=mess.seqno, + created_at=mess.created_at, + data=mess.data, + uncompressed_size=len(mess.data), + partitioning=None, + ) + self.codec = PublicCodec.RAW + + def get_bytes(self) -> bytes: + if self.data is None: + return bytes() + if isinstance(self.data, bytes): + return self.data + if isinstance(self.data, str): + return self.data.encode("utf-8") + raise ValueError("Bad data type") + + def to_message_data(self) -> StreamWriteMessage.WriteRequest.MessageData: + data = self.get_bytes() + return StreamWriteMessage.WriteRequest.MessageData( + seq_no=self.seq_no, + created_at=self.created_at, + data=data, + uncompressed_size=len(data), + partitioning=None, # unsupported by server now + ) + + +class MessageSendResult: + offset: Optional[int] + write_status: "MessageWriteStatus" + + +class MessageWriteStatus(enum.Enum): + Written = 1 + AlreadyWritten = 2 + + +class RetryPolicy: + connection_timeout_sec: float + overload_timeout_sec: float + retry_access_denied: bool = False + + +class TopicWriterError(ydb.Error): + def __init__(self, message: str): + super(TopicWriterError, self).__init__(message) + + +class TopicWriterClosedError(ydb.Error): + def __init__(self): + super().__init__("Topic writer already closed") + + +class TopicWriterRepeatableError(TopicWriterError): + pass + + +class TopicWriterStopped(TopicWriterError): + def __init__(self): + super(TopicWriterStopped, self).__init__( + "topic writer was stopped by call close" + ) + + +def default_serializer_message_content(data: Any) -> bytes: + if data is None: + return bytes() + if isinstance(data, bytes): + return data + if isinstance(data, bytearray): + return bytes(data) + if isinstance(data, str): + return data.encode(encoding="utf-8") + raise ValueError("can't serialize type %s to bytes" % type(data)) + + +def messages_to_proto_requests( + messages: List[InternalMessage], +) -> List[StreamWriteMessage.FromClient]: + # todo split by proto message size and codec + res = [] + for msg in messages: + req = StreamWriteMessage.FromClient( + StreamWriteMessage.WriteRequest( + messages=[msg.to_message_data()], + codec=msg.codec, + ) + ) + res.append(req) + return res diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py new file mode 100644 index 00000000..666fc11b --- /dev/null +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -0,0 +1,680 @@ +import asyncio +import concurrent.futures +import datetime +import gzip +import typing +from collections import deque +from typing import Deque, AsyncIterator, Union, List, Optional, Dict, Callable + +import ydb +from .topic_writer import ( + PublicWriterSettings, + WriterSettings, + PublicMessage, + PublicWriterInitInfo, + InternalMessage, + TopicWriterStopped, + TopicWriterError, + messages_to_proto_requests, + PublicWriteResultTypes, + Message, +) +from .. import ( + _apis, + issues, + check_retriable_error, + RetrySettings, +) +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec +from .._grpc.grpcwrapper.ydb_topic import ( + UpdateTokenRequest, + UpdateTokenResponse, + StreamWriteMessage, + WriterMessagesFromServerToClient, +) +from .._grpc.grpcwrapper.common_utils import ( + IGrpcWrapperAsyncIO, + SupportedDriverType, + GrpcWrapperAsyncIO, +) + + +class WriterAsyncIO: + _loop: asyncio.AbstractEventLoop + _reconnector: "WriterAsyncIOReconnector" + _closed: bool + + def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings): + self._loop = asyncio.get_running_loop() + self._closed = False + self._reconnector = WriterAsyncIOReconnector( + driver=driver, settings=WriterSettings(settings) + ) + + async def __aenter__(self) -> "WriterAsyncIO": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + def __del__(self): + if self._closed or self._loop.is_closed(): + return + + self._loop.call_soon(self.close) + + async def close(self, *, flush: bool = True): + if self._closed: + return + + self._closed = True + + await self._reconnector.close(flush) + + async def write_with_ack( + self, + messages: Union[Message, List[Message]], + ) -> Union[PublicWriteResultTypes, List[PublicWriteResultTypes]]: + """ + IT IS SLOWLY WAY. IT IS BAD CHOISE IN MOST CASES. + It is recommended to use write with optionally flush or write_with_ack_futures and receive acks by wait futures. + + send one or number of messages to server and wait acks. + + For wait with timeout use asyncio.wait_for. + """ + futures = await self.write_with_ack_future(messages) + if not isinstance(futures, list): + futures = [futures] + + await asyncio.wait(futures) + results = [f.result() for f in futures] + + return results if isinstance(messages, list) else results[0] + + async def write_with_ack_future( + self, + messages: Union[Message, List[Message]], + ) -> Union[asyncio.Future, List[asyncio.Future]]: + """ + send one or number of messages to server. + return feature, which can be waited for check send result. + + Usually it is fast method, but can wait if internal buffer is full. + + For wait with timeout use asyncio.wait_for. + """ + input_single_message = not isinstance(messages, list) + converted_messages = [] + if isinstance(messages, list): + for m in messages: + converted_messages.append(PublicMessage._create_message(m)) + else: + converted_messages = [PublicMessage._create_message(messages)] + + futures = await self._reconnector.write_with_ack_future(converted_messages) + if input_single_message: + return futures[0] + else: + return futures + + async def write( + self, + messages: Union[Message, List[Message]], + ): + """ + send one or number of messages to server. + it put message to internal buffer + + For wait with timeout use asyncio.wait_for. + """ + await self.write_with_ack_future(messages) + + async def flush(self): + """ + Force send all messages from internal buffer and wait acks from server for all + messages. + + For wait with timeout use asyncio.wait_for. + """ + return await self._reconnector.flush() + + async def wait_init(self) -> PublicWriterInitInfo: + """ + wait while real connection will be established to server. + + For wait with timeout use asyncio.wait_for() + """ + return await self._reconnector.wait_init() + + +class WriterAsyncIOReconnector: + _closed: bool + _loop: asyncio.AbstractEventLoop + _credentials: Union[ydb.credentials.Credentials, None] + _driver: ydb.aio.Driver + _init_message: StreamWriteMessage.InitRequest + _init_info: asyncio.Future + _stream_connected: asyncio.Event + _settings: WriterSettings + _codec: PublicCodec + _codec_functions: Dict[PublicCodec, Callable[[bytes], bytes]] + _encode_executor: Optional[concurrent.futures.Executor] + _codec_selector_batch_num: int + _codec_selector_last_codec: Optional[PublicCodec] + _codec_selector_check_batches_interval: int + + _last_known_seq_no: int + if typing.TYPE_CHECKING: + _messages_for_encode: asyncio.Queue[List[InternalMessage]] + else: + _messages_for_encode: asyncio.Queue + _messages: Deque[InternalMessage] + _messages_future: Deque[asyncio.Future] + _new_messages: asyncio.Queue + _stop_reason: asyncio.Future + _background_tasks: List[asyncio.Task] + + def __init__(self, driver: SupportedDriverType, settings: WriterSettings): + self._closed = False + self._loop = asyncio.get_running_loop() + self._driver = driver + self._credentials = driver._credentials + self._init_message = settings.create_init_request() + self._new_messages = asyncio.Queue() + self._init_info = self._loop.create_future() + self._stream_connected = asyncio.Event() + self._settings = settings + + self._codec_functions = { + PublicCodec.RAW: lambda data: data, + PublicCodec.GZIP: gzip.compress, + } + + if settings.encoders: + self._codec_functions.update(settings.encoders) + + self._encode_executor = settings.encoder_executor + + self._codec_selector_batch_num = 0 + self._codec_selector_last_codec = None + self._codec_selector_check_batches_interval = 10000 + + self._codec = self._settings.codec + if self._codec and self._codec not in self._codec_functions: + known_codecs = sorted(self._codec_functions.keys()) + raise ValueError( + "Unknown codec for writer: %s, supported codecs: %s" + % (self._codec, known_codecs) + ) + + self._last_known_seq_no = 0 + self._messages_for_encode = asyncio.Queue() + self._messages = deque() + self._messages_future = deque() + self._new_messages = asyncio.Queue() + self._stop_reason = self._loop.create_future() + self._background_tasks = [ + asyncio.create_task(self._connection_loop(), name="connection_loop"), + asyncio.create_task(self._encode_loop(), name="encode_loop"), + ] + + async def close(self, flush: bool): + if self._closed: + return + + if flush: + await self.flush() + + self._closed = True + self._stop(TopicWriterStopped()) + + for task in self._background_tasks: + task.cancel() + await asyncio.wait(self._background_tasks) + + # if work was stopped before close by error - raise the error + try: + self._check_stop() + except TopicWriterStopped: + pass + + async def wait_init(self) -> PublicWriterInitInfo: + done, _ = await asyncio.wait( + [self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED + ) + res = done.pop() # type: asyncio.Future + res_val = res.result() + + if isinstance(res_val, BaseException): + raise res_val + + return res_val + + async def wait_stop(self) -> Exception: + return await self._stop_reason + + async def write_with_ack_future( + self, messages: List[PublicMessage] + ) -> List[asyncio.Future]: + # todo check internal buffer limit + self._check_stop() + + if self._settings.auto_seqno: + await self.wait_init() + + internal_messages = self._prepare_internal_messages(messages) + messages_future = [self._loop.create_future() for _ in internal_messages] + + self._messages_future.extend(messages_future) + + if self._codec == PublicCodec.RAW: + self._add_messages_to_send_queue(internal_messages) + else: + self._messages_for_encode.put_nowait(internal_messages) + + return messages_future + + def _add_messages_to_send_queue(self, internal_messages: List[InternalMessage]): + self._messages.extend(internal_messages) + for m in internal_messages: + self._new_messages.put_nowait(m) + + def _prepare_internal_messages( + self, messages: List[PublicMessage] + ) -> List[InternalMessage]: + if self._settings.auto_created_at: + now = datetime.datetime.now() + else: + now = None + + res = [] + for m in messages: + internal_message = InternalMessage(m) + if self._settings.auto_seqno: + if internal_message.seq_no is None: + self._last_known_seq_no += 1 + internal_message.seq_no = self._last_known_seq_no + else: + raise TopicWriterError( + "Explicit seqno and auto_seq setting is mutual exclusive" + ) + else: + if internal_message.seq_no is None or internal_message.seq_no == 0: + raise TopicWriterError( + "Empty seqno and auto_seq setting is disabled" + ) + elif internal_message.seq_no <= self._last_known_seq_no: + raise TopicWriterError( + "Message seqno is duplicated: %s" % internal_message.seq_no + ) + else: + self._last_known_seq_no = internal_message.seq_no + + if self._settings.auto_created_at: + if internal_message.created_at is not None: + raise TopicWriterError( + "Explicit set auto_created_at and setting auto_created_at is mutual exclusive" + ) + else: + internal_message.created_at = now + + res.append(internal_message) + + return res + + def _check_stop(self): + if self._stop_reason.done(): + raise self._stop_reason.result() + + async def _connection_loop(self): + retry_settings = RetrySettings() # todo + + while True: + attempt = 0 # todo calc and reset + tasks = [] + + # noinspection PyBroadException + stream_writer = None + try: + stream_writer = await WriterAsyncIOStream.create( + self._driver, + self._init_message, + self._settings.update_token_interval, + ) + try: + self._last_known_seq_no = stream_writer.last_seqno + self._init_info.set_result( + PublicWriterInitInfo( + last_seqno=stream_writer.last_seqno, + supported_codecs=stream_writer.supported_codecs, + ) + ) + except asyncio.InvalidStateError: + pass + + self._stream_connected.set() + + send_loop = asyncio.create_task( + self._send_loop(stream_writer), name="writer send loop" + ) + receive_loop = asyncio.create_task( + self._read_loop(stream_writer), name="writer receive loop" + ) + + tasks = [send_loop, receive_loop] + done, _ = await asyncio.wait( + [send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED + ) + await stream_writer.close() + done.pop().result() + except issues.Error as err: + # todo log error + print(err) + + err_info = check_retriable_error(err, retry_settings, attempt) + if not err_info.is_retriable: + self._stop(err) + return + + await asyncio.sleep(err_info.sleep_timeout_seconds) + + except (asyncio.CancelledError, Exception) as err: + self._stop(err) + return + finally: + if stream_writer: + await stream_writer.close() + for task in tasks: + task.cancel() + await asyncio.wait(tasks) + + async def _encode_loop(self): + while True: + messages = await self._messages_for_encode.get() + while not self._messages_for_encode.empty(): + messages.extend(self._messages_for_encode.get_nowait()) + + batch_codec = await self._codec_selector(messages) + await self._encode_data_inplace(batch_codec, messages) + self._add_messages_to_send_queue(messages) + + async def _encode_data_inplace( + self, codec: PublicCodec, messages: List[InternalMessage] + ): + if codec == PublicCodec.RAW: + return + + eventloop = asyncio.get_running_loop() + encode_waiters = [] + encoder_function = self._codec_functions[codec] + + for message in messages: + encoded_data_futures = eventloop.run_in_executor( + self._encode_executor, encoder_function, message.get_bytes() + ) + encode_waiters.append(encoded_data_futures) + + encoded_datas = await asyncio.gather(*encode_waiters) + + for index, data in enumerate(encoded_datas): + message = messages[index] + message.codec = codec + message.data = data + + async def _codec_selector(self, messages: List[InternalMessage]) -> PublicCodec: + if self._codec is not None: + return self._codec + + if self._codec_selector_last_codec is None: + available_codecs = await self._get_available_codecs() + + # use every of available encoders at start for prevent problems + # with rare used encoders (on writer or reader side) + if self._codec_selector_batch_num < len(available_codecs): + codec = available_codecs[self._codec_selector_batch_num] + else: + codec = await self._codec_selector_by_check_compress(messages) + self._codec_selector_last_codec = codec + else: + if ( + self._codec_selector_batch_num + % self._codec_selector_check_batches_interval + == 0 + ): + self._codec_selector_last_codec = ( + await self._codec_selector_by_check_compress(messages) + ) + codec = self._codec_selector_last_codec + self._codec_selector_batch_num += 1 + return codec + + async def _get_available_codecs(self) -> List[PublicCodec]: + info = await self.wait_init() + topic_supported_codecs = info.supported_codecs + if not topic_supported_codecs: + topic_supported_codecs = [PublicCodec.RAW, PublicCodec.GZIP] + + res = [] + for codec in topic_supported_codecs: + if codec in self._codec_functions: + res.append(codec) + + if not res: + raise TopicWriterError("Writer does not support topic's codecs") + + res.sort() + + return res + + async def _codec_selector_by_check_compress( + self, messages: List[InternalMessage] + ) -> PublicCodec: + """ + Try to compress messages and choose codec with the smallest result size. + """ + + test_messages = messages[:10] + + available_codecs = await self._get_available_codecs() + if len(available_codecs) == 1: + return available_codecs[0] + + def get_compressed_size(codec) -> int: + s = 0 + f = self._codec_functions[codec] + + for m in test_messages: + encoded = f(m.get_bytes()) + s += len(encoded) + + return s + + def select_codec() -> PublicCodec: + min_codec = available_codecs[0] + min_size = get_compressed_size(min_codec) + for codec in available_codecs[1:]: + size = get_compressed_size(codec) + if size < min_size: + min_codec = codec + min_size = size + return min_codec + + loop = asyncio.get_running_loop() + codec = await loop.run_in_executor(self._encode_executor, select_codec) + return codec + + async def _read_loop(self, writer: "WriterAsyncIOStream"): + while True: + resp = await writer.receive() + + for ack in resp.acks: + self._handle_receive_ack(ack) + + def _handle_receive_ack(self, ack): + current_message = self._messages.popleft() + message_future = self._messages_future.popleft() + if current_message.seq_no != ack.seq_no: + raise TopicWriterError( + "internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s" + % (current_message.seq_no, ack.seq_no) + ) + message_future.set_result( + None + ) # todo - return result with offset or skip status + + async def _send_loop(self, writer: "WriterAsyncIOStream"): + try: + messages = list(self._messages) + + last_seq_no = 0 + for m in messages: + writer.write([m]) + last_seq_no = m.seq_no + + while True: + m = await self._new_messages.get() # type: InternalMessage + if m.seq_no > last_seq_no: + writer.write([m]) + except Exception as e: + self._stop(e) + finally: + pass + + def _stop(self, reason: Exception): + if reason is None: + raise Exception("writer stop reason can not be None") + + if self._stop_reason.done(): + return + + self._stop_reason.set_result(reason) + + async def flush(self): + self._check_stop() + if not self._messages_future: + return + + # wait last message + await asyncio.wait((self._messages_future[-1],)) + + +class WriterAsyncIOStream: + # todo slots + _closed: bool + + last_seqno: int + supported_codecs: Optional[List[PublicCodec]] + + _stream: IGrpcWrapperAsyncIO + _requests: asyncio.Queue + _responses: AsyncIterator + + _update_token_interval: Optional[Union[int, float]] + _update_token_task: Optional[asyncio.Task] + _update_token_event: asyncio.Event + _get_token_function: Optional[Callable[[], str]] + + def __init__( + self, + update_token_interval: Optional[Union[int, float]] = None, + get_token_function: Optional[Callable[[], str]] = None, + ): + self._closed = False + + self._update_token_interval = update_token_interval + self._get_token_function = get_token_function + self._update_token_event = asyncio.Event() + self._update_token_task = None + + async def close(self): + if self._closed: + return + self._closed = True + + if self._update_token_task: + self._update_token_task.cancel() + await asyncio.wait([self._update_token_task]) + + self._stream.close() + + @staticmethod + async def create( + driver: SupportedDriverType, + init_request: StreamWriteMessage.InitRequest, + update_token_interval: Optional[Union[int, float]] = None, + ) -> "WriterAsyncIOStream": + stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) + + await stream.start( + driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite + ) + + creds = driver._credentials + writer = WriterAsyncIOStream( + update_token_interval=update_token_interval, + get_token_function=creds.get_auth_token if creds else lambda: "", + ) + await writer._start(stream, init_request) + return writer + + async def receive(self) -> StreamWriteMessage.WriteResponse: + while True: + item = await self._stream.receive() + + if isinstance(item, StreamWriteMessage.WriteResponse): + return item + if isinstance(item, UpdateTokenResponse): + self._update_token_event.set() + continue + + # todo log unknown messages instead of raise exception + raise Exception("Unknown message while read writer answers: %s" % item) + + async def _start( + self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest + ): + stream.write(StreamWriteMessage.FromClient(init_message)) + + resp = await stream.receive() + self._ensure_ok(resp) + if not isinstance(resp, StreamWriteMessage.InitResponse): + raise TopicWriterError("Unexpected answer for init request: %s" % resp) + + self.last_seqno = resp.last_seq_no + self.supported_codecs = [PublicCodec(codec) for codec in resp.supported_codecs] + + self._stream = stream + + if self._update_token_interval is not None: + self._update_token_event.set() + self._update_token_task = asyncio.create_task( + self._update_token_loop(), name="update_token_loop" + ) + + @staticmethod + def _ensure_ok(message: WriterMessagesFromServerToClient): + if not message.status.is_success(): + raise TopicWriterError( + "status error from server in writer: %s", message.status + ) + + def write(self, messages: List[InternalMessage]): + if self._closed: + raise RuntimeError("Can not write on closed stream.") + + for request in messages_to_proto_requests(messages): + self._stream.write(request) + + async def _update_token_loop(self): + while True: + await asyncio.sleep(self._update_token_interval) + await self._update_token(token=self._get_token_function()) + + async def _update_token(self, token: str): + await self._update_token_event.wait() + try: + msg = StreamWriteMessage.FromClient(UpdateTokenRequest(token)) + self._stream.write(msg) + finally: + self._update_token_event.clear() diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py new file mode 100644 index 00000000..b5b3fcc8 --- /dev/null +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -0,0 +1,743 @@ +from __future__ import annotations + +import asyncio +import copy +import dataclasses +import datetime +import gzip +import typing +from queue import Queue, Empty +from typing import List, Callable, Optional +from unittest import mock + +import freezegun +import pytest + +from .. import aio +from .. import StatusCode, issues +from .._grpc.grpcwrapper.ydb_topic import ( + Codec, + StreamWriteMessage, + UpdateTokenRequest, + UpdateTokenResponse, +) +from .._grpc.grpcwrapper.common_utils import ServerStatus +from .topic_writer import ( + InternalMessage, + PublicMessage, + WriterSettings, + PublicWriterSettings, + PublicWriterInitInfo, + PublicWriteResult, + TopicWriterError, +) +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec +from .._topic_common.test_helpers import StreamMock, wait_for_fast + +from .topic_writer_asyncio import ( + WriterAsyncIOStream, + WriterAsyncIOReconnector, + WriterAsyncIO, +) + +from ..credentials import AnonymousCredentials + + +@pytest.fixture +def default_driver() -> aio.Driver: + driver = mock.Mock(spec=aio.Driver) + driver._credentials = AnonymousCredentials() + return driver + + +@pytest.mark.asyncio +class TestWriterAsyncIOStream: + @dataclasses.dataclass + class WriterWithMockedStream: + writer: WriterAsyncIOStream + stream: StreamMock + + @pytest.fixture + def stream(self): + stream = StreamMock() + yield stream + stream.close() + + @staticmethod + async def get_started_writer(stream, *args, **kwargs) -> WriterAsyncIOStream: + stream.from_server.put_nowait( + StreamWriteMessage.InitResponse( + last_seq_no=4, + session_id="123", + partition_id=3, + supported_codecs=[Codec.CODEC_RAW, Codec.CODEC_GZIP], + status=ServerStatus(StatusCode.SUCCESS, []), + ) + ) + + writer = WriterAsyncIOStream(*args, **kwargs) + await writer._start( + stream, + init_message=StreamWriteMessage.InitRequest( + path="/local/test", + producer_id="producer-id", + write_session_meta={"a": "b"}, + partitioning=StreamWriteMessage.PartitioningMessageGroupID( + message_group_id="message-group-id" + ), + get_last_seq_no=False, + ), + ) + await stream.from_client.get() + return writer + + @pytest.fixture + async def writer_and_stream(self, stream) -> WriterWithMockedStream: + writer = await self.get_started_writer(stream) + + yield TestWriterAsyncIOStream.WriterWithMockedStream( + stream=stream, + writer=writer, + ) + + await writer.close() + + async def test_init_writer(self, stream): + init_seqno = 4 + init_message = StreamWriteMessage.InitRequest( + path="/local/test", + producer_id="producer-id", + write_session_meta={"a": "b"}, + partitioning=StreamWriteMessage.PartitioningMessageGroupID( + message_group_id="message-group-id" + ), + get_last_seq_no=False, + ) + stream.from_server.put_nowait( + StreamWriteMessage.InitResponse( + last_seq_no=init_seqno, + session_id="123", + partition_id=0, + supported_codecs=[], + status=ServerStatus(StatusCode.SUCCESS, []), + ) + ) + + writer = WriterAsyncIOStream() + await writer._start(stream, init_message) + + sent_message = await stream.from_client.get() + expected_message = StreamWriteMessage.FromClient(init_message) + + assert expected_message == sent_message + assert writer.last_seqno == init_seqno + + await writer.close() + + async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): + data = "123".encode() + now = datetime.datetime.now() + writer_and_stream.writer.write( + [ + InternalMessage( + PublicMessage( + seqno=1, + created_at=now, + data=data, + ) + ) + ] + ) + + expected_message = StreamWriteMessage.FromClient( + StreamWriteMessage.WriteRequest( + codec=Codec.CODEC_RAW, + messages=[ + StreamWriteMessage.WriteRequest.MessageData( + seq_no=1, + created_at=now, + data=data, + uncompressed_size=len(data), + partitioning=None, + ) + ], + ) + ) + + sent_message = await writer_and_stream.stream.from_client.get() + assert expected_message == sent_message + + async def test_update_token(self, stream: StreamMock): + writer = await self.get_started_writer( + stream, update_token_interval=0.1, get_token_function=lambda: "foo-bar" + ) + assert stream.from_client.empty() + + expected = StreamWriteMessage.FromClient(UpdateTokenRequest(token="foo-bar")) + got = await wait_for_fast(stream.from_client.get()) + assert expected == got, "send update token request" + + await asyncio.sleep(0.2) + assert stream.from_client.empty(), "no answer - no new update request" + + await stream.from_server.put(UpdateTokenResponse()) + receive_task = asyncio.create_task(writer.receive()) + + got = await wait_for_fast(stream.from_client.get()) + assert expected == got + + receive_task.cancel() + await asyncio.wait([receive_task]) + + +@pytest.mark.asyncio +class TestWriterAsyncIOReconnector: + init_last_seqno = 0 + time_for_mocks = 1678046714.639387 + + class StreamWriterMock: + last_seqno: int + supported_codecs: List[PublicCodec] + + from_client: asyncio.Queue + from_server: asyncio.Queue + + _closed: bool + + def __init__( + self, + update_token_interval: Optional[int, float] = None, + get_token_function: Optional[Callable[[], str]] = None, + ): + self.last_seqno = 0 + self.from_server = asyncio.Queue() + self.from_client = asyncio.Queue() + self._closed = False + self.supported_codecs = [] + + def write(self, messages: typing.List[InternalMessage]): + if self._closed: + raise Exception("write to closed StreamWriterMock") + + self.from_client.put_nowait(messages) + + async def receive(self) -> StreamWriteMessage.WriteResponse: + if self._closed: + raise Exception("read from closed StreamWriterMock") + + item = await self.from_server.get() + if isinstance(item, Exception): + raise item + return item + + async def close(self): + if self._closed: + return + self._closed = True + + @pytest.fixture(autouse=True) + async def stream_writer_double_queue(self, monkeypatch): + class DoubleQueueWriters: + _first: Queue + _second: Queue + + def __init__(self): + self._first = Queue() + self._second = Queue() + + def get_first(self): + try: + return self._first.get_nowait() + except Empty: + self._create() + return self.get_first() + + def get_second(self): + try: + return self._second.get_nowait() + except Empty: + self._create() + return self.get_second() + + def _create(self): + writer = TestWriterAsyncIOReconnector.StreamWriterMock() + writer.last_seqno = TestWriterAsyncIOReconnector.init_last_seqno + self._first.put_nowait(writer) + self._second.put_nowait(writer) + + res = DoubleQueueWriters() + + async def async_create(driver, init_message, token_getter): + return res.get_first() + + monkeypatch.setattr(WriterAsyncIOStream, "create", async_create) + return res + + @pytest.fixture + def get_stream_writer( + self, stream_writer_double_queue + ) -> typing.Callable[[], "TestWriterAsyncIOReconnector.StreamWriterMock"]: + return stream_writer_double_queue.get_second + + @pytest.fixture + def default_settings(self) -> WriterSettings: + return WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + update_token_interval=3600, + ) + ) + + @pytest.fixture + def default_write_statistic( + self, + ) -> StreamWriteMessage.WriteResponse.WriteStatistics: + return StreamWriteMessage.WriteResponse.WriteStatistics( + persisting_time=datetime.timedelta(milliseconds=1), + min_queue_wait_time=datetime.timedelta(milliseconds=2), + max_queue_wait_time=datetime.timedelta(milliseconds=3), + partition_quota_wait_time=datetime.timedelta(milliseconds=4), + topic_quota_wait_time=datetime.timedelta(milliseconds=5), + ) + + def make_default_ack_message(self, seq_no=1) -> StreamWriteMessage.WriteResponse: + return StreamWriteMessage.WriteResponse( + partition_id=1, + acks=[ + StreamWriteMessage.WriteResponse.WriteAck( + seq_no=seq_no, + message_write_status=StreamWriteMessage.WriteResponse.WriteAck.StatusWritten( + offset=1 + ), + ) + ], + write_statistics=self.default_write_statistic, + ) + + @pytest.fixture + async def reconnector( + self, default_driver, default_settings + ) -> WriterAsyncIOReconnector: + return WriterAsyncIOReconnector(default_driver, default_settings) + + async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( + self, + reconnector: WriterAsyncIOReconnector, + get_stream_writer, + default_write_statistic, + ): + now = datetime.datetime.now() + data = "123".encode() + + message1 = PublicMessage( + data=data, + seqno=1, + created_at=now, + ) + message2 = PublicMessage( + data=data, + seqno=2, + created_at=now, + ) + await reconnector.write_with_ack_future([message1, message2]) + + # sent to first stream + stream_writer = get_stream_writer() + + messages = await stream_writer.from_client.get() + assert [InternalMessage(message1)] == messages + messages = await stream_writer.from_client.get() + assert [InternalMessage(message2)] == messages + + # ack first message + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1)) + + stream_writer.from_server.put_nowait(issues.Overloaded("test")) + + second_writer = get_stream_writer() + second_sent_msg = await second_writer.from_client.get() + + expected_messages = [InternalMessage(message2)] + assert second_sent_msg == expected_messages + + second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2)) + await reconnector.close(flush=True) + + async def test_stop_on_unexpected_exception( + self, reconnector: WriterAsyncIOReconnector, get_stream_writer + ): + class TestException(Exception): + pass + + stream_writer = get_stream_writer() + stream_writer.from_server.put_nowait(TestException()) + + message = PublicMessage( + data="123", + seqno=3, + ) + + with pytest.raises(TestException): + + async def wait_stop(): + while True: + await reconnector.write_with_ack_future([message]) + await asyncio.sleep(0.1) + + await asyncio.wait_for(wait_stop(), 1) + + with pytest.raises(TestException): + await reconnector.close(flush=False) + + async def test_wait_init(self, default_driver, default_settings, get_stream_writer): + init_seqno = 100 + expected_init_info = PublicWriterInitInfo( + last_seqno=init_seqno, supported_codecs=[] + ) + with mock.patch.object( + TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno + ): + reconnector = WriterAsyncIOReconnector(default_driver, default_settings) + info = await reconnector.wait_init() + assert info == expected_init_info + + reconnector._stream_connected.clear() + + # force reconnect + with mock.patch.object( + TestWriterAsyncIOReconnector, "init_last_seqno", init_seqno + 1 + ): + stream_writer = get_stream_writer() + stream_writer.from_server.put_nowait( + issues.Overloaded("test") + ) # some retriable error + await reconnector._stream_connected.wait() + + info = await reconnector.wait_init() + assert info == expected_init_info + + await reconnector.close(flush=False) + + async def test_write_message( + self, reconnector: WriterAsyncIOReconnector, get_stream_writer + ): + stream_writer = get_stream_writer() + message = PublicMessage( + data="123", + seqno=3, + ) + await reconnector.write_with_ack_future([message]) + + sent_messages = await asyncio.wait_for(stream_writer.from_client.get(), 1) + assert sent_messages == [InternalMessage(message)] + + await reconnector.close(flush=False) + + async def test_auto_seq_no( + self, default_driver, default_settings, get_stream_writer + ): + last_seq_no = 100 + with mock.patch.object( + TestWriterAsyncIOReconnector, "init_last_seqno", last_seq_no + ): + settings = copy.deepcopy(default_settings) + settings.auto_seqno = True + + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + await reconnector.write_with_ack_future([PublicMessage(data="123")]) + await reconnector.write_with_ack_future([PublicMessage(data="456")]) + + stream_writer = get_stream_writer() + + sent = await stream_writer.from_client.get() + assert [ + InternalMessage(PublicMessage(seqno=last_seq_no + 1, data="123")) + ] == sent + + sent = await stream_writer.from_client.get() + assert [ + InternalMessage(PublicMessage(seqno=last_seq_no + 2, data="456")) + ] == sent + + with pytest.raises(TopicWriterError): + await reconnector.write_with_ack_future( + [PublicMessage(seqno=last_seq_no + 3, data="123")] + ) + + await reconnector.close(flush=False) + + async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector): + await reconnector.write_with_ack_future([PublicMessage(seqno=10, data="123")]) + + with pytest.raises(TopicWriterError): + await reconnector.write_with_ack_future( + [PublicMessage(seqno=9, data="123")] + ) + + with pytest.raises(TopicWriterError): + await reconnector.write_with_ack_future( + [PublicMessage(seqno=10, data="123")] + ) + + await reconnector.write_with_ack_future([PublicMessage(seqno=11, data="123")]) + + await reconnector.close(flush=False) + + @freezegun.freeze_time("2022-01-13 20:50:00", tz_offset=0) + async def test_auto_created_at( + self, default_driver, default_settings, get_stream_writer + ): + now = datetime.datetime.now() + + settings = copy.deepcopy(default_settings) + settings.auto_created_at = True + reconnector = WriterAsyncIOReconnector(default_driver, settings) + await reconnector.write_with_ack_future([PublicMessage(seqno=4, data="123")]) + + stream_writer = get_stream_writer() + sent = await stream_writer.from_client.get() + + assert [ + InternalMessage(PublicMessage(seqno=4, data="123", created_at=now)) + ] == sent + await reconnector.close(flush=False) + + @pytest.mark.parametrize( + "codec,write_datas,expected_codecs,expected_datas", + [ + ( + PublicCodec.RAW, + [b"123"], + [PublicCodec.RAW], + [b"123"], + ), + ( + PublicCodec.GZIP, + [b"123"], + [PublicCodec.GZIP], + [gzip.compress(b"123", mtime=time_for_mocks)], + ), + ( + None, + [b"123", b"456", b"789", b"0" * 1000], + [PublicCodec.RAW, PublicCodec.GZIP, PublicCodec.RAW, PublicCodec.RAW], + [ + b"123", + gzip.compress(b"456", mtime=time_for_mocks), + b"789", + b"0" * 1000, + ], + ), + ( + None, + [b"123", b"456", b"789" * 1000, b"0"], + [PublicCodec.RAW, PublicCodec.GZIP, PublicCodec.GZIP, PublicCodec.GZIP], + [ + b"123", + gzip.compress(b"456", mtime=time_for_mocks), + gzip.compress(b"789" * 1000, mtime=time_for_mocks), + gzip.compress(b"0", mtime=time_for_mocks), + ], + ), + ], + ) + async def test_select_codecs( + self, + default_driver: aio.Driver, + default_settings: WriterSettings, + monkeypatch, + write_datas: List[typing.Optional[bytes]], + codec: typing.Optional[PublicCodec], + expected_codecs: List[PublicCodec], + expected_datas: List[bytes], + ): + assert len(write_datas) == len(expected_datas) + assert len(expected_codecs) == len(expected_datas) + + settings = copy.copy(default_settings) + settings.codec = codec + settings.auto_seqno = True + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + added_messages = asyncio.Queue() # type: asyncio.Queue[List[InternalMessage]] + + def add_messages(_self, messages: typing.List[InternalMessage]): + added_messages.put_nowait(messages) + + monkeypatch.setattr( + WriterAsyncIOReconnector, "_add_messages_to_send_queue", add_messages + ) + monkeypatch.setattr( + "time.time", lambda: TestWriterAsyncIOReconnector.time_for_mocks + ) + + for i in range(len(expected_datas)): + await reconnector.write_with_ack_future( + [PublicMessage(data=write_datas[i])] + ) + mess = await asyncio.wait_for(added_messages.get(), timeout=600) + mess = mess[0] + + assert mess.codec == expected_codecs[i] + assert mess.get_bytes() == expected_datas[i] + + await reconnector.close(flush=False) + + @pytest.mark.parametrize( + "codec,datas", + [ + ( + PublicCodec.RAW, + [b"123", b"456", b"789", b"0"], + ), + ( + PublicCodec.GZIP, + [b"123", b"456", b"789", b"0"], + ), + ], + ) + async def test_encode_data_inplace( + self, + reconnector: WriterAsyncIOReconnector, + codec: PublicCodec, + datas: List[bytes], + ): + f = reconnector._codec_functions[codec] + expected_datas = [f(data) for data in datas] + + messages = [InternalMessage(PublicMessage(data)) for data in datas] + await reconnector._encode_data_inplace(codec, messages) + + for index, mess in enumerate(messages): + assert mess.codec == codec + assert mess.get_bytes() == expected_datas[index] + + async def test_custom_encoder( + self, default_driver, default_settings, get_stream_writer + ): + codec = 10001 + + settings = copy.copy(default_settings) + settings.encoders = {codec: lambda x: bytes(reversed(x))} + settings.codec = codec + reconnector = WriterAsyncIOReconnector(default_driver, settings) + + now = datetime.datetime.now() + seqno = self.init_last_seqno + 1 + + await reconnector.write_with_ack_future( + [PublicMessage(data=b"123", seqno=seqno, created_at=now)] + ) + + stream_writer = get_stream_writer() + sent_messages = await wait_for_fast(stream_writer.from_client.get()) + + expected_mess = InternalMessage( + PublicMessage(data=b"321", seqno=seqno, created_at=now) + ) + expected_mess.codec = codec + + assert sent_messages == [expected_mess] + + await reconnector.close(flush=False) + + +@pytest.mark.asyncio +class TestWriterAsyncIO: + class ReconnectorMock: + lock: asyncio.Lock + messages: typing.List[InternalMessage] + futures: typing.List[asyncio.Future] + messages_writted: asyncio.Event + + def __init__(self): + self.lock = asyncio.Lock() + self.messages = [] + self.futures = [] + self.messages_writted = asyncio.Event() + + async def write_with_ack_future(self, messages: typing.List[InternalMessage]): + async with self.lock: + futures = [asyncio.Future() for _ in messages] + self.messages.extend(messages) + self.futures.extend(futures) + self.messages_writted.set() + return futures + + async def close(self): + pass + + @pytest.fixture + def default_settings(self) -> PublicWriterSettings: + return PublicWriterSettings( + topic="/local/topic", + producer_id="producer-id", + ) + + @pytest.fixture(autouse=True) + def mock_reconnector_init(self, monkeypatch, reconnector): + def t(cls, driver, settings): + return reconnector + + monkeypatch.setattr(WriterAsyncIOReconnector, "__new__", t) + + @pytest.fixture + def reconnector(self, monkeypatch) -> TestWriterAsyncIO.ReconnectorMock: + reconnector = TestWriterAsyncIO.ReconnectorMock() + return reconnector + + @pytest.fixture + async def writer(self, default_driver, default_settings): + return WriterAsyncIO(default_driver, default_settings) + + async def test_write(self, writer: WriterAsyncIO, reconnector): + m = PublicMessage(seqno=1, data="123") + res = await writer.write(m) + assert res is None + + assert reconnector.messages == [m] + + async def test_write_with_futures(self, writer: WriterAsyncIO, reconnector): + m = PublicMessage(seqno=1, data="123") + res = await writer.write_with_ack_future(m) + + assert reconnector.messages == [m] + assert asyncio.isfuture(res) + + async def test_write_with_ack(self, writer: WriterAsyncIO, reconnector): + reconnector.messages_writted.clear() + + async def ack_first_message(): + await reconnector.messages_writted.wait() + async with reconnector.lock: + reconnector.futures[0].set_result(PublicWriteResult.Written(offset=1)) + + asyncio.create_task(ack_first_message()) + + m = PublicMessage(seqno=1, data="123") + res = await writer.write_with_ack(m) + + assert res == PublicWriteResult.Written(offset=1) + + reconnector.messages_writted.clear() + async with reconnector.lock: + reconnector.messages.clear() + reconnector.futures.clear() + + async def ack_next_messages(): + await reconnector.messages_writted.wait() + async with reconnector.lock: + reconnector.futures[0].set_result(PublicWriteResult.Written(offset=2)) + reconnector.futures[1].set_result(PublicWriteResult.Skipped()) + + asyncio.create_task(ack_next_messages()) + + res = await writer.write_with_ack( + [PublicMessage(seqno=2, data="123"), PublicMessage(seqno=3, data="123")] + ) + assert res == [PublicWriteResult.Written(offset=2), PublicWriteResult.Skipped()] diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py new file mode 100644 index 00000000..e6b51238 --- /dev/null +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import asyncio +from concurrent.futures import Future +from typing import Union, List, Optional + +from .._grpc.grpcwrapper.common_utils import SupportedDriverType +from .topic_writer import ( + PublicWriterSettings, + PublicWriterInitInfo, + PublicWriteResult, + Message, + TopicWriterClosedError, +) + +from .topic_writer_asyncio import WriterAsyncIO +from .._topic_common.common import ( + _get_shared_event_loop, + TimeoutType, + CallFromSyncToAsync, +) + + +class WriterSync: + _caller: CallFromSyncToAsync + _async_writer: WriterAsyncIO + _closed: bool + + def __init__( + self, + driver: SupportedDriverType, + settings: PublicWriterSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, + ): + + self._closed = False + + if eventloop: + loop = eventloop + else: + loop = _get_shared_event_loop() + + self._caller = CallFromSyncToAsync(loop) + + async def create_async_writer(): + return WriterAsyncIO(driver, settings) + + self._async_writer = self._caller.safe_call_with_result( + create_async_writer(), None + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self, *, flush: bool = True, timeout: TimeoutType = None): + if self._closed: + return + + self._closed = True + + self._caller.safe_call_with_result( + self._async_writer.close(flush=flush), timeout + ) + + def _check_closed(self): + if self._closed: + raise TopicWriterClosedError() + + def async_flush(self) -> Future: + self._check_closed() + + return self._caller.unsafe_call_with_future(self._async_writer.flush()) + + def flush(self, *, timeout=None): + self._check_closed() + + return self._caller.unsafe_call_with_result(self._async_writer.flush(), timeout) + + def async_wait_init(self) -> Future[PublicWriterInitInfo]: + self._check_closed() + + return self._caller.unsafe_call_with_future(self._async_writer.wait_init()) + + def wait_init(self, *, timeout: TimeoutType = None) -> PublicWriterInitInfo: + self._check_closed() + + return self._caller.unsafe_call_with_result( + self._async_writer.wait_init(), timeout + ) + + def write( + self, + messages: Union[Message, List[Message]], + timeout: TimeoutType = None, + ): + self._check_closed() + + self._caller.safe_call_with_result(self._async_writer.write(messages), timeout) + + def async_write_with_ack( + self, + messages: Union[Message, List[Message]], + ) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]: + self._check_closed() + + return self._caller.unsafe_call_with_future( + self._async_writer.write_with_ack(messages) + ) + + def write_with_ack( + self, + messages: Union[Message, List[Message]], + timeout: Union[float, None] = None, + ) -> Union[PublicWriteResult, List[PublicWriteResult]]: + self._check_closed() + + return self._caller.unsafe_call_with_result( + self._async_writer.write_with_ack(messages), timeout=timeout + ) diff --git a/ydb/_utilities.py b/ydb/_utilities.py index 32419b1b..0b72a198 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- -import six +import threading import codecs from concurrent import futures import functools import hashlib import collections +import urllib.parse from . import ydb_version try: @@ -55,8 +56,8 @@ def parse_connection_string(connection_string): # default is grpcs cs = _grpcs_protocol + cs - p = six.moves.urllib.parse.urlparse(connection_string) - b = six.moves.urllib.parse.parse_qs(p.query) + p = urllib.parse.urlparse(connection_string) + b = urllib.parse.parse_qs(p.query) database = b.get("database", []) assert len(database) > 0 @@ -77,11 +78,9 @@ def decorator(*args, **kwargs): def get_query_hash(yql_text): try: - return hashlib.sha256( - six.text_type(yql_text, "utf-8").encode("utf-8") - ).hexdigest() + return hashlib.sha256(str(yql_text, "utf-8").encode("utf-8")).hexdigest() except TypeError: - return hashlib.sha256(six.text_type(yql_text).encode("utf-8")).hexdigest() + return hashlib.sha256(str(yql_text).encode("utf-8")).hexdigest() class LRUCache(object): @@ -159,3 +158,17 @@ def next(self): def __next__(self): return self._next() + + +class AtomicCounter: + _lock: threading.Lock + _value: int + + def __init__(self, initial_value: int = 0): + self._lock = threading.Lock() + self._value = initial_value + + def inc_and_get(self) -> int: + with self._lock: + self._value += 1 + return self._value diff --git a/ydb/aio/connection.py b/ydb/aio/connection.py index 88ab738c..fbfcfaaf 100644 --- a/ydb/aio/connection.py +++ b/ydb/aio/connection.py @@ -1,5 +1,6 @@ import logging import asyncio +import typing from typing import Any, Tuple, Callable, Iterable import collections import grpc @@ -24,11 +25,19 @@ from ydb.settings import BaseRequestSettings from ydb import issues +# Workaround for good IDE and universal for runtime +if typing.TYPE_CHECKING: + from ydb._grpc.v4 import ydb_topic_v1_pb2_grpc +else: + from ydb._grpc.common import ydb_topic_v1_pb2_grpc + + _stubs_list = ( _apis.TableService.Stub, _apis.SchemeService.Stub, _apis.DiscoveryService.Stub, _apis.CmsService.Stub, + ydb_topic_v1_pb2_grpc.TopicServiceStub, ) logger = logging.getLogger(__name__) diff --git a/ydb/aio/credentials.py b/ydb/aio/credentials.py index e9840440..93868b27 100644 --- a/ydb/aio/credentials.py +++ b/ydb/aio/credentials.py @@ -3,7 +3,6 @@ import abc import asyncio import logging -import six from ydb import issues, credentials logger = logging.getLogger(__name__) @@ -55,7 +54,6 @@ def submit(self, callback): asyncio.ensure_future(self._wrapped_execution(callback)) -@six.add_metaclass(abc.ABCMeta) class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self): super(AbstractExpiringTokenCredentials, self).__init__() diff --git a/ydb/aio/driver.py b/ydb/aio/driver.py index 3bf6cca8..1aa3ad27 100644 --- a/ydb/aio/driver.py +++ b/ydb/aio/driver.py @@ -1,42 +1,7 @@ -import os - from . import pool, scheme, table import ydb -from ydb.driver import get_config - - -def default_credentials(credentials=None): - if credentials is not None: - return credentials - - service_account_key_file = os.getenv("YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS") - if service_account_key_file is not None: - from .iam import ServiceAccountCredentials - - return ServiceAccountCredentials.from_file(service_account_key_file) - - anonymous_credetials = os.getenv("YDB_ANONYMOUS_CREDENTIALS", "0") == "1" - if anonymous_credetials: - return ydb.credentials.AnonymousCredentials() - - metadata_credentials = os.getenv("YDB_METADATA_CREDENTIALS", "0") == "1" - if metadata_credentials: - from .iam import MetadataUrlCredentials - - return MetadataUrlCredentials() - - access_token = os.getenv("YDB_ACCESS_TOKEN_CREDENTIALS") - if access_token is not None: - return ydb.credentials.AccessTokenCredentials(access_token) - - # (legacy instantiation) - creds = ydb.auth_helpers.construct_credentials_from_environ() - if creds is not None: - return creds - - from .iam import MetadataUrlCredentials - - return MetadataUrlCredentials() +from .. import _utilities +from ydb.driver import get_config, default_credentials class DriverConfig(ydb.DriverConfig): @@ -56,7 +21,7 @@ def default_from_endpoint_and_database( def default_from_connection_string( cls, connection_string, root_certificates=None, credentials=None, **kwargs ): - endpoint, database = ydb.parse_connection_string(connection_string) + endpoint, database = _utilities.parse_connection_string(connection_string) return cls( endpoint, database, @@ -67,6 +32,8 @@ def default_from_connection_string( class Driver(pool.ConnectionPool): + _credentials: ydb.Credentials # used for topic clients + def __init__( self, driver_config=None, @@ -77,6 +44,8 @@ def __init__( credentials=None, **kwargs ): + from .. import topic # local import for prevent cycle import error + config = get_config( driver_config, connection_string, @@ -89,5 +58,8 @@ def __init__( super(Driver, self).__init__(config) + self._credentials = config.credentials + self.scheme_client = scheme.SchemeClient(self) self.table_client = table.TableClient(self, config.table_client_settings) + self.topic_client = topic.TopicClientAsyncIO(self, config.topic_client_settings) diff --git a/ydb/aio/iam.py b/ydb/aio/iam.py index 51b650f2..b56c0660 100644 --- a/ydb/aio/iam.py +++ b/ydb/aio/iam.py @@ -3,7 +3,6 @@ import abc import logging -import six from ydb.iam import auth from .credentials import AbstractExpiringTokenCredentials @@ -24,7 +23,6 @@ aiohttp = None -@six.add_metaclass(abc.ABCMeta) class TokenServiceCredentials(AbstractExpiringTokenCredentials): def __init__(self, iam_endpoint=None, iam_channel_credentials=None): super(TokenServiceCredentials, self).__init__() diff --git a/ydb/aio/table.py b/ydb/aio/table.py index 92ed9812..06f8ca7c 100644 --- a/ydb/aio/table.py +++ b/ydb/aio/table.py @@ -13,7 +13,6 @@ _scan_query_request_factory, _wrap_scan_query_response, BaseTxContext, - _allow_split_transaction, ) from . import _utilities from ydb import _apis, _session_impl @@ -121,9 +120,7 @@ async def alter_table( set_read_replicas_settings, ) - def transaction( - self, tx_mode=None, *, allow_split_transactions=_allow_split_transaction - ): + def transaction(self, tx_mode=None, *, allow_split_transactions=None): return TxContext( self._driver, self._state, diff --git a/ydb/auth_helpers.py b/ydb/auth_helpers.py index 5d889555..6399c3cf 100644 --- a/ydb/auth_helpers.py +++ b/ydb/auth_helpers.py @@ -1,9 +1,6 @@ # -*- coding: utf-8 -*- import os -from . import credentials, tracing -import warnings - def read_bytes(f): with open(f, "rb") as fr: @@ -15,43 +12,3 @@ def load_ydb_root_certificate(): if path is not None and os.path.exists(path): return read_bytes(path) return None - - -def construct_credentials_from_environ(tracer=None): - tracer = tracer if tracer is not None else tracing.Tracer(None) - warnings.warn( - "using construct_credentials_from_environ method for credentials instantiation is deprecated and will be " - "removed in the future major releases. Please instantialize credentials by default or provide correct credentials " - "instance to the Driver." - ) - - # dynamically import required authentication libraries - if ( - os.getenv("USE_METADATA_CREDENTIALS") is not None - and int(os.getenv("USE_METADATA_CREDENTIALS")) == 1 - ): - import ydb.iam - - tracing.trace(tracer, {"credentials.metadata": True}) - return ydb.iam.MetadataUrlCredentials() - - if os.getenv("YDB_TOKEN") is not None: - tracing.trace(tracer, {"credentials.access_token": True}) - return credentials.AuthTokenCredentials(os.getenv("YDB_TOKEN")) - - if os.getenv("SA_KEY_FILE") is not None: - - import ydb.iam - - tracing.trace(tracer, {"credentials.sa_key_file": True}) - root_certificates_file = os.getenv("SSL_ROOT_CERTIFICATES_FILE", None) - iam_channel_credentials = {} - if root_certificates_file is not None: - iam_channel_credentials = { - "root_certificates": read_bytes(root_certificates_file) - } - return ydb.iam.ServiceAccountCredentials.from_file( - os.getenv("SA_KEY_FILE"), - iam_channel_credentials=iam_channel_credentials, - iam_endpoint=os.getenv("IAM_ENDPOINT", "iam.api.cloud.yandex.net:443"), - ) diff --git a/ydb/convert.py b/ydb/convert.py index b231bb10..81348d31 100644 --- a/ydb/convert.py +++ b/ydb/convert.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import decimal from google.protobuf import struct_pb2 -import six from . import issues, types, _apis @@ -13,7 +12,7 @@ _DecimalInfRepr = 10**35 _DecimalSignedInfRepr = -(10**35) _primitive_type_by_id = {} -_default_allow_truncated_result = True +_default_allow_truncated_result = False def _initialize(): @@ -82,9 +81,7 @@ def _pb_to_list(type_pb, value_pb, table_client_settings): def _pb_to_tuple(type_pb, value_pb, table_client_settings): return tuple( _to_native_value(item_type, item_value, table_client_settings) - for item_type, item_value in six.moves.zip( - type_pb.tuple_type.elements, value_pb.items - ) + for item_type, item_value in zip(type_pb.tuple_type.elements, value_pb.items) ) @@ -107,7 +104,7 @@ class _Struct(_DotDict): def _pb_to_struct(type_pb, value_pb, table_client_settings): result = _Struct() - for member, item in six.moves.zip(type_pb.struct_type.members, value_pb.items): + for member, item in zip(type_pb.struct_type.members, value_pb.items): result[member.name] = _to_native_value(member.type, item, table_client_settings) return result @@ -202,9 +199,7 @@ def _list_to_pb(type_pb, value): def _tuple_to_pb(type_pb, value): value_pb = _apis.ydb_value.Value() - for element_type, element_value in six.moves.zip( - type_pb.tuple_type.elements, value - ): + for element_type, element_value in zip(type_pb.tuple_type.elements, value): value_item_proto = value_pb.items.add() value_item_proto.MergeFrom(_from_native_value(element_type, element_value)) return value_pb @@ -290,7 +285,7 @@ def parameters_to_pb(parameters_types, parameters_values): return {} param_values_pb = {} - for name, type_pb in six.iteritems(parameters_types): + for name, type_pb in parameters_types.items(): result = _apis.ydb_value.TypedValue() ttype = type_pb if isinstance(type_pb, types.AbstractTypeBuilder): @@ -332,7 +327,7 @@ def from_message(cls, message, table_client_settings=None, snapshot=None): for row_proto in message.rows: row = _Row(message.columns) - for column, value, column_info in six.moves.zip( + for column, value, column_info in zip( message.columns, row_proto.items, column_parsers ): v_type = value.WhichOneof("value") @@ -400,9 +395,7 @@ def __init__(self, columns, proto_row, table_client_settings, parsers): super(_LazyRow, self).__init__() self._columns = columns self._table_client_settings = table_client_settings - for i, (column, row_item) in enumerate( - six.moves.zip(self._columns, proto_row.items) - ): + for i, (column, row_item) in enumerate(zip(self._columns, proto_row.items)): super(_LazyRow, self).__setitem__( column.name, _LazyRowItem(row_item, column.type, table_client_settings, parsers[i]), diff --git a/ydb/credentials.py b/ydb/credentials.py index 8e22fe2a..2a2dea3b 100644 --- a/ydb/credentials.py +++ b/ydb/credentials.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import abc -import six +import typing + from . import tracing, issues, connection from . import settings as settings_impl import threading @@ -9,8 +10,7 @@ import time # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_auth_pb2 from ._grpc.v4 import ydb_auth_v1_pb2_grpc else: @@ -22,15 +22,13 @@ logger = logging.getLogger(__name__) -@six.add_metaclass(abc.ABCMeta) -class AbstractCredentials(object): +class AbstractCredentials(abc.ABC): """ An abstract class that provides auth metadata """ -@six.add_metaclass(abc.ABCMeta) -class Credentials(object): +class Credentials(abc.ABC): def __init__(self, tracer=None): self.tracer = tracer if tracer is not None else tracing.Tracer(None) @@ -41,6 +39,12 @@ def auth_metadata(self): """ pass + def get_auth_token(self) -> str: + for header, token in self.auth_metadata(): + if header == YDB_AUTH_TICKET_HEADER: + return token + return "" + class OneToManyValue(object): def __init__(self): @@ -87,7 +91,6 @@ def cleanup(self): self._can_schedule = True -@six.add_metaclass(abc.ABCMeta) class AbstractExpiringTokenCredentials(Credentials): def __init__(self, tracer=None): super(AbstractExpiringTokenCredentials, self).__init__(tracer) diff --git a/ydb/dbapi/cursor.py b/ydb/dbapi/cursor.py index 71175abf..eb26dc2b 100644 --- a/ydb/dbapi/cursor.py +++ b/ydb/dbapi/cursor.py @@ -5,8 +5,6 @@ import itertools import logging -import six - import ydb from .errors import DatabaseError @@ -42,7 +40,7 @@ def render_datetime(value): def render(value): if value is None: return "NULL" - if isinstance(value, six.string_types): + if isinstance(value, str): return render_str(value) if isinstance(value, datetime.datetime): return render_datetime(value) diff --git a/ydb/default_pem.py b/ydb/default_pem.py index 92286ba2..b8272efd 100644 --- a/ydb/default_pem.py +++ b/ydb/default_pem.py @@ -1,6 +1,3 @@ -import six - - data = """ # Issuer: CN=GlobalSign Root CA O=GlobalSign nv-sa OU=Root CA # Subject: CN=GlobalSign Root CA O=GlobalSign nv-sa OU=Root CA @@ -4686,6 +4683,4 @@ def load_default_pem(): global data - if six.PY3: - return data.encode("utf-8") - return data + return data.encode("utf-8") diff --git a/ydb/driver.py b/ydb/driver.py index 9b3fa99c..e3274687 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -1,15 +1,11 @@ # -*- coding: utf-8 -*- from . import credentials as credentials_impl, table, scheme, pool from . import tracing -import six import os import grpc from . import _utilities -if six.PY2: - Any = None -else: - from typing import Any # noqa +from typing import Any # noqa class RPCCompression: @@ -23,10 +19,17 @@ class RPCCompression: def default_credentials(credentials=None, tracer=None): tracer = tracer if tracer is not None else tracing.Tracer(None) with tracer.trace("Driver.default_credentials") as ctx: - if credentials is not None: + if credentials is None: + ctx.trace({"credentials.anonymous": True}) + return credentials_impl.AnonymousCredentials() + else: ctx.trace({"credentials.prepared": True}) return credentials + +def credentials_from_env_variables(tracer=None): + tracer = tracer if tracer is not None else tracing.Tracer(None) + with tracer.trace("Driver.credentials_from_env_variables") as ctx: service_account_key_file = os.getenv("YDB_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS") if service_account_key_file is not None: ctx.trace({"credentials.service_account_key_file": True}) @@ -51,9 +54,7 @@ def default_credentials(credentials=None, tracer=None): ctx.trace({"credentials.access_token": True}) return credentials_impl.AuthTokenCredentials(access_token) - import ydb.iam - - return ydb.iam.MetadataUrlCredentials(tracer=tracer) + return default_credentials(None, tracer) class DriverConfig(object): @@ -70,6 +71,7 @@ class DriverConfig(object): "grpc_keep_alive_timeout", "secure_channel", "table_client_settings", + "topic_client_settings", "endpoints", "primary_user_agent", "tracer", @@ -92,6 +94,7 @@ def __init__( private_key=None, grpc_keep_alive_timeout=None, table_client_settings=None, + topic_client_settings=None, endpoints=None, primary_user_agent="python-library", tracer=None, @@ -138,6 +141,7 @@ def __init__( self.private_key = private_key self.grpc_keep_alive_timeout = grpc_keep_alive_timeout self.table_client_settings = table_client_settings + self.topic_client_settings = topic_client_settings self.primary_user_agent = primary_user_agent self.tracer = tracer if tracer is not None else tracing.Tracer(None) self.grpc_lb_policy_name = grpc_lb_policy_name @@ -228,6 +232,8 @@ def __init__( :param database: A database path :param credentials: A credentials. If not specifed credentials constructed by default. """ + from . import topic # local import for prevent cycle import error + driver_config = get_config( driver_config, connection_string, @@ -238,5 +244,9 @@ def __init__( ) super(Driver, self).__init__(driver_config) + + self._credentials = driver_config.credentials + self.scheme_client = scheme.SchemeClient(self) self.table_client = table.TableClient(self, driver_config.table_client_settings) + self.topic_client = topic.TopicClient(self, driver_config.topic_client_settings) diff --git a/ydb/export.py b/ydb/export.py index 30898cbb..bc35bd28 100644 --- a/ydb/export.py +++ b/ydb/export.py @@ -1,12 +1,12 @@ import enum +import typing from . import _apis from . import settings_impl as s_impl # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_export_pb2 from ._grpc.v4 import ydb_export_v1_pb2_grpc else: diff --git a/ydb/global_settings.py b/ydb/global_settings.py index de8b0b1b..8edac3f4 100644 --- a/ydb/global_settings.py +++ b/ydb/global_settings.py @@ -1,16 +1,24 @@ +import warnings + from . import convert from . import table def global_allow_truncated_result(enabled: bool = True): - """ - call global_allow_truncated_result(False) for more safe execution and compatible with future changes - """ + if convert._default_allow_truncated_result == enabled: + return + + if enabled: + warnings.warn("Global allow truncated response is deprecated behaviour.") + convert._default_allow_truncated_result = enabled def global_allow_split_transactions(enabled: bool): - """ - call global_allow_truncated_result(False) for more safe execution and compatible with future changes - """ - table._allow_split_transaction = enabled + if table._default_allow_split_transaction == enabled: + return + + if enabled: + warnings.warn("Global allow split transaction is deprecated behaviour.") + + table._default_allow_split_transaction = enabled diff --git a/ydb/iam/auth.py b/ydb/iam/auth.py index 06b07e91..50d98b4b 100644 --- a/ydb/iam/auth.py +++ b/ydb/iam/auth.py @@ -3,7 +3,6 @@ import grpc import time import abc -import six from datetime import datetime import json import os @@ -45,7 +44,6 @@ def get_jwt(account_id, access_key_id, private_key, jwt_expiration_timeout): ) -@six.add_metaclass(abc.ABCMeta) class TokenServiceCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self, iam_endpoint=None, iam_channel_credentials=None, tracer=None): super(TokenServiceCredentials, self).__init__(tracer) @@ -84,8 +82,7 @@ def _make_token_request(self): return {"access_token": response.iam_token, "expires_in": expires_in} -@six.add_metaclass(abc.ABCMeta) -class BaseJWTCredentials(object): +class BaseJWTCredentials(abc.ABC): def __init__(self, account_id, access_key_id, private_key): self._account_id = account_id self._jwt_expiration_timeout = 60.0 * 60 diff --git a/ydb/import_client.py b/ydb/import_client.py index d1ccc99a..d94294ca 100644 --- a/ydb/import_client.py +++ b/ydb/import_client.py @@ -1,12 +1,12 @@ import enum +import typing from . import _apis from . import settings_impl as s_impl # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_import_pb2 from ._grpc.v4 import ydb_import_v1_pb2_grpc else: diff --git a/ydb/issues.py b/ydb/issues.py index 2e128d5a..100af01d 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from google.protobuf import text_format import enum -from six.moves import queue +import queue from . import _apis diff --git a/ydb/pool.py b/ydb/pool.py index dfda0adf..007aa94d 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- +import abc import threading import logging from concurrent import futures import collections import random -import six - from . import connection as connection_impl, issues, resolver, _utilities, tracing -from abc import abstractmethod, ABCMeta +from abc import abstractmethod +from .connection import Connection logger = logging.getLogger(__name__) @@ -127,7 +127,7 @@ def subscribe(self): return subscription @tracing.with_trace() - def get(self, preferred_endpoint=None): + def get(self, preferred_endpoint=None) -> Connection: with self.lock: if ( preferred_endpoint is not None @@ -295,8 +295,7 @@ def run(self): self.logger.info("Successfully terminated discovery process") -@six.add_metaclass(ABCMeta) -class IConnectionPool: +class IConnectionPool(abc.ABC): @abstractmethod def __init__(self, driver_config): """ diff --git a/ydb/public/__init__.py b/ydb/public/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ydb/public/api/__init__.py b/ydb/public/api/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ydb/public/api/grpc/__init__.py b/ydb/public/api/grpc/__init__.py deleted file mode 100644 index 08dd1066..00000000 --- a/ydb/public/api/grpc/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from ydb._grpc.common import * # noqa -import sys -import warnings - -sys.modules["ydb.public.api.grpc"] = sys.modules["ydb._grpc.common"] -warnings.warn( - "using ydb.public.api.grpc module is deprecated. Don't use direct grpc dependencies." -) diff --git a/ydb/public/api/protos/__init__.py b/ydb/public/api/protos/__init__.py deleted file mode 100644 index 204ff3b9..00000000 --- a/ydb/public/api/protos/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from ydb._grpc.common.protos import * # noqa -import sys -import warnings - -sys.modules["ydb.public.api.protos"] = sys.modules["ydb._grpc.common.protos"] -warnings.warn( - "using ydb.public.api.protos module is deprecated. Don't use direct grpc dependencies." -) diff --git a/ydb/public/stub.txt b/ydb/public/stub.txt new file mode 100644 index 00000000..6c4e84e7 --- /dev/null +++ b/ydb/public/stub.txt @@ -0,0 +1 @@ +the folder must not use for prevent issues with intersect with old protobuf generate packages. diff --git a/ydb/scheme.py b/ydb/scheme.py index 88eca78c..96a6c25f 100644 --- a/ydb/scheme.py +++ b/ydb/scheme.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import abc import enum -import six from abc import abstractmethod from . import issues, operation, settings as settings_impl, _apis @@ -347,8 +346,7 @@ def _wrap_describe_path_response(rpc_state, response): return _wrap_scheme_entry(message.self) -@six.add_metaclass(abc.ABCMeta) -class ISchemeClient: +class ISchemeClient(abc.ABC): @abstractmethod def __init__(self, driver): pass diff --git a/ydb/scripting.py b/ydb/scripting.py index 9fed037a..13132430 100644 --- a/ydb/scripting.py +++ b/ydb/scripting.py @@ -1,6 +1,7 @@ +import typing + # Workaround for good IDE and universal for runtime -# noinspection PyUnreachableCode -if False: +if typing.TYPE_CHECKING: from ._grpc.v4.protos import ydb_scripting_pb2 from ._grpc.v4 import ydb_scripting_v1_pb2_grpc else: diff --git a/ydb/table.py b/ydb/table.py index 73e1dcab..799a5426 100644 --- a/ydb/table.py +++ b/ydb/table.py @@ -7,7 +7,6 @@ import random import enum -import six from . import ( issues, convert, @@ -28,7 +27,7 @@ except ImportError: interceptor = None -_allow_split_transaction = True +_default_allow_split_transaction = False logger = logging.getLogger(__name__) @@ -770,8 +769,7 @@ def with_compaction_policy(self, compaction_policy): return self -@six.add_metaclass(abc.ABCMeta) -class AbstractTransactionModeBuilder(object): +class AbstractTransactionModeBuilder(abc.ABC): @property @abc.abstractmethod def name(self): @@ -949,7 +947,7 @@ def retry_operation_impl(callee, retry_settings=None, *args, **kwargs): retry_settings = RetrySettings() if retry_settings is None else retry_settings status = None - for attempt in six.moves.range(retry_settings.max_retries + 1): + for attempt in range(retry_settings.max_retries + 1): try: result = YdbRetryOperationFinalResult(callee(*args, **kwargs)) yield result @@ -1103,8 +1101,7 @@ def _scan_query_request_factory(query, parameters=None, settings=None): ) -@six.add_metaclass(abc.ABCMeta) -class ISession: +class ISession(abc.ABC): @abstractmethod def __init__(self, driver, table_client_settings): pass @@ -1184,9 +1181,7 @@ def execute_scheme(self, yql_text, settings=None): pass @abstractmethod - def transaction( - self, tx_mode=None, allow_split_transactions=_allow_split_transaction - ): + def transaction(self, tx_mode=None, allow_split_transactions=None): pass @abstractmethod @@ -1268,8 +1263,7 @@ def describe_table(self, path, settings=None): pass -@six.add_metaclass(abc.ABCMeta) -class ITableClient: +class ITableClient(abc.ABC): def __init__(self, driver, table_client_settings=None): pass @@ -1691,9 +1685,7 @@ def execute_scheme(self, yql_text, settings=None): self._state.endpoint, ) - def transaction( - self, tx_mode=None, allow_split_transactions=_allow_split_transaction - ): + def transaction(self, tx_mode=None, allow_split_transactions=None): return TxContext( self._driver, self._state, @@ -2100,8 +2092,7 @@ def async_describe_table(self, path, settings=None): ) -@six.add_metaclass(abc.ABCMeta) -class ITxContext: +class ITxContext(abc.ABC): @abstractmethod def __init__(self, driver, session_state, session, tx_mode=None): """ @@ -2231,7 +2222,7 @@ def __init__( session, tx_mode=None, *, - allow_split_transactions=_allow_split_transaction + allow_split_transactions=None ): """ An object that provides a simple transaction context manager that allows statements execution @@ -2418,7 +2409,13 @@ def _check_split(self, allow=""): Deny all operaions with transaction after commit/rollback. Exception: double commit and double rollbacks, because it is safe """ - if self._allow_split_transactions: + allow_split_transaction = ( + self._allow_split_transactions + if self._allow_split_transactions is not None + else _default_allow_split_transaction + ) + + if allow_split_transaction: return if self._finished != "" and self._finished != allow: diff --git a/ydb/topic.py b/ydb/topic.py new file mode 100644 index 00000000..7dde70ff --- /dev/null +++ b/ydb/topic.py @@ -0,0 +1,384 @@ +from __future__ import annotations + +__all__ = [ + "TopicClient", + "TopicClientAsyncIO", + "TopicClientSettings", + "TopicCodec", + "TopicConsumer", + "TopicDescription", + "TopicError", + "TopicMeteringMode", + "TopicReader", + "TopicReaderAsyncIO", + "TopicReaderSettings", + "TopicStatWindow", + "TopicWriter", + "TopicWriterAsyncIO", + "TopicWriterMessage", + "TopicWriterSettings", +] + +import concurrent.futures +import datetime +from dataclasses import dataclass +from typing import List, Union, Mapping, Optional, Dict, Callable + +from . import aio, Credentials, _apis, issues + +from . import driver + +from ._topic_reader.topic_reader import ( + PublicReaderSettings as TopicReaderSettings, +) + +from ._topic_reader.topic_reader_sync import TopicReaderSync as TopicReader + +from ._topic_reader.topic_reader_asyncio import ( + PublicAsyncIOReader as TopicReaderAsyncIO, +) + +from ._topic_writer.topic_writer import ( # noqa: F401 + PublicWriterSettings as TopicWriterSettings, + PublicMessage as TopicWriterMessage, + RetryPolicy as TopicWriterRetryPolicy, +) + +from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter + +from ._topic_common.common import ( + wrap_operation as _wrap_operation, + create_result_wrapper as _create_result_wrapper, +) + +from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO + +from ._grpc.grpcwrapper import ydb_topic as _ydb_topic +from ._grpc.grpcwrapper import ydb_topic_public_types as _ydb_topic_public_types +from ._grpc.grpcwrapper.ydb_topic_public_types import ( # noqa: F401 + PublicDescribeTopicResult as TopicDescription, + PublicMultipleWindowsStat as TopicStatWindow, + PublicPartitionStats as TopicPartitionStats, + PublicCodec as TopicCodec, + PublicConsumer as TopicConsumer, + PublicMeteringMode as TopicMeteringMode, +) + + +class TopicClientAsyncIO: + _closed: bool + _driver: aio.Driver + _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor + + def __init__( + self, driver: aio.Driver, settings: Optional[TopicClientSettings] = None + ): + if not settings: + settings = TopicClientSettings() + self._closed = False + self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() + + async def create_topic( + self, + path: str, + min_active_partitions: Optional[int] = None, + partition_count_limit: Optional[int] = None, + retention_period: Optional[datetime.timedelta] = None, + retention_storage_mb: Optional[int] = None, + supported_codecs: Optional[List[Union[TopicCodec, int]]] = None, + partition_write_speed_bytes_per_second: Optional[int] = None, + partition_write_burst_bytes: Optional[int] = None, + attributes: Optional[Dict[str, str]] = None, + consumers: Optional[List[Union[TopicConsumer, str]]] = None, + metering_mode: Optional[TopicMeteringMode] = None, + ): + """ + create topic command + + :param path: full path to topic + :param min_active_partitions: Minimum partition count auto merge would stop working at. + :param partition_count_limit: Limit for total partition count, including active (open for write) + and read-only partitions. + :param retention_period: How long data in partition should be stored + :param retention_storage_mb: How much data in partition should be stored + :param supported_codecs: List of allowed codecs for writers. Writes with codec not from this list are forbidden. + Empty list mean disable codec compatibility checks for the topic. + :param partition_write_speed_bytes_per_second: Partition write speed in bytes per second + :param partition_write_burst_bytes: Burst size for write in partition, in bytes + :param attributes: User and server attributes of topic. + Server attributes starts from "_" and will be validated by server. + :param consumers: List of consumers for this topic + :param metering_mode: Metering mode for the topic in a serverless database + """ + args = locals().copy() + del args["self"] + req = _ydb_topic_public_types.CreateTopicRequestParams(**args) + req = _ydb_topic.CreateTopicRequest.from_public(req) + await self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.CreateTopic, + _wrap_operation, + ) + + async def describe_topic( + self, path: str, include_stats: bool = False + ) -> TopicDescription: + args = locals().copy() + del args["self"] + req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) + res = await self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DescribeTopic, + _create_result_wrapper(_ydb_topic.DescribeTopicResult), + ) # type: _ydb_topic.DescribeTopicResult + return res.to_public() + + async def drop_topic(self, path: str): + req = _ydb_topic_public_types.DropTopicRequestParams(path=path) + await self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DropTopic, + _wrap_operation, + ) + + def reader( + self, + topic: str, + consumer: str, + buffer_size_bytes: int = 50 * 1024 * 1024, + # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes + decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + decoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool + ) -> TopicReaderAsyncIO: + + if not decoder_executor: + decoder_executor = self._executor + + args = locals() + del args["self"] + + settings = TopicReaderSettings(**args) + + return TopicReaderAsyncIO(self._driver, settings) + + def writer( + self, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool + ) -> TopicWriterAsyncIO: + args = locals() + del args["self"] + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TopicWriterAsyncIO(self._driver, settings) + + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + + raise RuntimeError("Topic client closed") + + +class TopicClient: + _closed: bool + _driver: driver.Driver + _credentials: Union[Credentials, None] + _settings: TopicClientSettings + _executor: concurrent.futures.Executor + + def __init__(self, driver: driver.Driver, settings: Optional[TopicClientSettings]): + if not settings: + settings = TopicClientSettings() + + self._closed = False + self._driver = driver + self._settings = settings + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=settings.encode_decode_threads_count, + thread_name_prefix="topic_asyncio_executor", + ) + + def __del__(self): + self.close() + + def create_topic( + self, + path: str, + min_active_partitions: Optional[int] = None, + partition_count_limit: Optional[int] = None, + retention_period: Optional[datetime.timedelta] = None, + retention_storage_mb: Optional[int] = None, + supported_codecs: Optional[List[Union[TopicCodec, int]]] = None, + partition_write_speed_bytes_per_second: Optional[int] = None, + partition_write_burst_bytes: Optional[int] = None, + attributes: Optional[Dict[str, str]] = None, + consumers: Optional[List[Union[TopicConsumer, str]]] = None, + metering_mode: Optional[TopicMeteringMode] = None, + ): + """ + create topic command + + :param path: full path to topic + :param min_active_partitions: Minimum partition count auto merge would stop working at. + :param partition_count_limit: Limit for total partition count, including active (open for write) + and read-only partitions. + :param retention_period: How long data in partition should be stored + :param retention_storage_mb: How much data in partition should be stored + :param supported_codecs: List of allowed codecs for writers. Writes with codec not from this list are forbidden. + Empty list mean disable codec compatibility checks for the topic. + :param partition_write_speed_bytes_per_second: Partition write speed in bytes per second + :param partition_write_burst_bytes: Burst size for write in partition, in bytes + :param attributes: User and server attributes of topic. + Server attributes starts from "_" and will be validated by server. + :param consumers: List of consumers for this topic + :param metering_mode: Metering mode for the topic in a serverless database + """ + args = locals().copy() + del args["self"] + self._check_closed() + + req = _ydb_topic_public_types.CreateTopicRequestParams(**args) + req = _ydb_topic.CreateTopicRequest.from_public(req) + self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.CreateTopic, + _wrap_operation, + ) + + def describe_topic( + self, path: str, include_stats: bool = False + ) -> TopicDescription: + args = locals().copy() + del args["self"] + self._check_closed() + + req = _ydb_topic_public_types.DescribeTopicRequestParams(**args) + res = self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DescribeTopic, + _create_result_wrapper(_ydb_topic.DescribeTopicResult), + ) # type: _ydb_topic.DescribeTopicResult + return res.to_public() + + def drop_topic(self, path: str): + self._check_closed() + + req = _ydb_topic_public_types.DropTopicRequestParams(path=path) + self._driver( + req.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.DropTopic, + _wrap_operation, + ) + + def reader( + self, + topic: str, + consumer: str, + buffer_size_bytes: int = 50 * 1024 * 1024, + # decoders: map[codec_code] func(encoded_bytes)->decoded_bytes + decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None, + decoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool + ) -> TopicReader: + if not decoder_executor: + decoder_executor = self._executor + + args = locals() + del args["self"] + self._check_closed() + + settings = TopicReaderSettings(**args) + + return TopicReader(self._driver, settings) + + def writer( + self, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + encoders: Optional[ + Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]] + ] = None, + encoder_executor: Optional[ + concurrent.futures.Executor + ] = None, # default shared client executor pool + ) -> TopicWriter: + args = locals() + del args["self"] + self._check_closed() + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TopicWriter(self._driver, settings) + + def close(self): + if self._closed: + return + + self._closed = True + self._executor.shutdown(wait=False, cancel_futures=True) + + def _check_closed(self): + if not self._closed: + return + + raise RuntimeError("Topic client closed") + + +@dataclass +class TopicClientSettings: + encode_decode_threads_count: int = 4 + + +class TopicError(issues.Error): + pass diff --git a/ydb/tornado/__init__.py b/ydb/tornado/__init__.py deleted file mode 100644 index eea7bdc4..00000000 --- a/ydb/tornado/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -try: - from .tornado_helpers import * # noqa -except ImportError: - pass diff --git a/ydb/tornado/tornado_helpers.py b/ydb/tornado/tornado_helpers.py deleted file mode 100644 index 973c4435..00000000 --- a/ydb/tornado/tornado_helpers.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- -try: - import tornado.concurrent - import tornado.ioloop - import tornado.gen - from tornado.concurrent import TracebackFuture -except ImportError: - tornado = None - -from .table import retry_operation_impl, YdbRetryOperationSleepOpt - - -def as_tornado_future(foreign_future, timeout=None): - """ - Return tornado.concurrent.Future wrapped python concurrent.future (foreign_future). - Cancel execution original future after given timeout - """ - result_future = tornado.concurrent.Future() - timeout_timer = set() - if timeout: - - def on_timeout(): - timeout_timer.clear() - foreign_future.cancel() - - timeout_timer.add( - tornado.ioloop.IOLoop.current().call_later(timeout, on_timeout) - ) - - def copy_to_result_future(foreign_future): - try: - to_remove = timeout_timer.pop() - tornado.ioloop.IOLoop.current().remove_timeout(to_remove) - except KeyError: - pass - - if result_future.done(): - return - - if ( - isinstance(foreign_future, TracebackFuture) - and isinstance(result_future, TracebackFuture) - and result_future.exc_info() is not None - ): - result_future.set_exc_info(foreign_future.exc_info()) - elif foreign_future.cancelled(): - result_future.set_exception(tornado.gen.TimeoutError()) - elif foreign_future.exception() is not None: - result_future.set_exception(foreign_future.exception()) - else: - result_future.set_result(foreign_future.result()) - - tornado.ioloop.IOLoop.current().add_future(foreign_future, copy_to_result_future) - return result_future - - -async def retry_operation(callee, retry_settings=None, *args, **kwargs): - opt_generator = retry_operation_impl(callee, retry_settings, *args, **kwargs) - - for next_opt in opt_generator: - if isinstance(next_opt, YdbRetryOperationSleepOpt): - await tornado.gen.sleep(next_opt.timeout) - else: - try: - return await next_opt.result - except Exception as e: - next_opt.set_exception(e) diff --git a/ydb/types.py b/ydb/types.py index a62c8a74..5ffa16e6 100644 --- a/ydb/types.py +++ b/ydb/types.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import abc import enum -import six import json from . import _utilities, _apis from datetime import date, datetime, timedelta @@ -13,13 +12,6 @@ _SECONDS_IN_DAY = 60 * 60 * 24 _EPOCH = datetime(1970, 1, 1) -if six.PY3: - _from_bytes = None -else: - - def _from_bytes(x, table_client_settings): - return _utilities.from_bytes(x) - def _from_date(x, table_client_settings): if ( @@ -52,8 +44,6 @@ def _from_json(x, table_client_settings): and table_client_settings._native_json_in_result_sets ): return json.loads(x) - if _from_bytes is not None: - return _from_bytes(x, table_client_settings) return x @@ -122,7 +112,7 @@ class PrimitiveType(enum.Enum): Float = _apis.primitive_types.FLOAT, "float_value" String = _apis.primitive_types.STRING, "bytes_value" - Utf8 = _apis.primitive_types.UTF8, "text_value", _from_bytes + Utf8 = _apis.primitive_types.UTF8, "text_value" Yson = _apis.primitive_types.YSON, "bytes_value" Json = _apis.primitive_types.JSON, "text_value", _from_json @@ -152,7 +142,7 @@ class PrimitiveType(enum.Enum): _to_interval, ) - DyNumber = _apis.primitive_types.DYNUMBER, "text_value", _from_bytes + DyNumber = _apis.primitive_types.DYNUMBER, "text_value" def __init__(self, idn, proto_field, to_obj=None, from_obj=None): self._idn_ = idn diff --git a/ydb/ydb_version.py b/ydb/ydb_version.py index e783fc72..9b3d0a8c 100644 --- a/ydb/ydb_version.py +++ b/ydb/ydb_version.py @@ -1 +1 @@ -VERSION = "2.13.3" +VERSION = "3.0.1b11"