From bea552b6ce9c15cc1380ebb584175df92999a858 Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Mon, 5 Feb 2024 19:31:21 +0100 Subject: [PATCH 01/13] Add optional type casting to limit-offset clause --- test/test_suite.py | 73 +++++-------------------- ydb_sqlalchemy/sqlalchemy/__init__.py | 79 +++++++++++++++++++++++---- ydb_sqlalchemy/sqlalchemy/types.py | 24 +++++++- 3 files changed, 106 insertions(+), 70 deletions(-) diff --git a/test/test_suite.py b/test/test_suite.py index dc61109..f093e30 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -6,13 +6,12 @@ from sqlalchemy.testing import is_true, is_false from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements, fixtures -from sqlalchemy.testing.suite import func, column, literal_column, select, exists +from sqlalchemy.testing.suite import func, column, literal_column, select, exists, union from sqlalchemy.testing.suite import MetaData, Column, Table, Integer, String from sqlalchemy.testing.suite.test_select import ( ExistsTest as _ExistsTest, LikeFunctionsTest as _LikeFunctionsTest, - CompoundSelectTest as _CompoundSelectTest, ) from sqlalchemy.testing.suite.test_reflection import ( HasTableTest as _HasTableTest, @@ -49,7 +48,6 @@ from sqlalchemy.testing.suite.test_insert import InsertBehaviorTest as _InsertBehaviorTest from sqlalchemy.testing.suite.test_ddl import LongNameBlowoutTest as _LongNameBlowoutTest from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest -from sqlalchemy.testing.suite.test_deprecations import DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types @@ -294,20 +292,6 @@ def test_not_regexp_match(self): self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10, 11}) -class CompoundSelectTest(_CompoundSelectTest): - @pytest.mark.skip("limit don't work") - def test_distinct_selectable_in_unions(self): - pass - - @pytest.mark.skip("limit don't work") - def test_limit_offset_in_unions_from_alias(self): - pass - - @pytest.mark.skip("limit don't work") - def test_limit_offset_aliased_selectable_in_unions(self): - pass - - class EscapingTest(_EscapingTest): @provide_metadata def test_percent_sign_round_trip(self): @@ -364,45 +348,23 @@ def test_group_by_composed(self): class FetchLimitOffsetTest(_FetchLimitOffsetTest): - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_bound_limit(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_bound_limit_offset(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_bound_offset(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_expr_limit_simple_offset(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") def test_limit_render_multiple_times(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_limit(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_limit_offset(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_offset(self, connection): - pass + """ + YQL does not support scalar subquery, so test was refiled with simple subquery + """ + table = self.tables.some_table + stmt = select(table.c.id).limit(1).subquery() - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_offset_zero(self, connection): - pass + u = union(select(stmt), select(stmt)).subquery().select() - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_limit_expr_offset(self, connection): - pass + self._assert_result( + connection, + u, + [ + (1,), + (1,), + ], + ) class InsertBehaviorTest(_InsertBehaviorTest): @@ -539,8 +501,3 @@ class RowFetchTest(_RowFetchTest): @pytest.mark.skip("scalar subquery unsupported") def test_row_w_scalar_select(self, connection): pass - - -@pytest.mark.skip("TODO: try it after limit/offset tests would fixed") -class DeprecatedCompoundSelectTest(_DeprecatedCompoundSelectTest): - pass diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 297d4ff..202590c 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -24,7 +24,7 @@ from sqlalchemy.engine.default import StrCompileDialect, DefaultExecutionContext from sqlalchemy.util.compat import inspect_getfullargspec -from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict +from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict, Type from . import types @@ -87,15 +87,30 @@ def visit_FLOAT(self, type_: sa.FLOAT, **kw): def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw): return "BOOL" + def visit_uint64(self, type_: types.UInt64, **kw): + return "UInt64" + def visit_uint32(self, type_: types.UInt32, **kw): return "UInt32" - def visit_uint64(self, type_: types.UInt64, **kw): - return "UInt64" + def visit_uint16(self, type_: types.UInt16, **kw): + return "UInt16" def visit_uint8(self, type_: types.UInt8, **kw): return "UInt8" + def visit_int64(self, type_: types.UInt64, **kw): + return "Int64" + + def visit_int32(self, type_: types.UInt32, **kw): + return "Int32" + + def visit_int16(self, type_: types.UInt16, **kw): + return "Int16" + + def visit_int8(self, type_: types.UInt8, **kw): + return "Int8" + def visit_INTEGER(self, type_: sa.INTEGER, **kw): return "Int64" @@ -134,8 +149,28 @@ def get_ydb_type( if isinstance(type_, (sa.Text, sa.String, sa.Uuid)): ydb_type = ydb.PrimitiveType.Utf8 + + # Integers + elif isinstance(type_, types.UInt64): + ydb_type = ydb.PrimitiveType.Uint64 + elif isinstance(type_, types.UInt32): + ydb_type = ydb.PrimitiveType.Uint32 + elif isinstance(type_, types.UInt16): + ydb_type = ydb.PrimitiveType.Uint16 + elif isinstance(type_, types.UInt8): + ydb_type = ydb.PrimitiveType.Uint8 + elif isinstance(type_, types.Int64): + ydb_type = ydb.PrimitiveType.Int64 + elif isinstance(type_, types.Int32): + ydb_type = ydb.PrimitiveType.Int32 + elif isinstance(type_, types.Int16): + ydb_type = ydb.PrimitiveType.Int16 + elif isinstance(type_, types.Int8): + ydb_type = ydb.PrimitiveType.Int8 elif isinstance(type_, sa.Integer): ydb_type = ydb.PrimitiveType.Int64 + # Integers + elif isinstance(type_, sa.JSON): ydb_type = ydb.PrimitiveType.Json elif isinstance(type_, sa.DateTime): @@ -188,6 +223,32 @@ def group_by_clause(self, select, **kw): kw.update(within_columns_clause=True) return super(YqlCompiler, self).group_by_clause(select, **kw) + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + limit_clause = self._maybe_cast( + select._limit_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) + ) + text += "\n LIMIT " + self.process(limit_clause, **kw) + if select._offset_clause is not None: + offset_clause = self._maybe_cast( + select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) + ) + if select._limit_clause is None: + text += "\n LIMIT NULL" + text += " OFFSET " + self.process(offset_clause, **kw) + return text + + def _maybe_cast( + self, + element: Any, + cast_to: Type[sa.types.TypeEngine], + skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None, + ) -> Any: + if not hasattr(element, "type") or not isinstance(element.type, skip_types): + return sa.Cast(element, cast_to) + return element + def render_literal_value(self, value, type_): if isinstance(value, str): value = "".join(STR_QUOTE_MAP.get(x, x) for x in value) @@ -277,16 +338,14 @@ def _is_bound_to_nullable_column(self, bind_name: str) -> bool: def _guess_bound_variable_type_by_parameters( self, bind: sa.BindParameter, post_compile_bind_values: list ) -> Optional[sa.types.TypeEngine]: - if not bind.expanding: - if isinstance(bind.type, sa.types.NullType): - return None - bind_type = bind.type - else: + bind_type = bind.type + if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values): not_null_values = [v for v in post_compile_bind_values if v is not None] if not_null_values: bind_type = sa.BindParameter("", not_null_values[0]).type - else: - return None + + if isinstance(bind_type, sa.types.NullType): + return None return bind_type diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 61fa5ca..105ae8b 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -3,18 +3,38 @@ from typing import Mapping, Any, Union, Type +class UInt64(types.Integer): + __visit_name__ = "uint64" + + class UInt32(types.Integer): __visit_name__ = "uint32" -class UInt64(types.Integer): - __visit_name__ = "uint64" +class UInt16(types.Integer): + __visit_name__ = "uint32" class UInt8(types.Integer): __visit_name__ = "uint8" +class Int64(types.Integer): + __visit_name__ = "int64" + + +class Int32(types.Integer): + __visit_name__ = "int32" + + +class Int16(types.Integer): + __visit_name__ = "int32" + + +class Int8(types.Integer): + __visit_name__ = "int8" + + class ListType(ARRAY): __visit_name__ = "list_type" From a576a67d653e50755676a3be8a8212e5efe03bad Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Mon, 5 Feb 2024 19:49:50 +0100 Subject: [PATCH 02/13] Add test of types --- test/test_core.py | 15 +++++++++++++++ ydb_sqlalchemy/sqlalchemy/__init__.py | 8 ++++---- ydb_sqlalchemy/sqlalchemy/types.py | 2 +- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/test/test_core.py b/test/test_core.py index 1ff9f15..f7337df 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -209,6 +209,21 @@ def test_select_types(self, connection): row = connection.execute(sa.select(tb)).fetchone() assert row == (1, "Hello World!", 3.5, True, now, today) + def test_integer_types(self, connection): + stmt = sa.Select( + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint8", 8, types.UInt8))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint16", 16, types.UInt16))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint32", 32, types.UInt32))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint64", 64, types.UInt64))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int8", -8, types.Int8))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int16", -16, types.Int16))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int32", -32, types.Int32))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int64", -64, types.Int64))), + ) + + result = connection.execute(stmt).fetchone() + assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64") + class TestWithClause(TablesTest): run_create_tables = "each" diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 202590c..9602567 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -99,16 +99,16 @@ def visit_uint16(self, type_: types.UInt16, **kw): def visit_uint8(self, type_: types.UInt8, **kw): return "UInt8" - def visit_int64(self, type_: types.UInt64, **kw): + def visit_int64(self, type_: types.Int64, **kw): return "Int64" - def visit_int32(self, type_: types.UInt32, **kw): + def visit_int32(self, type_: types.Int32, **kw): return "Int32" - def visit_int16(self, type_: types.UInt16, **kw): + def visit_int16(self, type_: types.Int16, **kw): return "Int16" - def visit_int8(self, type_: types.UInt8, **kw): + def visit_int8(self, type_: types.Int8, **kw): return "Int8" def visit_INTEGER(self, type_: sa.INTEGER, **kw): diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 105ae8b..82d44c8 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -12,7 +12,7 @@ class UInt32(types.Integer): class UInt16(types.Integer): - __visit_name__ = "uint32" + __visit_name__ = "uint16" class UInt8(types.Integer): From bf1a57977d4aec5c63c70afe55d6cffda2acbf59 Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Mon, 5 Feb 2024 19:58:33 +0100 Subject: [PATCH 03/13] Always skip cast_to type --- ydb_sqlalchemy/sqlalchemy/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 9602567..1aa21a6 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -245,6 +245,10 @@ def _maybe_cast( cast_to: Type[sa.types.TypeEngine], skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None, ) -> Any: + if not skip_types: + skip_types = (cast_to,) + if cast_to not in skip_types: + skip_types = (*skip_types, cast_to) if not hasattr(element, "type") or not isinstance(element.type, skip_types): return sa.Cast(element, cast_to) return element From dd913596f69857d07f8a0d092e5a54be7bb3f002 Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Wed, 24 Jan 2024 19:47:37 +0100 Subject: [PATCH 04/13] Add .vscode to .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index bd6ad26..adcbed1 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # PyCharm .idea/ + +# VSCode +.vscode From 9bd6243a7cc41148a80ee953558efd5040369b71 Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Wed, 24 Jan 2024 19:48:01 +0100 Subject: [PATCH 05/13] Migrate to class based dbapi --- test_dbapi/conftest.py | 2 +- test_dbapi/test_dbapi.py | 2 +- ydb_sqlalchemy/dbapi/__init__.py | 44 ++++++++++++++------------- ydb_sqlalchemy/sqlalchemy/__init__.py | 2 +- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/test_dbapi/conftest.py b/test_dbapi/conftest.py index 7a9f5a3..a64a68c 100644 --- a/test_dbapi/conftest.py +++ b/test_dbapi/conftest.py @@ -5,6 +5,6 @@ @pytest.fixture(scope="module") def connection(): - conn = dbapi.connect(host="localhost", port="2136", database="/local") + conn = dbapi.YdbDBApi().connect(host="localhost", port="2136", database="/local") yield conn conn.close() diff --git a/test_dbapi/test_dbapi.py b/test_dbapi/test_dbapi.py index ad354d5..b67dafb 100644 --- a/test_dbapi/test_dbapi.py +++ b/test_dbapi/test_dbapi.py @@ -69,7 +69,7 @@ def test_cursor_raw_query(connection): def test_errors(connection): with pytest.raises(dbapi.InterfaceError): - dbapi.connect("localhost:2136", database="/local666") + dbapi.YdbDBApi().connect("localhost:2136", database="/local666") cur = connection.cursor() diff --git a/ydb_sqlalchemy/dbapi/__init__.py b/ydb_sqlalchemy/dbapi/__init__.py index f06e15f..add2be8 100644 --- a/ydb_sqlalchemy/dbapi/__init__.py +++ b/ydb_sqlalchemy/dbapi/__init__.py @@ -13,25 +13,27 @@ NotSupportedError, ) -apilevel = "1.0" +class YdbDBApi: + def __init__(self): + self.paramstyle = "pyformat" + self.threadsafety = 0 + self.apilevel = "1.0" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name, value in { + "Warning": Warning, + "Error": Error, + "InterfaceError": InterfaceError, + "DatabaseError": DatabaseError, + "DataError": DataError, + "OperationalError": OperationalError, + "IntegrityError": IntegrityError, + "InternalError": InternalError, + "ProgrammingError": ProgrammingError, + "NotSupportedError": NotSupportedError, + }.items(): + setattr(self, name, value) -threadsafety = 0 - -paramstyle = "pyformat" - -errors = ( - Warning, - Error, - InterfaceError, - DatabaseError, - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, - NotSupportedError, -) - - -def connect(*args, **kwargs): - return Connection(*args, **kwargs) + def connect(self, *args, **kwargs): + return Connection(*args, **kwargs) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 1aa21a6..2e34f43 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -542,7 +542,7 @@ class YqlDialect(StrCompileDialect): @classmethod def import_dbapi(cls: Any): - return dbapi + return dbapi.YdbDBApi() def _describe_table(self, connection, table_name, schema=None): if schema is not None: From b5215359d24ff12843205cacddeef0fb288ad68d Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Thu, 25 Jan 2024 10:11:22 +0100 Subject: [PATCH 06/13] Implement AsyncCursor --- ydb_sqlalchemy/dbapi/__init__.py | 5 +- ydb_sqlalchemy/dbapi/cursor.py | 166 +++++++++++++++++++++---------- 2 files changed, 119 insertions(+), 52 deletions(-) diff --git a/ydb_sqlalchemy/dbapi/__init__.py b/ydb_sqlalchemy/dbapi/__init__.py index add2be8..f6d42cc 100644 --- a/ydb_sqlalchemy/dbapi/__init__.py +++ b/ydb_sqlalchemy/dbapi/__init__.py @@ -1,5 +1,5 @@ from .connection import Connection, IsolationLevel # noqa: F401 -from .cursor import Cursor, YdbQuery # noqa: F401 +from .cursor import Cursor, AsyncCursor, YdbQuery # noqa: F401 from .errors import ( Warning, Error, @@ -13,13 +13,14 @@ NotSupportedError, ) + class YdbDBApi: def __init__(self): self.paramstyle = "pyformat" self.threadsafety = 0 self.apilevel = "1.0" self._init_dbapi_attributes() - + def _init_dbapi_attributes(self): for name, value in { "Warning": Warning, diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index 4ae9565..ab68066 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -1,9 +1,13 @@ import dataclasses import itertools import logging -from typing import Any, Mapping, Optional, Sequence, Union, Dict, Callable +import functools +from typing import Any, Mapping, Optional, Sequence, Union, Dict +import collections.abc +from sqlalchemy import util import ydb +import ydb.aio from .errors import ( InternalError, @@ -31,10 +35,45 @@ class YdbQuery: is_ddl: bool = False -class Cursor(object): +def _handle_ydb_errors(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: + raise IntegrityError(e.message, e.issues, e.status) from e + except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: + raise NotSupportedError(e.message, e.issues, e.status) from e + except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e: + raise ProgrammingError(e.message, e.issues, e.status) from e + except ( + ydb.issues.TruncatedResponseError, + ydb.issues.ConnectionError, + ydb.issues.Aborted, + ydb.issues.Unavailable, + ydb.issues.Overloaded, + ydb.issues.Undetermined, + ydb.issues.Timeout, + ydb.issues.Cancelled, + ydb.issues.SessionBusy, + ydb.issues.SessionExpired, + ydb.issues.SessionPoolEmpty, + ) as e: + raise OperationalError(e.message, e.issues, e.status) from e + except ydb.issues.GenericError as e: + raise DataError(e.message, e.issues, e.status) from e + except ydb.issues.InternalError as e: + raise InternalError(e.message, e.issues, e.status) from e + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) from e + + return wrapper + + +class Cursor: def __init__( self, - session_pool: ydb.SessionPool, + session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool], tx_context: Optional[ydb.BaseTxContext] = None, ): self.session_pool = session_pool @@ -54,12 +93,9 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = logger.info("execute sql: %s, params: %s", query, parameters) if is_ddl: - chunks = self.session_pool.retry_operation_sync(self._execute_ddl, None, query) + chunks = self._execute_ddl(query) else: - if self.tx_context: - chunks = self._execute_dml(self.tx_context.session, query, parameters, self.tx_context) - else: - chunks = self.session_pool.retry_operation_sync(self._execute_dml, None, query, parameters) + chunks = self._execute_dml(query, parameters) rows = self._rows_iterable(chunks) # Prefetch the description: @@ -74,57 +110,54 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = self.rows = rows - @classmethod + @_handle_ydb_errors def _execute_dml( - cls, - session: ydb.Session, - query: ydb.DataQuery, - parameters: Optional[Mapping[str, Any]] = None, - tx_context: Optional[ydb.BaseTxContext] = None, + self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None ) -> ydb.convert.ResultSets: prepared_query = query if isinstance(query, str) and parameters: - prepared_query = session.prepare(query) + if self.tx_context: + prepared_query = self._run_operation_in_session(self._prepare, query) + else: + prepared_query = self._retry_operation_in_pool(self._prepare, query) - if tx_context: - return cls._handle_ydb_errors(tx_context.execute, prepared_query, parameters) + if self.tx_context: + return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters) - return cls._handle_ydb_errors(session.transaction().execute, prepared_query, parameters, commit_tx=True) + return self._retry_operation_in_pool(self._execute_in_session, prepared_query, parameters) - @classmethod - def _execute_ddl(cls, session: ydb.Session, query: str) -> ydb.convert.ResultSets: - return cls._handle_ydb_errors(session.execute_scheme, query) + @_handle_ydb_errors + def _execute_ddl(self, query: str) -> ydb.convert.ResultSets: + return self._retry_operation_in_pool(self._execute_scheme, query) @staticmethod - def _handle_ydb_errors(callee: Callable, *args, **kwargs) -> Any: - try: - return callee(*args, **kwargs) - except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: - raise IntegrityError(e.message, e.issues, e.status) from e - except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: - raise NotSupportedError(e.message, e.issues, e.status) from e - except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e: - raise ProgrammingError(e.message, e.issues, e.status) from e - except ( - ydb.issues.TruncatedResponseError, - ydb.issues.ConnectionError, - ydb.issues.Aborted, - ydb.issues.Unavailable, - ydb.issues.Overloaded, - ydb.issues.Undetermined, - ydb.issues.Timeout, - ydb.issues.Cancelled, - ydb.issues.SessionBusy, - ydb.issues.SessionExpired, - ydb.issues.SessionPoolEmpty, - ) as e: - raise OperationalError(e.message, e.issues, e.status) from e - except ydb.issues.GenericError as e: - raise DataError(e.message, e.issues, e.status) from e - except ydb.issues.InternalError as e: - raise InternalError(e.message, e.issues, e.status) from e - except ydb.Error as e: - raise DatabaseError(e.message, e.issues, e.status) from e + def _execute_scheme(session: ydb.Session, query: str) -> ydb.convert.ResultSets: + return session.execute_scheme(query) + + @staticmethod + def _prepare(session: ydb.Session, query: str) -> ydb.DataQuery: + return session.prepare(query) + + @staticmethod + def _execute_in_tx( + tx_context: ydb.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + ) -> ydb.convert.ResultSets: + return tx_context.execute(prepared_query, parameters, commit_tx=False) + + @staticmethod + def _execute_in_session( + session: ydb.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + ) -> ydb.convert.ResultSets: + return session.transaction().execute(prepared_query, parameters, commit_tx=True) + + def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs): + return callee(self.tx_context, *args, **kwargs) + + def _run_operation_in_session(self, callee: collections.abc.Callable, *args, **kwargs): + return callee(self.tx_context.session, *args, **kwargs) + + def _retry_operation_in_pool(self, callee: collections.abc.Callable, *args, **kwargs): + return self.session_pool.retry_operation_sync(callee, None, *args, **kwargs) def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets): try: @@ -186,3 +219,36 @@ def close(self): @property def rowcount(self): return len(self._ensure_prefetched()) + + +class AsyncCursor(Cursor): + _await = staticmethod(util.await_only) + + @staticmethod + async def _execute_scheme(session: ydb.aio.table.Session, query: str) -> ydb.convert.ResultSets: + return await session.execute_scheme(query) + + @staticmethod + async def _prepare(session: ydb.aio.table.Session, query: str) -> ydb.DataQuery: + return await session.prepare(query) + + @staticmethod + async def _execute_in_tx( + tx_context: ydb.aio.table.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + ) -> ydb.convert.ResultSets: + return await tx_context.execute(prepared_query, parameters, commit_tx=False) + + @staticmethod + async def _execute_in_session( + session: ydb.aio.table.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + ) -> ydb.convert.ResultSets: + return await session.transaction().execute(prepared_query, parameters, commit_tx=True) + + def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs): + return self._await(callee(self.tx_context, *args, **kwargs)) + + def _run_operation_in_session(self, callee: collections.abc.Coroutine, *args, **kwargs): + return self._await(callee(self.tx_context.session, *args, **kwargs)) + + def _retry_operation_in_pool(self, callee: collections.abc.Coroutine, *args, **kwargs): + return self._await(self.session_pool.retry_operation(callee, None, *args, **kwargs)) From 7197868bac76aedbb4b8c8f18e03d2f000a07e50 Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Thu, 25 Jan 2024 18:09:11 +0100 Subject: [PATCH 07/13] Implement Async driver + tests --- setup.cfg | 2 + setup.py | 2 + test-requirements.txt | 1 + test/conftest.py | 2 + test/test_core.py | 64 ++++++++-- test_dbapi/conftest.py | 10 -- test_dbapi/test_dbapi.py | 175 ++++++++++++++++---------- tox.ini | 2 +- ydb_sqlalchemy/dbapi/__init__.py | 7 +- ydb_sqlalchemy/dbapi/connection.py | 92 ++++++++------ ydb_sqlalchemy/dbapi/cursor.py | 50 +++++++- ydb_sqlalchemy/sqlalchemy/__init__.py | 21 +++- 12 files changed, 295 insertions(+), 133 deletions(-) delete mode 100644 test_dbapi/conftest.py diff --git a/setup.cfg b/setup.cfg index d926c4e..650a160 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,3 +7,5 @@ profile_file=test/profiles.txt [db] default=yql+ydb://localhost:2136/local +ydb=yql+ydb://localhost:2136/local +ydb_async=yql+ydb_async://localhost:2136/local diff --git a/setup.py b/setup.py index 1cc3fb0..61878f5 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,8 @@ entry_points={ "sqlalchemy.dialects": [ "yql.ydb=ydb_sqlalchemy.sqlalchemy:YqlDialect", + "yql.ydb_async=ydb_sqlalchemy.sqlalchemy:AsyncYqlDialect", + "ydb_async=ydb_sqlalchemy.sqlalchemy:AsyncYqlDialect", "ydb=ydb_sqlalchemy.sqlalchemy:YqlDialect", "yql=ydb_sqlalchemy.sqlalchemy:YqlDialect", ] diff --git a/test-requirements.txt b/test-requirements.txt index d345613..ceb987e 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -9,3 +9,4 @@ dockerpty==0.4.1 flake8==3.9.2 black==23.3.0 pytest-cov +pytest-asyncio diff --git a/test/conftest.py b/test/conftest.py index 0f8b014..5c0ed41 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,6 +2,8 @@ from sqlalchemy.dialects import registry registry.register("yql.ydb", "ydb_sqlalchemy.sqlalchemy", "YqlDialect") +registry.register("yql.ydb_async", "ydb_sqlalchemy.sqlalchemy", "AsyncYqlDialect") +registry.register("ydb_async", "ydb_sqlalchemy.sqlalchemy", "AsyncYqlDialect") registry.register("ydb", "ydb_sqlalchemy.sqlalchemy", "YqlDialect") registry.register("yql", "ydb_sqlalchemy.sqlalchemy", "YqlDialect") pytest.register_assert_rewrite("sqlalchemy.testing.assertions") diff --git a/test/test_core.py b/test/test_core.py index f7337df..050b5a1 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -1,3 +1,4 @@ +import asyncio from datetime import date, datetime from decimal import Decimal from typing import NamedTuple @@ -21,6 +22,8 @@ def clear_sql(stm): class TestText(TestBase): + __backend__ = True + def test_sa_text(self, connection): rs = connection.execute(sa.text("SELECT 1 AS value")) assert rs.fetchone() == (1,) @@ -38,6 +41,8 @@ def test_sa_text(self, connection): class TestCrud(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -82,6 +87,8 @@ def test_sa_crud(self, connection): class TestSimpleSelect(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -174,6 +181,8 @@ def test_sa_select_simple(self, connection): class TestTypes(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -226,6 +235,7 @@ def test_integer_types(self, connection): class TestWithClause(TablesTest): + __backend__ = True run_create_tables = "each" @staticmethod @@ -238,10 +248,7 @@ def _create_table_and_get_desc(connection, metadata, **kwargs): ) table.create(connection) - session: ydb.Session = connection.connection.driver_connection.session_pool.acquire() - table_description = session.describe_table("/local/" + table.name) - connection.connection.driver_connection.session_pool.release(session) - return table_description + return connection.connection.driver_connection.describe(table.name) @pytest.mark.parametrize( "auto_partitioning_by_size,res", @@ -389,6 +396,8 @@ def test_several_keys(self, connection, metadata): class TestTransaction(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata: sa.MetaData): Table( @@ -477,6 +486,8 @@ def test_not_interactive_transaction( class TestTransactionIsolationLevel(TestBase): + __backend__ = True + class IsolationSettings(NamedTuple): ydb_mode: ydb.AbstractTransactionModeBuilder interactive: bool @@ -508,7 +519,10 @@ def test_connection_set(self, connection_no_trans: sa.Connection): class TestEngine(TestBase): - @pytest.fixture(scope="module") + __backend__ = True + __only_on__ = "yql+ydb" + + @pytest.fixture(scope="class") def ydb_driver(self): url = config.db_url driver = ydb.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database) @@ -520,13 +534,14 @@ def ydb_driver(self): driver.stop() - @pytest.fixture(scope="module") + @pytest.fixture(scope="class") def ydb_pool(self, ydb_driver): session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1) - yield session_pool - - session_pool.stop() + try: + yield session_pool + finally: + session_pool.stop() def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool}) @@ -559,7 +574,36 @@ def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): assert not ydb_driver._stopped +class TestAsyncEngine(TestEngine): + __only_on__ = "yql+ydb_async" + + @pytest.fixture(scope="class") + def ydb_driver(self): + loop = asyncio.get_event_loop() + url = config.db_url + driver = ydb.aio.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database) + try: + loop.run_until_complete(driver.wait(timeout=5, fail_fast=True)) + yield driver + finally: + loop.run_until_complete(driver.stop()) + + loop.run_until_complete(driver.stop()) + + @pytest.fixture(scope="class") + def ydb_pool(self, ydb_driver): + session_pool = ydb.aio.SessionPool(ydb_driver, size=5) + + try: + yield session_pool + finally: + loop = asyncio.get_event_loop() + loop.run_until_complete(session_pool.stop()) + + class TestUpsert(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -659,6 +703,8 @@ def test_upsert_from_select(self, connection, metadata): class TestUpsertDoesNotReplaceInsert(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( diff --git a/test_dbapi/conftest.py b/test_dbapi/conftest.py deleted file mode 100644 index a64a68c..0000000 --- a/test_dbapi/conftest.py +++ /dev/null @@ -1,10 +0,0 @@ -import pytest - -import ydb_sqlalchemy.dbapi as dbapi - - -@pytest.fixture(scope="module") -def connection(): - conn = dbapi.YdbDBApi().connect(host="localhost", port="2136", database="/local") - yield conn - conn.close() diff --git a/test_dbapi/test_dbapi.py b/test_dbapi/test_dbapi.py index b67dafb..6805468 100644 --- a/test_dbapi/test_dbapi.py +++ b/test_dbapi/test_dbapi.py @@ -1,98 +1,141 @@ import pytest +import pytest_asyncio import ydb import ydb_sqlalchemy.dbapi as dbapi from contextlib import suppress +import sqlalchemy.util as util -def test_connection(connection): - connection.commit() - connection.rollback() +class BaseDBApiTestSuit: + def _test_connection(self, connection: dbapi.Connection): + connection.commit() + connection.rollback() - cur = connection.cursor() - with suppress(dbapi.DatabaseError): - cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) + cur = connection.cursor() + with suppress(dbapi.DatabaseError): + cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) - assert not connection.check_exists("/local/foo") - with pytest.raises(dbapi.ProgrammingError): - connection.describe("/local/foo") + assert not connection.check_exists("/local/foo") + with pytest.raises(dbapi.ProgrammingError): + connection.describe("/local/foo") - cur.execute(dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True)) + cur.execute(dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True)) - assert connection.check_exists("/local/foo") + assert connection.check_exists("/local/foo") - col = connection.describe("/local/foo").columns[0] - assert col.name == "id" - assert col.type == ydb.PrimitiveType.Int64 + col = connection.describe("/local/foo").columns[0] + assert col.name == "id" + assert col.type == ydb.PrimitiveType.Int64 - cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) - cur.close() + cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) + cur.close() + + def _test_cursor_raw_query(self, connection: dbapi.Connection): + cur = connection.cursor() + assert cur + + with suppress(dbapi.DatabaseError): + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + + cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", is_ddl=True)) + + cur.execute( + dbapi.YdbQuery( + """ + DECLARE $data AS List>; + + INSERT INTO test SELECT id, text FROM AS_TABLE($data); + """, + parameters_types={ + "$data": ydb.ListType( + ydb.StructType() + .add_member("id", ydb.PrimitiveType.Int64) + .add_member("text", ydb.PrimitiveType.Utf8) + ) + }, + ), + { + "$data": [ + {"id": 17, "text": "seventeen"}, + {"id": 21, "text": "twenty one"}, + ] + }, + ) + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) -def test_cursor_raw_query(connection): - cur = connection.cursor() - assert cur + cur.close() - with suppress(dbapi.DatabaseError): - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + def _test_errors(self, connection: dbapi.Connection): + with pytest.raises(dbapi.InterfaceError): + dbapi.YdbDBApi().connect("localhost:2136", database="/local666") - cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", is_ddl=True)) - - cur.execute( - dbapi.YdbQuery( - """ - DECLARE $data AS List>; - - INSERT INTO test SELECT id, text FROM AS_TABLE($data); - """, - parameters_types={ - "$data": ydb.ListType( - ydb.StructType() - .add_member("id", ydb.PrimitiveType.Int64) - .add_member("text", ydb.PrimitiveType.Utf8) - ) - }, - ), - { - "$data": [ - {"id": 17, "text": "seventeen"}, - {"id": 21, "text": "twenty one"}, - ] - }, - ) + cur = connection.cursor() + + with suppress(dbapi.DatabaseError): + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + with pytest.raises(dbapi.DataError): + cur.execute(dbapi.YdbQuery("SELECT 18446744073709551616")) - cur.close() + with pytest.raises(dbapi.DataError): + cur.execute(dbapi.YdbQuery("SELECT * FROM 拉屎")) + with pytest.raises(dbapi.DataError): + cur.execute(dbapi.YdbQuery("SELECT floor(5 / 2)")) -def test_errors(connection): - with pytest.raises(dbapi.InterfaceError): - dbapi.YdbDBApi().connect("localhost:2136", database="/local666") + with pytest.raises(dbapi.ProgrammingError): + cur.execute(dbapi.YdbQuery("SELECT * FROM test")) - cur = connection.cursor() + cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64, PRIMARY KEY (id))", is_ddl=True)) + + cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) + with pytest.raises(dbapi.IntegrityError): + cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) - with suppress(dbapi.DatabaseError): cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + cur.close() - with pytest.raises(dbapi.DataError): - cur.execute(dbapi.YdbQuery("SELECT 18446744073709551616")) - with pytest.raises(dbapi.DataError): - cur.execute(dbapi.YdbQuery("SELECT * FROM 拉屎")) +class TestSyncConnection(BaseDBApiTestSuit): + @pytest.fixture(scope="class") + def sync_connection(self) -> dbapi.Connection: + conn = dbapi.YdbDBApi().connect(host="localhost", port="2136", database="/local") + try: + yield conn + finally: + conn.close() - with pytest.raises(dbapi.DataError): - cur.execute(dbapi.YdbQuery("SELECT floor(5 / 2)")) + def test_connection(self, sync_connection: dbapi.Connection): + self._test_connection(sync_connection) - with pytest.raises(dbapi.ProgrammingError): - cur.execute(dbapi.YdbQuery("SELECT * FROM test")) + def test_cursor_raw_query(self, sync_connection: dbapi.Connection): + return self._test_cursor_raw_query(sync_connection) - cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64, PRIMARY KEY (id))", is_ddl=True)) + def test_errors(self, sync_connection: dbapi.Connection): + return self._test_errors(sync_connection) - cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) - with pytest.raises(dbapi.IntegrityError): - cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) - cur.close() +@pytest.mark.asyncio(scope="class") +class TestAsyncConnection(BaseDBApiTestSuit): + @pytest_asyncio.fixture(scope="class") + async def async_connection(self) -> dbapi.AsyncConnection: + def connect(): + return dbapi.YdbDBApi().async_connect(host="localhost", port="2136", database="/local") + + conn = await util.greenlet_spawn(connect) + try: + yield conn + finally: + await util.greenlet_spawn(conn.close) + + async def test_connection(self, async_connection: dbapi.AsyncConnection): + await util.greenlet_spawn(self._test_connection, async_connection) + + async def test_cursor_raw_query(self, async_connection: dbapi.AsyncConnection): + await util.greenlet_spawn(self._test_cursor_raw_query, async_connection) + + async def test_errors(self, async_connection: dbapi.AsyncConnection): + await util.greenlet_spawn(self._test_errors, async_connection) diff --git a/tox.ini b/tox.ini index a776ff6..63e760e 100644 --- a/tox.ini +++ b/tox.ini @@ -25,7 +25,7 @@ ignore_errors = True commands = docker-compose up -d python {toxinidir}/wait_container_ready.py - pytest -v test + pytest -v test --dbdriver ydb --dbdriver ydb_async pytest -v test_dbapi pytest -v ydb_sqlalchemy docker-compose down diff --git a/ydb_sqlalchemy/dbapi/__init__.py b/ydb_sqlalchemy/dbapi/__init__.py index f6d42cc..9c0b139 100644 --- a/ydb_sqlalchemy/dbapi/__init__.py +++ b/ydb_sqlalchemy/dbapi/__init__.py @@ -1,4 +1,4 @@ -from .connection import Connection, IsolationLevel # noqa: F401 +from .connection import Connection, AsyncConnection, IsolationLevel # noqa: F401 from .cursor import Cursor, AsyncCursor, YdbQuery # noqa: F401 from .errors import ( Warning, @@ -36,5 +36,8 @@ def _init_dbapi_attributes(self): }.items(): setattr(self, name, value) - def connect(self, *args, **kwargs): + def connect(self, *args, **kwargs) -> Connection: return Connection(*args, **kwargs) + + def async_connect(self, *args, **kwargs) -> AsyncConnection: + return AsyncConnection(*args, **kwargs) diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py index 43e6273..8cb1eb0 100644 --- a/ydb_sqlalchemy/dbapi/connection.py +++ b/ydb_sqlalchemy/dbapi/connection.py @@ -1,10 +1,11 @@ import posixpath -from typing import Optional, NamedTuple, Any +from typing import Optional, NamedTuple, Any, List import ydb - -from .cursor import Cursor -from .errors import InterfaceError, ProgrammingError, DatabaseError, InternalError, NotSupportedError +import sqlalchemy.util as util +import collections.abc +from .cursor import Cursor, AsyncCursor +from .errors import InterfaceError, InternalError, NotSupportedError class IsolationLevel: @@ -17,6 +18,14 @@ class IsolationLevel: class Connection: + _await = staticmethod(util.await_only) + + _is_async = False + _ydb_driver_class = ydb.Driver + _ydb_session_pool_class = ydb.SessionPool + _ydb_table_client_class = ydb.TableClient + _cursor_class = Cursor + def __init__( self, host: str = "", @@ -31,37 +40,36 @@ def __init__( if "ydb_session_pool" in self.conn_kwargs: # Use session pool managed manually self._shared_session_pool = True self.session_pool: ydb.SessionPool = self.conn_kwargs.pop("ydb_session_pool") - self.driver = self.session_pool._pool_impl._driver - self.driver.table_client = ydb.TableClient(self.driver, self._get_table_client_settings()) + self.driver = ( + self.session_pool._driver + if hasattr(self.session_pool, "_driver") + else self.session_pool._pool_impl._driver + ) + self.driver.table_client = self._ydb_table_client_class(self.driver, self._get_table_client_settings()) else: self._shared_session_pool = False self.driver = self._create_driver() - self.session_pool = ydb.SessionPool(self.driver, size=5, workers_threads_count=1) + self.session_pool = self._ydb_session_pool_class(self.driver, size=5) self.interactive_transaction: bool = False # AUTOCOMMIT self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite() self.tx_context: Optional[ydb.TxContext] = None def cursor(self): - return Cursor(self.session_pool, self.tx_context) + return self._cursor_class(self.session_pool, self.tx_context) - def describe(self, table_path): - full_path = posixpath.join(self.database, table_path) - try: - return self.session_pool.retry_operation_sync(lambda session: session.describe_table(full_path)) - except ydb.issues.SchemeError as e: - raise ProgrammingError(e.message, e.issues, e.status) from e - except ydb.Error as e: - raise DatabaseError(e.message, e.issues, e.status) from e - except Exception as e: - raise DatabaseError(f"Failed to describe table {table_path}") from e + def describe(self, table_path: str) -> ydb.TableDescription: + abs_table_path = posixpath.join(self.database, table_path) + cursor = self.cursor() + return cursor.describe_table(abs_table_path) - def check_exists(self, table_path): - try: - self.driver.scheme_client.describe_path(table_path) - return True - except ydb.SchemeError: - return False + def check_exists(self, table_path: str) -> ydb.SchemeEntry: + cursor = self.cursor() + return cursor.check_exists(table_path) + + def get_table_names(self) -> List[str]: + cursor = self.cursor() + return cursor.get_table_names() def set_isolation_level(self, isolation_level: str): class IsolationSettings(NamedTuple): @@ -105,28 +113,34 @@ def get_isolation_level(self) -> str: def begin(self): self.tx_context = None if self.interactive_transaction: - session = self.session_pool.acquire(blocking=True) + session = self._maybe_await(self.session_pool.acquire) self.tx_context = session.transaction(self.tx_mode) - self.tx_context.begin() + self._maybe_await(self.tx_context.begin) def commit(self): if self.tx_context and self.tx_context.tx_id: - self.tx_context.commit() - self.session_pool.release(self.tx_context.session) + self._maybe_await(self.tx_context.commit) + self._maybe_await(self.session_pool.release, self.tx_context.session) self.tx_context = None def rollback(self): if self.tx_context and self.tx_context.tx_id: - self.tx_context.rollback() - self.session_pool.release(self.tx_context.session) + self._maybe_await(self.tx_context.rollback) + self._maybe_await(self.session_pool.release, self.tx_context.session) self.tx_context = None def close(self): self.rollback() if not self._shared_session_pool: - self.session_pool.stop() + self._maybe_await(self.session_pool.stop) self._stop_driver() + @classmethod + def _maybe_await(cls, callee: collections.abc.Callable, *args, **kwargs) -> Any: + if cls._is_async: + return cls._await(callee(*args, **kwargs)) + return callee(*args, **kwargs) + def _get_table_client_settings(self) -> ydb.TableClientSettings: return ( ydb.TableClientSettings() @@ -143,15 +157,23 @@ def _create_driver(self): database=self.database, table_client_settings=self._get_table_client_settings(), ) - driver = ydb.Driver(driver_config) + driver = self._ydb_driver_class(driver_config) try: - driver.wait(timeout=5, fail_fast=True) + self._maybe_await(driver.wait, timeout=5, fail_fast=True) except ydb.Error as e: raise InterfaceError(e.message, e.issues, e.status) from e except Exception as e: - driver.stop() + self._maybe_await(driver.stop) raise InterfaceError(f"Failed to connect to YDB, details {driver.discovery_debug_details()}") from e return driver def _stop_driver(self): - self.driver.stop() + self._maybe_await(self.driver.stop) + + +class AsyncConnection(Connection): + _is_async = True + _ydb_driver_class = ydb.aio.Driver + _ydb_session_pool_class = ydb.aio.SessionPool + _ydb_table_client_class = ydb.aio.table.TableClient + _cursor_class = AsyncCursor diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index ab68066..ce0659c 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -2,7 +2,7 @@ import itertools import logging import functools -from typing import Any, Mapping, Optional, Sequence, Union, Dict +from typing import Any, Mapping, Optional, Sequence, Union, Dict, List import collections.abc from sqlalchemy import util @@ -37,9 +37,9 @@ class YdbQuery: def _handle_ydb_errors(func): @functools.wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(*args, **kwargs): try: - return func(self, *args, **kwargs) + return func(*args, **kwargs) except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: raise IntegrityError(e.message, e.issues, e.status) from e except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: @@ -66,6 +66,8 @@ def wrapper(self, *args, **kwargs): raise InternalError(e.message, e.issues, e.status) from e except ydb.Error as e: raise DatabaseError(e.message, e.issues, e.status) from e + except Exception as e: + raise DatabaseError("Failed to execute query") from e return wrapper @@ -83,6 +85,22 @@ def __init__( self.rows = None self._rows_prefetched = None + @_handle_ydb_errors + def describe_table(self, abs_table_path: str) -> ydb.TableDescription: + return self._retry_operation_in_pool(self._describe_table, abs_table_path) + + def check_exists(self, table_path: str) -> bool: + try: + self._retry_operation_in_pool(self._describe_path, table_path) + return True + except ydb.SchemeError: + return False + + @_handle_ydb_errors + def get_table_names(self) -> List[str]: + directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory) + return [child.name for child in directory.children if child.is_table()] + def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = None): if operation.is_ddl or not operation.parameters_types: query = operation.yql_text @@ -134,6 +152,18 @@ def _execute_ddl(self, query: str) -> ydb.convert.ResultSets: def _execute_scheme(session: ydb.Session, query: str) -> ydb.convert.ResultSets: return session.execute_scheme(query) + @staticmethod + def _describe_table(session: ydb.Session, abs_table_path: str) -> ydb.TableDescription: + return session.describe_table(abs_table_path) + + @staticmethod + def _describe_path(session: ydb.Session, table_path: str) -> ydb.SchemeEntry: + return session._driver.scheme_client.describe_path(table_path) + + @staticmethod + def _list_directory(session: ydb.Session) -> ydb.Directory: + return session._driver.scheme_client.list_directory(session._driver._driver_config.database) + @staticmethod def _prepare(session: ydb.Session, query: str) -> ydb.DataQuery: return session.prepare(query) @@ -224,6 +254,18 @@ def rowcount(self): class AsyncCursor(Cursor): _await = staticmethod(util.await_only) + @staticmethod + async def _describe_table(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.TableDescription: + return await session.describe_table(abs_table_path) + + @staticmethod + async def _describe_path(session: ydb.aio.table.Session, table_path: str) -> ydb.SchemeEntry: + return await session._driver.scheme_client.describe_path(table_path) + + @staticmethod + async def _list_directory(session: ydb.aio.table.Session) -> ydb.Directory: + return await session._driver.scheme_client.list_directory(session._driver._driver_config.database) + @staticmethod async def _execute_scheme(session: ydb.aio.table.Session, query: str) -> ydb.convert.ResultSets: return await session.execute_scheme(query) @@ -251,4 +293,4 @@ def _run_operation_in_session(self, callee: collections.abc.Coroutine, *args, ** return self._await(callee(self.tx_context.session, *args, **kwargs)) def _retry_operation_in_pool(self, callee: collections.abc.Coroutine, *args, **kwargs): - return self._await(self.session_pool.retry_operation(callee, None, *args, **kwargs)) + return self._await(self.session_pool.retry_operation(callee, *args, **kwargs)) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 2e34f43..88c0272 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -573,15 +573,12 @@ def get_columns(self, connection, table_name, schema=None, **kw): return as_compatible @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def get_table_names(self, connection, schema=None, **kw) -> List[str]: if schema: raise dbapi.NotSupportedError("unsupported on non empty schema") - driver = connection.connection.driver_connection.driver - db_path = driver._driver_config.database - children = driver.scheme_client.list_directory(db_path).children - - return [child.name for child in children if child.is_table()] + raw_conn = connection.connection + return raw_conn.get_table_names() @reflection.cache def has_table(self, connection, table_name, schema=None, **kwargs): @@ -615,6 +612,9 @@ def get_default_isolation_level(self, dbapi_conn: dbapi.Connection) -> str: def get_isolation_level(self, dbapi_connection: dbapi.Connection) -> str: return dbapi_connection.get_isolation_level() + def connect(self, *cargs, **cparams): + return self.loaded_dbapi.connect(*cargs, **cparams) + def do_begin(self, dbapi_connection: dbapi.Connection) -> None: dbapi_connection.begin() @@ -697,3 +697,12 @@ def do_execute( ) -> None: operation, parameters = self._make_ydb_operation(statement, context, parameters, execute_many=False) cursor.execute(operation, parameters) + + +class AsyncYqlDialect(YqlDialect): + driver = "ydb_async" + is_async = True + supports_statement_cache = False + + def connect(self, *cargs, **cparams): + return self.loaded_dbapi.async_connect(*cargs, **cparams) From 4a1edac9de58f60b2ba6f777966c594586d04bd8 Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Thu, 25 Jan 2024 18:15:40 +0100 Subject: [PATCH 08/13] Add isort --- README.md | 1 + examples/example.py | 7 +- examples/fill_tables.py | 3 +- examples/models.py | 1 - pyproject.toml | 3 + test-requirements.txt | 1 + test/test_core.py | 10 ++- test/test_inspect.py | 3 +- test/test_suite.py | 95 +++++++++++++++++---------- test_dbapi/test_dbapi.py | 8 +-- tox.ini | 5 ++ ydb_sqlalchemy/dbapi/__init__.py | 14 ++-- ydb_sqlalchemy/dbapi/connection.py | 9 +-- ydb_sqlalchemy/dbapi/cursor.py | 18 ++--- ydb_sqlalchemy/sqlalchemy/__init__.py | 20 +++--- ydb_sqlalchemy/sqlalchemy/types.py | 5 +- 16 files changed, 118 insertions(+), 85 deletions(-) diff --git a/README.md b/README.md index a2fcef6..8722ff8 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ $ tox -e style Reformat code: ```bash +$ tox -e isort $ tox -e black-format ``` diff --git a/examples/example.py b/examples/example.py index 5d6a786..96f75bc 100644 --- a/examples/example.py +++ b/examples/example.py @@ -1,11 +1,10 @@ import datetime import logging -import sqlalchemy as sa -from sqlalchemy import orm, exc, sql -from sqlalchemy import Table, Column, Integer, String, Float, TIMESTAMP +import sqlalchemy as sa from fill_tables import fill_all_tables, to_days -from models import Base, Series, Episodes +from models import Base, Episodes, Series +from sqlalchemy import TIMESTAMP, Column, Float, Integer, String, Table, exc, orm, sql def describe_table(engine, name): diff --git a/examples/fill_tables.py b/examples/fill_tables.py index 5a9eb95..b047e65 100644 --- a/examples/fill_tables.py +++ b/examples/fill_tables.py @@ -1,7 +1,6 @@ import iso8601 - import sqlalchemy as sa -from models import Base, Series, Seasons, Episodes +from models import Base, Episodes, Seasons, Series def to_days(date): diff --git a/examples/models.py b/examples/models.py index a02349a..09f1882 100644 --- a/examples/models.py +++ b/examples/models.py @@ -1,7 +1,6 @@ import sqlalchemy.orm as orm from sqlalchemy import Column, Integer, Unicode - Base = orm.declarative_base() diff --git a/pyproject.toml b/pyproject.toml index 55ec8d7..85c3b07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,5 @@ [tool.black] line-length = 120 + +[tool.isort] +profile = "black" diff --git a/test-requirements.txt b/test-requirements.txt index ceb987e..21e0953 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -10,3 +10,4 @@ flake8==3.9.2 black==23.3.0 pytest-cov pytest-asyncio +isort==5.13.2 diff --git a/test/test_core.py b/test/test_core.py index 050b5a1..646f765 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -4,17 +4,15 @@ from typing import NamedTuple import pytest - import sqlalchemy as sa -from sqlalchemy import Table, Column, Integer, Unicode, String -from sqlalchemy.testing.fixtures import TestBase, TablesTest, config - import ydb +from sqlalchemy import Column, Integer, String, Table, Unicode +from sqlalchemy.testing.fixtures import TablesTest, TestBase, config from ydb._grpc.v4.protos import ydb_common_pb2 -from ydb_sqlalchemy import dbapi, IsolationLevel -from ydb_sqlalchemy.sqlalchemy import types +from ydb_sqlalchemy import IsolationLevel, dbapi from ydb_sqlalchemy import sqlalchemy as ydb_sa +from ydb_sqlalchemy.sqlalchemy import types def clear_sql(stm): diff --git a/test/test_inspect.py b/test/test_inspect.py index 1fe61b8..0d4c9a7 100644 --- a/test/test_inspect.py +++ b/test/test_inspect.py @@ -1,6 +1,5 @@ import sqlalchemy as sa - -from sqlalchemy import Table, Column, Integer, Unicode, Numeric +from sqlalchemy import Column, Integer, Numeric, Table, Unicode from sqlalchemy.testing.fixtures import TablesTest diff --git a/test/test_suite.py b/test/test_suite.py index f093e30..329e932 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -1,53 +1,80 @@ import pytest import sqlalchemy as sa import sqlalchemy.testing.suite.test_types - +from sqlalchemy.testing import is_false, is_true from sqlalchemy.testing.suite import * # noqa: F401, F403 - -from sqlalchemy.testing import is_true, is_false -from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements, fixtures -from sqlalchemy.testing.suite import func, column, literal_column, select, exists, union -from sqlalchemy.testing.suite import MetaData, Column, Table, Integer, String - -from sqlalchemy.testing.suite.test_select import ( - ExistsTest as _ExistsTest, - LikeFunctionsTest as _LikeFunctionsTest, +from sqlalchemy.testing.suite import ( + Column, + Integer, + MetaData, + String, + Table, + column, + config, + eq_, + exists, + fixtures, + func, + inspect, + literal_column, + provide_metadata, + requirements, + select, + testing, +) +from sqlalchemy.testing.suite.test_ddl import ( + LongNameBlowoutTest as _LongNameBlowoutTest, +) +from sqlalchemy.testing.suite.test_dialect import ( + DifficultParametersTest as _DifficultParametersTest, +) +from sqlalchemy.testing.suite.test_dialect import EscapingTest as _EscapingTest +from sqlalchemy.testing.suite.test_insert import ( + InsertBehaviorTest as _InsertBehaviorTest, ) from sqlalchemy.testing.suite.test_reflection import ( - HasTableTest as _HasTableTest, - HasIndexTest as _HasIndexTest, ComponentReflectionTest as _ComponentReflectionTest, - CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite.test_reflection import ( ComponentReflectionTestExtra as _ComponentReflectionTestExtra, +) +from sqlalchemy.testing.suite.test_reflection import ( + CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite.test_reflection import HasIndexTest as _HasIndexTest +from sqlalchemy.testing.suite.test_reflection import HasTableTest as _HasTableTest +from sqlalchemy.testing.suite.test_reflection import ( QuotedNameArgumentTest as _QuotedNameArgumentTest, ) +from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest +from sqlalchemy.testing.suite.test_select import ExistsTest as _ExistsTest +from sqlalchemy.testing.suite.test_select import ( + FetchLimitOffsetTest as _FetchLimitOffsetTest, +) +from sqlalchemy.testing.suite.test_select import JoinTest as _JoinTest +from sqlalchemy.testing.suite.test_select import LikeFunctionsTest as _LikeFunctionsTest +from sqlalchemy.testing.suite.test_select import OrderByLabelTest as _OrderByLabelTest +from sqlalchemy.testing.suite.test_types import BinaryTest as _BinaryTest +from sqlalchemy.testing.suite.test_types import DateTest as _DateTest from sqlalchemy.testing.suite.test_types import ( - IntegerTest as _IntegerTest, - NumericTest as _NumericTest, - BinaryTest as _BinaryTest, - TrueDivTest as _TrueDivTest, - TimeTest as _TimeTest, - StringTest as _StringTest, - NativeUUIDTest as _NativeUUIDTest, - TimeMicrosecondsTest as _TimeMicrosecondsTest, DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest, - DateTest as _DateTest, +) +from sqlalchemy.testing.suite.test_types import ( DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest, - DateTimeTest as _DateTimeTest, - TimestampMicrosecondsTest as _TimestampMicrosecondsTest, ) -from sqlalchemy.testing.suite.test_dialect import ( - EscapingTest as _EscapingTest, - DifficultParametersTest as _DifficultParametersTest, +from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest +from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest +from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest +from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest +from sqlalchemy.testing.suite.test_types import StringTest as _StringTest +from sqlalchemy.testing.suite.test_types import ( + TimeMicrosecondsTest as _TimeMicrosecondsTest, ) -from sqlalchemy.testing.suite.test_select import ( - JoinTest as _JoinTest, - OrderByLabelTest as _OrderByLabelTest, - FetchLimitOffsetTest as _FetchLimitOffsetTest, +from sqlalchemy.testing.suite.test_types import ( + TimestampMicrosecondsTest as _TimestampMicrosecondsTest, ) -from sqlalchemy.testing.suite.test_insert import InsertBehaviorTest as _InsertBehaviorTest -from sqlalchemy.testing.suite.test_ddl import LongNameBlowoutTest as _LongNameBlowoutTest -from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest +from sqlalchemy.testing.suite.test_types import TimeTest as _TimeTest +from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types diff --git a/test_dbapi/test_dbapi.py b/test_dbapi/test_dbapi.py index 6805468..6f01ae0 100644 --- a/test_dbapi/test_dbapi.py +++ b/test_dbapi/test_dbapi.py @@ -1,11 +1,11 @@ +from contextlib import suppress + import pytest import pytest_asyncio - +import sqlalchemy.util as util import ydb -import ydb_sqlalchemy.dbapi as dbapi -from contextlib import suppress -import sqlalchemy.util as util +import ydb_sqlalchemy.dbapi as dbapi class BaseDBApiTestSuit: diff --git a/tox.ini b/tox.ini index 63e760e..41ee405 100644 --- a/tox.ini +++ b/tox.ini @@ -60,6 +60,11 @@ skip_install = true commands = black ydb_sqlalchemy examples test test_dbapi +[testenv:isort] +skip_install = true +commands = + isort ydb_sqlalchemy examples test test_dbapi + [testenv:style] ignore_errors = True commands = diff --git a/ydb_sqlalchemy/dbapi/__init__.py b/ydb_sqlalchemy/dbapi/__init__.py index 9c0b139..f8fffe7 100644 --- a/ydb_sqlalchemy/dbapi/__init__.py +++ b/ydb_sqlalchemy/dbapi/__init__.py @@ -1,16 +1,16 @@ -from .connection import Connection, AsyncConnection, IsolationLevel # noqa: F401 -from .cursor import Cursor, AsyncCursor, YdbQuery # noqa: F401 +from .connection import AsyncConnection, Connection, IsolationLevel # noqa: F401 +from .cursor import AsyncCursor, Cursor, YdbQuery # noqa: F401 from .errors import ( - Warning, - Error, - InterfaceError, DatabaseError, DataError, - OperationalError, + Error, IntegrityError, + InterfaceError, InternalError, - ProgrammingError, NotSupportedError, + OperationalError, + ProgrammingError, + Warning, ) diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py index 8cb1eb0..e198924 100644 --- a/ydb_sqlalchemy/dbapi/connection.py +++ b/ydb_sqlalchemy/dbapi/connection.py @@ -1,10 +1,11 @@ +import collections.abc import posixpath -from typing import Optional, NamedTuple, Any, List +from typing import Any, List, NamedTuple, Optional -import ydb import sqlalchemy.util as util -import collections.abc -from .cursor import Cursor, AsyncCursor +import ydb + +from .cursor import AsyncCursor, Cursor from .errors import InterfaceError, InternalError, NotSupportedError diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index ce0659c..4d2a038 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -1,22 +1,22 @@ +import collections.abc import dataclasses +import functools import itertools import logging -import functools -from typing import Any, Mapping, Optional, Sequence, Union, Dict, List -import collections.abc -from sqlalchemy import util +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union import ydb import ydb.aio +from sqlalchemy import util from .errors import ( - InternalError, - IntegrityError, - DataError, DatabaseError, - ProgrammingError, - OperationalError, + DataError, + IntegrityError, + InternalError, NotSupportedError, + OperationalError, + ProgrammingError, ) logger = logging.getLogger(__name__) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 88c0272..78bb7de 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -4,27 +4,27 @@ """ import collections import collections.abc -import ydb -import ydb_sqlalchemy.dbapi as dbapi -from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS -from ydb_sqlalchemy.sqlalchemy.dml import Upsert +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union import sqlalchemy as sa +import ydb +from sqlalchemy.engine import reflection +from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect from sqlalchemy.exc import CompileError, NoSuchTableError from sqlalchemy.sql import functions, literal_column from sqlalchemy.sql.compiler import ( - selectable, + DDLCompiler, IdentifierPreparer, - StrSQLTypeCompiler, StrSQLCompiler, - DDLCompiler, + StrSQLTypeCompiler, + selectable, ) from sqlalchemy.sql.elements import ClauseList -from sqlalchemy.engine import reflection -from sqlalchemy.engine.default import StrCompileDialect, DefaultExecutionContext from sqlalchemy.util.compat import inspect_getfullargspec -from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict, Type +import ydb_sqlalchemy.dbapi as dbapi +from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS +from ydb_sqlalchemy.sqlalchemy.dml import Upsert from . import types diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 82d44c8..94f957b 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -1,6 +1,7 @@ -from sqlalchemy import exc, ColumnElement, ARRAY, types +from typing import Any, Mapping, Type, Union + +from sqlalchemy import ARRAY, ColumnElement, exc, types from sqlalchemy.sql import type_api -from typing import Mapping, Any, Union, Type class UInt64(types.Integer): From 41eb884cd28de97c01d5f5fe3a53683a94a289e5 Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Fri, 2 Feb 2024 13:40:33 +0100 Subject: [PATCH 09/13] Fix typo --- test/test_core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_core.py b/test/test_core.py index 646f765..9175b7c 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -586,8 +586,6 @@ def ydb_driver(self): finally: loop.run_until_complete(driver.stop()) - loop.run_until_complete(driver.stop()) - @pytest.fixture(scope="class") def ydb_pool(self, ydb_driver): session_pool = ydb.aio.SessionPool(ydb_driver, size=5) From 946a3f10c03627a3afd93a6f9afd6b0fb7f546b4 Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Fri, 2 Feb 2024 13:40:55 +0100 Subject: [PATCH 10/13] Provide tx_mode to one-time transactions --- test_dbapi/test_dbapi.py | 61 ++++++++++++++++++++++++++++-- ydb_sqlalchemy/dbapi/connection.py | 2 +- ydb_sqlalchemy/dbapi/cursor.py | 18 ++++++--- 3 files changed, 72 insertions(+), 9 deletions(-) diff --git a/test_dbapi/test_dbapi.py b/test_dbapi/test_dbapi.py index 6f01ae0..08e4a9b 100644 --- a/test_dbapi/test_dbapi.py +++ b/test_dbapi/test_dbapi.py @@ -9,6 +9,28 @@ class BaseDBApiTestSuit: + def _test_isolation_level_read_only(self, connection: dbapi.Connection, isolation_level: str, read_only: bool): + connection.cursor().execute( + dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True) + ) + connection.set_isolation_level(isolation_level) + + cursor = connection.cursor() + + connection.begin() + + query = dbapi.YdbQuery("UPSERT INTO foo(id) VALUES (1)") + if read_only: + with pytest.raises(dbapi.DatabaseError): + cursor.execute(query) + else: + cursor.execute(query) + + connection.rollback() + + connection.cursor().execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) + connection.cursor().close() + def _test_connection(self, connection: dbapi.Connection): connection.commit() connection.rollback() @@ -100,7 +122,7 @@ def _test_errors(self, connection: dbapi.Connection): class TestSyncConnection(BaseDBApiTestSuit): - @pytest.fixture(scope="class") + @pytest.fixture def sync_connection(self) -> dbapi.Connection: conn = dbapi.YdbDBApi().connect(host="localhost", port="2136", database="/local") try: @@ -108,6 +130,20 @@ def sync_connection(self) -> dbapi.Connection: finally: conn.close() + @pytest.mark.parametrize( + "isolation_level, read_only", + [ + (dbapi.IsolationLevel.SERIALIZABLE, False), + (dbapi.IsolationLevel.AUTOCOMMIT, False), + (dbapi.IsolationLevel.ONLINE_READONLY, True), + (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True), + (dbapi.IsolationLevel.STALE_READONLY, True), + (dbapi.IsolationLevel.SNAPSHOT_READONLY, True), + ], + ) + def test_isolation_level_read_only(self, isolation_level: str, read_only: bool, sync_connection: dbapi.Connection): + self._test_isolation_level_read_only(sync_connection, isolation_level, read_only) + def test_connection(self, sync_connection: dbapi.Connection): self._test_connection(sync_connection) @@ -118,9 +154,8 @@ def test_errors(self, sync_connection: dbapi.Connection): return self._test_errors(sync_connection) -@pytest.mark.asyncio(scope="class") class TestAsyncConnection(BaseDBApiTestSuit): - @pytest_asyncio.fixture(scope="class") + @pytest_asyncio.fixture async def async_connection(self) -> dbapi.AsyncConnection: def connect(): return dbapi.YdbDBApi().async_connect(host="localhost", port="2136", database="/local") @@ -131,11 +166,31 @@ def connect(): finally: await util.greenlet_spawn(conn.close) + @pytest.mark.asyncio + @pytest.mark.parametrize( + "isolation_level, read_only", + [ + (dbapi.IsolationLevel.SERIALIZABLE, False), + (dbapi.IsolationLevel.AUTOCOMMIT, False), + (dbapi.IsolationLevel.ONLINE_READONLY, True), + (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True), + (dbapi.IsolationLevel.STALE_READONLY, True), + (dbapi.IsolationLevel.SNAPSHOT_READONLY, True), + ], + ) + async def test_isolation_level_read_only( + self, isolation_level: str, read_only: bool, async_connection: dbapi.AsyncConnection + ): + await util.greenlet_spawn(self._test_isolation_level_read_only, async_connection, isolation_level, read_only) + + @pytest.mark.asyncio async def test_connection(self, async_connection: dbapi.AsyncConnection): await util.greenlet_spawn(self._test_connection, async_connection) + @pytest.mark.asyncio async def test_cursor_raw_query(self, async_connection: dbapi.AsyncConnection): await util.greenlet_spawn(self._test_cursor_raw_query, async_connection) + @pytest.mark.asyncio async def test_errors(self, async_connection: dbapi.AsyncConnection): await util.greenlet_spawn(self._test_errors, async_connection) diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py index e198924..a90239f 100644 --- a/ydb_sqlalchemy/dbapi/connection.py +++ b/ydb_sqlalchemy/dbapi/connection.py @@ -57,7 +57,7 @@ def __init__( self.tx_context: Optional[ydb.TxContext] = None def cursor(self): - return self._cursor_class(self.session_pool, self.tx_context) + return self._cursor_class(self.session_pool, self.tx_mode, self.tx_context) def describe(self, table_path: str) -> ydb.TableDescription: abs_table_path = posixpath.join(self.database, table_path) diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index 4d2a038..9d74424 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -76,9 +76,11 @@ class Cursor: def __init__( self, session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool], + tx_mode: ydb.AbstractTransactionModeBuilder, tx_context: Optional[ydb.BaseTxContext] = None, ): self.session_pool = session_pool + self.tx_mode = tx_mode self.tx_context = tx_context self.description = None self.arraysize = 1 @@ -142,7 +144,7 @@ def _execute_dml( if self.tx_context: return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters) - return self._retry_operation_in_pool(self._execute_in_session, prepared_query, parameters) + return self._retry_operation_in_pool(self._execute_in_session, self.tx_mode, prepared_query, parameters) @_handle_ydb_errors def _execute_ddl(self, query: str) -> ydb.convert.ResultSets: @@ -176,9 +178,12 @@ def _execute_in_tx( @staticmethod def _execute_in_session( - session: ydb.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + session: ydb.Session, + tx_mode: ydb.AbstractTransactionModeBuilder, + prepared_query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]], ) -> ydb.convert.ResultSets: - return session.transaction().execute(prepared_query, parameters, commit_tx=True) + return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs): return callee(self.tx_context, *args, **kwargs) @@ -282,9 +287,12 @@ async def _execute_in_tx( @staticmethod async def _execute_in_session( - session: ydb.aio.table.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + session: ydb.aio.table.Session, + tx_mode: ydb.AbstractTransactionModeBuilder, + prepared_query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]], ) -> ydb.convert.ResultSets: - return await session.transaction().execute(prepared_query, parameters, commit_tx=True) + return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs): return self._await(callee(self.tx_context, *args, **kwargs)) From fc9f8fbdd09df66b6815dfb0985f19a60e67278c Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Fri, 2 Feb 2024 13:43:11 +0100 Subject: [PATCH 11/13] Add original error to dbapi exceptions --- ydb_sqlalchemy/__init__.py | 1 + ydb_sqlalchemy/dbapi/connection.py | 2 +- ydb_sqlalchemy/dbapi/cursor.py | 16 ++++++++-------- ydb_sqlalchemy/dbapi/errors.py | 26 +++++++++++++++++++------- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/ydb_sqlalchemy/__init__.py b/ydb_sqlalchemy/__init__.py index 2e5fbab..7e39278 100644 --- a/ydb_sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/__init__.py @@ -1 +1,2 @@ from .dbapi import IsolationLevel # noqa: F401 +from .sqlalchemy import Upsert, types, upsert # noqa: F401 diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py index a90239f..df73c4c 100644 --- a/ydb_sqlalchemy/dbapi/connection.py +++ b/ydb_sqlalchemy/dbapi/connection.py @@ -162,7 +162,7 @@ def _create_driver(self): try: self._maybe_await(driver.wait, timeout=5, fail_fast=True) except ydb.Error as e: - raise InterfaceError(e.message, e.issues, e.status) from e + raise InterfaceError(e.message, original_error=e) from e except Exception as e: self._maybe_await(driver.stop) raise InterfaceError(f"Failed to connect to YDB, details {driver.discovery_debug_details()}") from e diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index 9d74424..27e6593 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -41,11 +41,11 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: - raise IntegrityError(e.message, e.issues, e.status) from e + raise IntegrityError(e.message, original_error=e) from e except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: - raise NotSupportedError(e.message, e.issues, e.status) from e + raise NotSupportedError(e.message, original_error=e) from e except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e: - raise ProgrammingError(e.message, e.issues, e.status) from e + raise ProgrammingError(e.message, original_error=e) from e except ( ydb.issues.TruncatedResponseError, ydb.issues.ConnectionError, @@ -59,13 +59,13 @@ def wrapper(*args, **kwargs): ydb.issues.SessionExpired, ydb.issues.SessionPoolEmpty, ) as e: - raise OperationalError(e.message, e.issues, e.status) from e + raise OperationalError(e.message, original_error=e) from e except ydb.issues.GenericError as e: - raise DataError(e.message, e.issues, e.status) from e + raise DataError(e.message, original_error=e) from e except ydb.issues.InternalError as e: - raise InternalError(e.message, e.issues, e.status) from e + raise InternalError(e.message, original_error=e) from e except ydb.Error as e: - raise DatabaseError(e.message, e.issues, e.status) from e + raise DatabaseError(e.message, original_error=e) from e except Exception as e: raise DatabaseError("Failed to execute query") from e @@ -214,7 +214,7 @@ def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets): # of this PEP to return a sequence: https://www.python.org/dev/peps/pep-0249/#fetchmany yield row[::] except ydb.Error as e: - raise DatabaseError(e.message, e.issues, e.status) from e + raise DatabaseError(e.message, original_error=e) from e def _ensure_prefetched(self): if self.rows is not None and self._rows_prefetched is None: diff --git a/ydb_sqlalchemy/dbapi/errors.py b/ydb_sqlalchemy/dbapi/errors.py index a67c0ce..70b55eb 100644 --- a/ydb_sqlalchemy/dbapi/errors.py +++ b/ydb_sqlalchemy/dbapi/errors.py @@ -1,15 +1,27 @@ +from typing import Optional, List + +import ydb +from google.protobuf.message import Message + + class Warning(Exception): pass class Error(Exception): - def __init__(self, message, issues=None, status=None): + def __init__( + self, + message: str, + original_error: Optional[ydb.Error] = None, + ): super(Error, self).__init__(message) - pretty_issues = _pretty_issues(issues) - self.issues = issues - self.message = pretty_issues or message - self.status = status + self.original_error = original_error + if original_error: + pretty_issues = _pretty_issues(original_error.issues) + self.issues = original_error.issues + self.message = pretty_issues or message + self.status = original_error.status class InterfaceError(Error): @@ -44,7 +56,7 @@ class NotSupportedError(DatabaseError): pass -def _pretty_issues(issues): +def _pretty_issues(issues: List[Message]) -> str: if issues is None: return None @@ -56,7 +68,7 @@ def _pretty_issues(issues): return "\n" + "\n".join(children_messages) -def _get_messages(issue, max_depth=100, indent=2, depth=0, root=False): +def _get_messages(issue: Message, max_depth: int = 100, indent: int = 2, depth: int = 0, root: bool = False) -> str: if depth >= max_depth: return None From 80a61f1d22992b9d2dc2a0029d62bf58c1c1fbea Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Tue, 6 Feb 2024 11:02:47 +0100 Subject: [PATCH 12/13] Resolve rebase conflicts --- ydb_sqlalchemy/dbapi/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ydb_sqlalchemy/dbapi/errors.py b/ydb_sqlalchemy/dbapi/errors.py index 70b55eb..79faba8 100644 --- a/ydb_sqlalchemy/dbapi/errors.py +++ b/ydb_sqlalchemy/dbapi/errors.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List, Optional import ydb from google.protobuf.message import Message From 4064f82df035a53e69acb1a74a2797ec8a8331cf Mon Sep 17 00:00:00 2001 From: Roman Tretiak Date: Tue, 6 Feb 2024 11:06:29 +0100 Subject: [PATCH 13/13] Add missing import --- test/test_suite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_suite.py b/test/test_suite.py index 329e932..e504058 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -21,6 +21,7 @@ requirements, select, testing, + union, ) from sqlalchemy.testing.suite.test_ddl import ( LongNameBlowoutTest as _LongNameBlowoutTest,