diff --git a/test/test_orm.py b/test/test_orm.py new file mode 100644 index 0000000..10aae2c --- /dev/null +++ b/test/test_orm.py @@ -0,0 +1,55 @@ +import pytest +import sqlalchemy as sa +from types import MethodType +from sqlalchemy import Column, Integer, Unicode +from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.testing.fixtures import TablesTest, config + + +class TestDirectories(TablesTest): + __backend__ = True + + def prepare_table(self, engine): + base = declarative_base() + + class Table(base): + __tablename__ = "dir/test" + id = Column(Integer, primary_key=True) + text = Column(Unicode) + + base.metadata.create_all(engine) + session = sessionmaker(bind=engine)() + session.add(Table(id=2, text="foo")) + session.commit() + return base, Table, session + + def try_update(self, session, Table): + row = session.query(Table).first() + row.text = "bar" + session.commit() + return row + + def drop_table(self, base, engine): + base.metadata.drop_all(engine) + + def bind_old_method_to_dialect(self, dialect): + def _handle_column_name(self, variable): + return variable + + dialect._handle_column_name = MethodType(_handle_column_name, dialect) + + def test_directories(self): + engine_good = sa.create_engine(config.db_url) + base, Table, session = self.prepare_table(engine_good) + row = self.try_update(session, Table) + assert row.id == 2 + assert row.text == "bar" + self.drop_table(base, engine_good) + + engine_bad = sa.create_engine(config.db_url) + self.bind_old_method_to_dialect(engine_bad.dialect) + base, Table, session = self.prepare_table(engine_bad) + with pytest.raises(Exception) as excinfo: + self.try_update(session, Table) + assert "Unknown name: $dir" in str(excinfo.value) + self.drop_table(base, engine_bad) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 33321b8..e20a2e3 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -584,7 +584,13 @@ class YqlDialect(StrCompileDialect): def import_dbapi(cls: Any): return dbapi.YdbDBApi() - def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, **kwargs): + def __init__( + self, + json_serializer=None, + json_deserializer=None, + _add_declare_for_yql_stmt_vars=False, + **kwargs, + ): super().__init__(**kwargs) self._json_deserializer = json_deserializer @@ -673,6 +679,9 @@ def do_rollback(self, dbapi_connection: dbapi.Connection) -> None: def do_commit(self, dbapi_connection: dbapi.Connection) -> None: dbapi_connection.commit() + def _handle_column_name(self, variable): + return "`" + variable + "`" + def _format_variables( self, statement: str, @@ -694,7 +703,9 @@ def _format_variables( variable_names = set(parameters.keys()) formatted_parameters = {f"${k}": v for k, v in parameters.items()} - formatted_variable_names = {variable_name: f"${variable_name}" for variable_name in variable_names} + formatted_variable_names = { + variable_name: f"${self._handle_column_name(variable_name)}" for variable_name in variable_names + } formatted_statement = formatted_statement % formatted_variable_names formatted_statement = formatted_statement.replace("%%", "%") @@ -702,7 +713,10 @@ def _format_variables( def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types): declarations = "\n".join( - [f"DECLARE {param_name} as {str(param_type)};" for param_name, param_type in parameters_types.items()] + [ + f"DECLARE $`{param_name[1:]}` as {str(param_type)};" + for param_name, param_type in parameters_types.items() + ] ) return f"{declarations}\n{statement}"