diff --git a/test/test_core.py b/test/test_core.py index 8d9a53f..52825c5 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -83,6 +83,11 @@ def test_sa_crud(self, connection): (5, "c"), ] + def test_sa_crud_with_add_declare(self): + engine = sa.create_engine(config.db_url, _add_declare_for_yql_stmt_vars=True) + with engine.connect() as connection: + self.test_sa_crud(connection) + class TestSimpleSelect(TablesTest): __backend__ = True diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index e5e3ca2..33321b8 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -584,11 +584,14 @@ class YqlDialect(StrCompileDialect): def import_dbapi(cls: Any): return dbapi.YdbDBApi() - def __init__(self, json_serializer=None, json_deserializer=None, **kwargs): + def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, **kwargs): super().__init__(**kwargs) self._json_deserializer = json_deserializer self._json_serializer = json_serializer + # NOTE: _add_declare_for_yql_stmt_vars is temporary and is soon to be removed. + # no need in declare in yql statement here since ydb 24-1 + self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars def _describe_table(self, connection, table_name, schema=None): if schema is not None: @@ -697,6 +700,12 @@ def _format_variables( formatted_statement = formatted_statement.replace("%%", "%") return formatted_statement, formatted_parameters + def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types): + declarations = "\n".join( + [f"DECLARE {param_name} as {str(param_type)};" for param_name, param_type in parameters_types.items()] + ) + return f"{declarations}\n{statement}" + def _make_ydb_operation( self, statement: str, @@ -710,6 +719,8 @@ def _make_ydb_operation( parameters_types = context.compiled.get_bind_types(parameters) parameters_types = {f"${k}": v for k, v in parameters_types.items()} statement, parameters = self._format_variables(statement, parameters, execute_many) + if self._add_declare_for_yql_stmt_vars: + statement = self._add_declare_for_yql_stmt_vars_impl(statement, parameters_types) return dbapi.YdbQuery(yql_text=statement, parameters_types=parameters_types, is_ddl=is_ddl), parameters statement, parameters = self._format_variables(statement, parameters, execute_many)