Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def test_sa_crud(self, connection):
(5, "c"),
]

def test_sa_crud_with_add_declare(self):
engine = sa.create_engine(config.db_url, _add_declare_for_yql_stmt_vars=True)
with engine.connect() as connection:
self.test_sa_crud(connection)


class TestSimpleSelect(TablesTest):
__backend__ = True
Expand Down
13 changes: 12 additions & 1 deletion ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,14 @@ class YqlDialect(StrCompileDialect):
def import_dbapi(cls: Any):
return dbapi.YdbDBApi()

def __init__(self, json_serializer=None, json_deserializer=None, **kwargs):
def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, **kwargs):
super().__init__(**kwargs)

self._json_deserializer = json_deserializer
self._json_serializer = json_serializer
# NOTE: _add_declare_for_yql_stmt_vars is temporary and is soon to be removed.
# no need in declare in yql statement here since ydb 24-1
self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars

def _describe_table(self, connection, table_name, schema=None):
if schema is not None:
Expand Down Expand Up @@ -697,6 +700,12 @@ def _format_variables(
formatted_statement = formatted_statement.replace("%%", "%")
return formatted_statement, formatted_parameters

def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types):
declarations = "\n".join(
[f"DECLARE {param_name} as {str(param_type)};" for param_name, param_type in parameters_types.items()]
)
return f"{declarations}\n{statement}"

def _make_ydb_operation(
self,
statement: str,
Expand All @@ -710,6 +719,8 @@ def _make_ydb_operation(
parameters_types = context.compiled.get_bind_types(parameters)
parameters_types = {f"${k}": v for k, v in parameters_types.items()}
statement, parameters = self._format_variables(statement, parameters, execute_many)
if self._add_declare_for_yql_stmt_vars:
statement = self._add_declare_for_yql_stmt_vars_impl(statement, parameters_types)
return dbapi.YdbQuery(yql_text=statement, parameters_types=parameters_types, is_ddl=is_ddl), parameters

statement, parameters = self._format_variables(statement, parameters, execute_many)
Expand Down