Skip to content

Commit

Permalink
Merge pull request #41 Fix orm for tables in directories from kabulov…
Browse files Browse the repository at this point in the history
…/fix_orm_for_tables_in_directories_2
  • Loading branch information
rekby committed Apr 22, 2024
2 parents 6dc5578 + db5ce93 commit 368c0ba
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
55 changes: 55 additions & 0 deletions test/test_orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import sqlalchemy as sa
from types import MethodType
from sqlalchemy import Column, Integer, Unicode
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.testing.fixtures import TablesTest, config


class TestDirectories(TablesTest):
__backend__ = True

def prepare_table(self, engine):
base = declarative_base()

class Table(base):
__tablename__ = "dir/test"
id = Column(Integer, primary_key=True)
text = Column(Unicode)

base.metadata.create_all(engine)
session = sessionmaker(bind=engine)()
session.add(Table(id=2, text="foo"))
session.commit()
return base, Table, session

def try_update(self, session, Table):
row = session.query(Table).first()
row.text = "bar"
session.commit()
return row

def drop_table(self, base, engine):
base.metadata.drop_all(engine)

def bind_old_method_to_dialect(self, dialect):
def _handle_column_name(self, variable):
return variable

dialect._handle_column_name = MethodType(_handle_column_name, dialect)

def test_directories(self):
engine_good = sa.create_engine(config.db_url)
base, Table, session = self.prepare_table(engine_good)
row = self.try_update(session, Table)
assert row.id == 2
assert row.text == "bar"
self.drop_table(base, engine_good)

engine_bad = sa.create_engine(config.db_url)
self.bind_old_method_to_dialect(engine_bad.dialect)
base, Table, session = self.prepare_table(engine_bad)
with pytest.raises(Exception) as excinfo:
self.try_update(session, Table)
assert "Unknown name: $dir" in str(excinfo.value)
self.drop_table(base, engine_bad)
20 changes: 17 additions & 3 deletions ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,13 @@ class YqlDialect(StrCompileDialect):
def import_dbapi(cls: Any):
return dbapi.YdbDBApi()

def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, **kwargs):
def __init__(
self,
json_serializer=None,
json_deserializer=None,
_add_declare_for_yql_stmt_vars=False,
**kwargs,
):
super().__init__(**kwargs)

self._json_deserializer = json_deserializer
Expand Down Expand Up @@ -673,6 +679,9 @@ def do_rollback(self, dbapi_connection: dbapi.Connection) -> None:
def do_commit(self, dbapi_connection: dbapi.Connection) -> None:
dbapi_connection.commit()

def _handle_column_name(self, variable):
return "`" + variable + "`"

def _format_variables(
self,
statement: str,
Expand All @@ -694,15 +703,20 @@ def _format_variables(
variable_names = set(parameters.keys())
formatted_parameters = {f"${k}": v for k, v in parameters.items()}

formatted_variable_names = {variable_name: f"${variable_name}" for variable_name in variable_names}
formatted_variable_names = {
variable_name: f"${self._handle_column_name(variable_name)}" for variable_name in variable_names
}
formatted_statement = formatted_statement % formatted_variable_names

formatted_statement = formatted_statement.replace("%%", "%")
return formatted_statement, formatted_parameters

def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types):
declarations = "\n".join(
[f"DECLARE {param_name} as {str(param_type)};" for param_name, param_type in parameters_types.items()]
[
f"DECLARE $`{param_name[1:]}` as {str(param_type)};"
for param_name, param_type in parameters_types.items()
]
)
return f"{declarations}\n{statement}"

Expand Down

0 comments on commit 368c0ba

Please sign in to comment.