diff --git a/ellar_sql/__init__.py b/ellar_sql/__init__.py index 0051886..252099c 100644 --- a/ellar_sql/__init__.py +++ b/ellar_sql/__init__.py @@ -1,6 +1,6 @@ """EllarSQL Module adds support for SQLAlchemy and Alembic package to your Ellar application""" -__version__ = "0.1.6" +__version__ = "0.1.8" from .model.database_binds import get_all_metadata, get_metadata from .module import EllarSQLModule diff --git a/ellar_sql/module.py b/ellar_sql/module.py index 7491e54..2af2034 100644 --- a/ellar_sql/module.py +++ b/ellar_sql/module.py @@ -27,33 +27,32 @@ def _raise_exception(): return _raise_exception -async def _session_cleanup( - db_service: EllarSQLService, session: t.Union[Session, AsyncSession] -) -> None: - res = session.close() - if isinstance(res, t.Coroutine): - await res - - res = db_service.session_factory.remove() - if isinstance(res, t.Coroutine): - await res - - @as_middleware async def session_middleware( context: IHostContext, call_next: t.Callable[..., t.Coroutine] ): connection = context.switch_to_http_connection().get_client() - db_service = context.get_service_provider().get(EllarSQLService) - session = db_service.session_factory() + # Create a NEW session for this request + session = db_service.session_factory() connection.state.session = session try: await call_next() + except Exception as ex: + # Only rollback if session is still active + if session.is_active and session.in_transaction(): + res = session.rollback() + if isinstance(res, t.Coroutine): + await res + raise ex finally: - await _session_cleanup(db_service, session) + # Always clean up + if session.is_active: + res = session.close() + if isinstance(res, t.Coroutine): + await res @Module( diff --git a/ellar_sql/pagination/base.py b/ellar_sql/pagination/base.py index e66fc86..e538ada 100644 --- a/ellar_sql/pagination/base.py +++ b/ellar_sql/pagination/base.py @@ -286,7 +286,7 @@ async def _close_session(self) -> None: def _get_session(self) -> t.Union[sa_orm.Session, AsyncSession, t.Any]: self._created_session = True service = current_injector.get(EllarSQLService) - return service.get_scoped_session()() + return service.session_factory_maker()() def _query_items(self) -> t.List[t.Any]: if self._is_async: diff --git a/ellar_sql/query/utils.py b/ellar_sql/query/utils.py index 7df73e0..ef79177 100644 --- a/ellar_sql/query/utils.py +++ b/ellar_sql/query/utils.py @@ -19,7 +19,7 @@ async def get_or_404( ) -> _O: """ """ db_service = current_injector.get(EllarSQLService) - session = db_service.get_scoped_session()() + session = db_service.session_factory_maker()() value = session.get(entity, ident, **kwargs) @@ -39,7 +39,7 @@ async def get_or_none( ) -> t.Optional[_O]: """ """ db_service = current_injector.get(EllarSQLService) - session = db_service.get_scoped_session()() + session = db_service.session_factory_maker()() value = session.get(entity, ident, **kwargs) diff --git a/ellar_sql/services/base.py b/ellar_sql/services/base.py index 5983e86..869e8f6 100644 --- a/ellar_sql/services/base.py +++ b/ellar_sql/services/base.py @@ -1,6 +1,5 @@ import os import typing as t -from threading import get_ident from weakref import WeakKeyDictionary import sqlalchemy as sa @@ -14,7 +13,6 @@ ) from sqlalchemy.ext.asyncio import ( AsyncSession, - async_scoped_session, async_sessionmaker, ) @@ -32,6 +30,10 @@ class EllarSQLService: + session_factory: t.Union[ + sa_orm.sessionmaker[sa_orm.Session], async_sessionmaker[AsyncSession] + ] + def __init__( self, databases: t.Union[str, t.Dict[str, t.Any]], @@ -61,7 +63,7 @@ def __init__( self._has_async_engine_driver: bool = False self._setup(databases, models=models, echo=echo) - self.session_factory = self.get_scoped_session() + self.session_factory = self.session_factory_maker() @property def has_async_engine_driver(self) -> bool: @@ -177,24 +179,16 @@ def reflect(self, *databases: str) -> None: continue metadata_engine.reflect() - def get_scoped_session( + def session_factory_maker( self, **extra_options: t.Any, - ) -> t.Union[ - sa_orm.scoped_session[sa_orm.Session], - async_scoped_session[t.Union[AsyncSession, t.Any]], - ]: + ) -> t.Union[sa_orm.sessionmaker[sa_orm.Session], async_sessionmaker[AsyncSession]]: options = self._session_options.copy() options.update(extra_options) - scope = options.pop("scopefunc", get_ident) - - factory = self._make_session_factory(options) - - if self.has_async_engine_driver: - return async_scoped_session(factory, scope) # type:ignore[arg-type] + scope = options.pop("scopefunc", None) # noqa: F841 - return sa_orm.scoped_session(factory, scope) # type:ignore[arg-type] + return self._make_session_factory(options) def _make_session_factory( self, options: t.Dict[str, t.Any] diff --git a/tests/test_model_export.py b/tests/test_model_export.py index e88c4c1..ef3346f 100644 --- a/tests/test_model_export.py +++ b/tests/test_model_export.py @@ -43,7 +43,7 @@ def test_model_export_without_filter(self, db_service, ignore_base): "id": 1, "name": "Ellar", } - db_service.session_factory.close() + # db_service.session_factory.close() def test_model_exclude_none(self, db_service, ignore_base): user_factory = get_model_factory(db_service) @@ -59,7 +59,7 @@ def test_model_exclude_none(self, db_service, ignore_base): "id": 1, "name": "Ellar", } - db_service.session_factory.close() + # db_service.session_factory.close() def test_model_export_include(self, db_service, ignore_base): user_factory = get_model_factory(db_service) @@ -73,7 +73,7 @@ def test_model_export_include(self, db_service, ignore_base): "id", "name", } - db_service.session_factory.close() + # db_service.session_factory.close() def test_model_export_exclude(self, db_service, ignore_base): user_factory = get_model_factory(db_service) @@ -83,7 +83,7 @@ def test_model_export_exclude(self, db_service, ignore_base): user = user_factory() assert user.dict(exclude={"email", "name"}).keys() == {"address", "city", "id"} - db_service.session_factory.close() + # db_service.session_factory.close() @pytest.mark.asyncio @@ -105,7 +105,6 @@ async def test_model_export_without_filter_async( "id": 1, "name": "Ellar", } - await db_service_async.session_factory.close() async def test_model_exclude_none_async(self, db_service_async, ignore_base): user_factory = get_model_factory(db_service_async) @@ -121,7 +120,6 @@ async def test_model_exclude_none_async(self, db_service_async, ignore_base): "id": 1, "name": "Ellar", } - await db_service_async.session_factory.close() async def test_model_export_include_async(self, db_service_async, ignore_base): user_factory = get_model_factory(db_service_async) @@ -135,7 +133,6 @@ async def test_model_export_include_async(self, db_service_async, ignore_base): "id", "name", } - await db_service_async.session_factory.close() async def test_model_export_exclude_async(self, db_service_async, ignore_base): user_factory = get_model_factory(db_service_async) @@ -145,4 +142,3 @@ async def test_model_export_exclude_async(self, db_service_async, ignore_base): user = user_factory() assert user.dict(exclude={"email", "name"}).keys() == {"address", "city", "id"} - await db_service_async.session_factory.close() diff --git a/tests/test_model_factory.py b/tests/test_model_factory.py index 19e1221..2998b78 100644 --- a/tests/test_model_factory.py +++ b/tests/test_model_factory.py @@ -64,13 +64,14 @@ class Meta(group_meta): class TestModelFactory: def test_model_factory(self, db_service, ignore_base): + session = db_service.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_COMMIT, }, group={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_COMMIT, }, ) @@ -79,15 +80,16 @@ def test_model_factory(self, db_service, ignore_base): group = group_factory() assert group.dict().keys() == {"name", "user_id", "id"} - db_service.session_factory.close() + session.close() def test_model_factory_session_none(self, db_service, ignore_base): + session = db_service.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, }, group={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, }, ) @@ -96,16 +98,17 @@ def test_model_factory_session_none(self, db_service, ignore_base): group = group_factory() assert f"" == repr(group) assert group.dict().keys() != {"name", "user_id", "id"} - db_service.session_factory.close() + session.close() def test_model_factory_session_flush(self, db_service, ignore_base): + session = db_service.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, }, group={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, }, ) @@ -114,16 +117,17 @@ def test_model_factory_session_flush(self, db_service, ignore_base): group = group_factory() assert group.dict().keys() == {"name", "user_id", "id"} - db_service.session_factory.close() + session.close() def test_model_factory_get_or_create(self, db_service, ignore_base): + session = db_service.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_COMMIT, }, group={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_COMMIT, "sqlalchemy_get_or_create": ("name",), }, @@ -141,18 +145,19 @@ def test_model_factory_get_or_create(self, db_service, ignore_base): assert group.dict() == group2.dict() != group3.dict() - db_service.session_factory.close() + session.close() def test_model_factory_get_or_create_for_integrity_error( self, db_service, ignore_base ): + session = db_service.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_COMMIT, }, group={ - "sqlalchemy_session_factory": db_service.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_COMMIT, "sqlalchemy_get_or_create": ("name",), }, @@ -164,19 +169,20 @@ def test_model_factory_get_or_create_for_integrity_error( with pytest.raises(IntegrityError): group_factory(name="new group", user=group.user) - db_service.session_factory.close() + session.close() @pytest.mark.asyncio class TestModelFactoryAsync: async def test_model_factory_async(self, db_service_async, ignore_base): + session = db_service_async.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_COMMIT, }, group={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_COMMIT, }, ) @@ -185,17 +191,18 @@ async def test_model_factory_async(self, db_service_async, ignore_base): group = group_factory() assert group.dict().keys() == {"name", "user_id", "id"} - await db_service_async.session_factory.close() + await session.close() async def test_model_factory_session_none_async( self, db_service_async, ignore_base ): + session = db_service_async.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, }, group={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, }, ) @@ -204,18 +211,19 @@ async def test_model_factory_session_none_async( group = group_factory() assert f"" == repr(group) assert group.dict().keys() != {"name", "user_id", "id"} - await db_service_async.session_factory.close() + await session.close() async def test_model_factory_session_flush_async( self, db_service_async, ignore_base ): + session = db_service_async.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, }, group={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, }, ) @@ -224,18 +232,19 @@ async def test_model_factory_session_flush_async( group = group_factory() assert group.dict().keys() == {"name", "user_id", "id"} - await db_service_async.session_factory.close() + await session.close() async def test_model_factory_get_or_create_async( self, db_service_async, ignore_base ): + session = db_service_async.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, }, group={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, "sqlalchemy_get_or_create": ("name",), }, @@ -252,19 +261,19 @@ async def test_model_factory_get_or_create_async( assert group.id == group2.id assert group.dict() == group2.dict() != group3.dict() - - await db_service_async.session_factory.close() + await session.close() async def test_model_factory_get_or_create_for_integrity_error_async( self, db_service_async, ignore_base ): + session = db_service_async.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, }, group={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, "sqlalchemy_get_or_create": ("name",), }, @@ -275,19 +284,19 @@ async def test_model_factory_get_or_create_for_integrity_error_async( with pytest.raises(IntegrityError): group_factory(name="new group", user=group.user) - - await db_service_async.session_factory.close() + await session.close() async def test_model_factory_get_or_create_raises_error_for_missing_field_async( self, db_service_async, ignore_base ): + session = db_service_async.session_factory() group_factory = create_group_model( user={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, }, group={ - "sqlalchemy_session_factory": db_service_async.session_factory, + "sqlalchemy_session": session, "sqlalchemy_session_persistence": SESSION_PERSISTENCE_FLUSH, "sqlalchemy_get_or_create": ("name", "user_id"), }, @@ -295,5 +304,4 @@ async def test_model_factory_get_or_create_raises_error_for_missing_field_async( db_service_async.create_all() with pytest.raises(FactoryError): group_factory() - - await db_service_async.session_factory.close() + await session.close()