From 229641727c98cb2446d5435f2beef3fc931b0cc2 Mon Sep 17 00:00:00 2001 From: kabulov kozim Date: Thu, 18 Apr 2024 10:43:37 +0300 Subject: [PATCH 1/5] add declare for yql statement variables --- ydb_sqlalchemy/sqlalchemy/__init__.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index e5e3ca2..6d3ad80 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -584,11 +584,18 @@ class YqlDialect(StrCompileDialect): def import_dbapi(cls: Any): return dbapi.YdbDBApi() - def __init__(self, json_serializer=None, json_deserializer=None, **kwargs): + def __init__( + self, + json_serializer=None, + json_deserializer=None, + add_declare_for_yql_stmt_vars=False, + **kwargs + ): super().__init__(**kwargs) self._json_deserializer = json_deserializer self._json_serializer = json_serializer + self._add_declare_for_yql_stmt_vars = add_declare_for_yql_stmt_vars def _describe_table(self, connection, table_name, schema=None): if schema is not None: @@ -697,6 +704,12 @@ def _format_variables( formatted_statement = formatted_statement.replace("%%", "%") return formatted_statement, formatted_parameters + def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types): + declarations = "\n".join([ + f"DECLARE {param_name} as {str(param_type)};" for param_name, param_type in parameters_types.items() + ]) + return f"{declarations}\n{statement}" + def _make_ydb_operation( self, statement: str, @@ -710,6 +723,8 @@ def _make_ydb_operation( parameters_types = context.compiled.get_bind_types(parameters) parameters_types = {f"${k}": v for k, v in parameters_types.items()} statement, parameters = self._format_variables(statement, parameters, execute_many) + if self._add_declare_for_yql_stmt_vars: + statement = self._add_declare_for_yql_stmt_vars_impl(statement, parameters_types) return dbapi.YdbQuery(yql_text=statement, parameters_types=parameters_types, is_ddl=is_ddl), parameters statement, parameters = self._format_variables(statement, parameters, execute_many) From 01b15ca910d402eb79446f0a0e329c3614f681b9 Mon Sep 17 00:00:00 2001 From: kabulov kozim Date: Thu, 18 Apr 2024 12:32:03 +0300 Subject: [PATCH 2/5] add tests --- test/test_core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_core.py b/test/test_core.py index 8d9a53f..8e7cdc1 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -83,6 +83,11 @@ def test_sa_crud(self, connection): (5, "c"), ] + def test_sa_crud_with_add_declare(self): + engine = sa.create_engine(config.db_url, add_declare_for_yql_stmt_vars=True) + with engine.connect() as connection: + self.test_sa_crud(connection) + class TestSimpleSelect(TablesTest): __backend__ = True From 7a9dd7726ed402ef1acf0effc648073bdb593639 Mon Sep 17 00:00:00 2001 From: kabulov kozim Date: Thu, 18 Apr 2024 12:32:19 +0300 Subject: [PATCH 3/5] fix style --- ydb_sqlalchemy/sqlalchemy/__init__.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 6d3ad80..a0df96b 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -584,13 +584,7 @@ class YqlDialect(StrCompileDialect): def import_dbapi(cls: Any): return dbapi.YdbDBApi() - def __init__( - self, - json_serializer=None, - json_deserializer=None, - add_declare_for_yql_stmt_vars=False, - **kwargs - ): + def __init__(self, json_serializer=None, json_deserializer=None, add_declare_for_yql_stmt_vars=False, **kwargs): super().__init__(**kwargs) self._json_deserializer = json_deserializer @@ -705,9 +699,9 @@ def _format_variables( return formatted_statement, formatted_parameters def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types): - declarations = "\n".join([ - f"DECLARE {param_name} as {str(param_type)};" for param_name, param_type in parameters_types.items() - ]) + declarations = "\n".join( + [f"DECLARE {param_name} as {str(param_type)};" for param_name, param_type in parameters_types.items()] + ) return f"{declarations}\n{statement}" def _make_ydb_operation( From 85f8c91336064f65804d3a5aa5fc10ee2ab08d7c Mon Sep 17 00:00:00 2001 From: kabulov kozim Date: Thu, 18 Apr 2024 13:57:36 +0300 Subject: [PATCH 4/5] set note for readability --- ydb_sqlalchemy/sqlalchemy/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index a0df96b..da16ee8 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -584,12 +584,13 @@ class YqlDialect(StrCompileDialect): def import_dbapi(cls: Any): return dbapi.YdbDBApi() - def __init__(self, json_serializer=None, json_deserializer=None, add_declare_for_yql_stmt_vars=False, **kwargs): + def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, **kwargs): super().__init__(**kwargs) self._json_deserializer = json_deserializer self._json_serializer = json_serializer - self._add_declare_for_yql_stmt_vars = add_declare_for_yql_stmt_vars + # NOTE: _add_declare_for_yql_stmt_vars is temporary and is soon to be removed + self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars def _describe_table(self, connection, table_name, schema=None): if schema is not None: From 109bd03b136220f478c3e4b88c12707518d9c13e Mon Sep 17 00:00:00 2001 From: kabulov kozim Date: Thu, 18 Apr 2024 14:08:22 +0300 Subject: [PATCH 5/5] fix notes and tests --- test/test_core.py | 2 +- ydb_sqlalchemy/sqlalchemy/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_core.py b/test/test_core.py index 8e7cdc1..52825c5 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -84,7 +84,7 @@ def test_sa_crud(self, connection): ] def test_sa_crud_with_add_declare(self): - engine = sa.create_engine(config.db_url, add_declare_for_yql_stmt_vars=True) + engine = sa.create_engine(config.db_url, _add_declare_for_yql_stmt_vars=True) with engine.connect() as connection: self.test_sa_crud(connection) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index da16ee8..33321b8 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -589,7 +589,8 @@ def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_fo self._json_deserializer = json_deserializer self._json_serializer = json_serializer - # NOTE: _add_declare_for_yql_stmt_vars is temporary and is soon to be removed + # NOTE: _add_declare_for_yql_stmt_vars is temporary and is soon to be removed. + # no need in declare in yql statement here since ydb 24-1 self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars def _describe_table(self, connection, table_name, schema=None):