Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ellar_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""EllarSQL Module adds support for SQLAlchemy and Alembic package to your Ellar application"""

__version__ = "0.1.6"
__version__ = "0.1.8"

from .model.database_binds import get_all_metadata, get_metadata
from .module import EllarSQLModule
Expand Down
29 changes: 14 additions & 15 deletions ellar_sql/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,32 @@ def _raise_exception():
return _raise_exception


async def _session_cleanup(
db_service: EllarSQLService, session: t.Union[Session, AsyncSession]
) -> None:
res = session.close()
if isinstance(res, t.Coroutine):
await res

res = db_service.session_factory.remove()
if isinstance(res, t.Coroutine):
await res


@as_middleware
async def session_middleware(
context: IHostContext, call_next: t.Callable[..., t.Coroutine]
):
connection = context.switch_to_http_connection().get_client()

db_service = context.get_service_provider().get(EllarSQLService)
session = db_service.session_factory()

# Create a NEW session for this request
session = db_service.session_factory()
connection.state.session = session

try:
await call_next()
except Exception as ex:
# Only rollback if session is still active
if session.is_active and session.in_transaction():
res = session.rollback()
if isinstance(res, t.Coroutine):
await res
raise ex
finally:
await _session_cleanup(db_service, session)
# Always clean up
if session.is_active:
res = session.close()
if isinstance(res, t.Coroutine):
await res


@Module(
Expand Down
2 changes: 1 addition & 1 deletion ellar_sql/pagination/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ async def _close_session(self) -> None:
def _get_session(self) -> t.Union[sa_orm.Session, AsyncSession, t.Any]:
self._created_session = True
service = current_injector.get(EllarSQLService)
return service.get_scoped_session()()
return service.session_factory_maker()()

def _query_items(self) -> t.List[t.Any]:
if self._is_async:
Expand Down
4 changes: 2 additions & 2 deletions ellar_sql/query/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def get_or_404(
) -> _O:
""" """
db_service = current_injector.get(EllarSQLService)
session = db_service.get_scoped_session()()
session = db_service.session_factory_maker()()

value = session.get(entity, ident, **kwargs)

Expand All @@ -39,7 +39,7 @@ async def get_or_none(
) -> t.Optional[_O]:
""" """
db_service = current_injector.get(EllarSQLService)
session = db_service.get_scoped_session()()
session = db_service.session_factory_maker()()

value = session.get(entity, ident, **kwargs)

Expand Down
24 changes: 9 additions & 15 deletions ellar_sql/services/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import typing as t
from threading import get_ident
from weakref import WeakKeyDictionary

import sqlalchemy as sa
Expand All @@ -14,7 +13,6 @@
)
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_scoped_session,
async_sessionmaker,
)

Expand All @@ -32,6 +30,10 @@


class EllarSQLService:
session_factory: t.Union[
sa_orm.sessionmaker[sa_orm.Session], async_sessionmaker[AsyncSession]
]

def __init__(
self,
databases: t.Union[str, t.Dict[str, t.Any]],
Expand Down Expand Up @@ -61,7 +63,7 @@ def __init__(
self._has_async_engine_driver: bool = False

self._setup(databases, models=models, echo=echo)
self.session_factory = self.get_scoped_session()
self.session_factory = self.session_factory_maker()

@property
def has_async_engine_driver(self) -> bool:
Expand Down Expand Up @@ -177,24 +179,16 @@ def reflect(self, *databases: str) -> None:
continue
metadata_engine.reflect()

def get_scoped_session(
def session_factory_maker(
self,
**extra_options: t.Any,
) -> t.Union[
sa_orm.scoped_session[sa_orm.Session],
async_scoped_session[t.Union[AsyncSession, t.Any]],
]:
) -> t.Union[sa_orm.sessionmaker[sa_orm.Session], async_sessionmaker[AsyncSession]]:
options = self._session_options.copy()
options.update(extra_options)

scope = options.pop("scopefunc", get_ident)

factory = self._make_session_factory(options)

if self.has_async_engine_driver:
return async_scoped_session(factory, scope) # type:ignore[arg-type]
scope = options.pop("scopefunc", None) # noqa: F841

return sa_orm.scoped_session(factory, scope) # type:ignore[arg-type]
return self._make_session_factory(options)

def _make_session_factory(
self, options: t.Dict[str, t.Any]
Expand Down
12 changes: 4 additions & 8 deletions tests/test_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_model_export_without_filter(self, db_service, ignore_base):
"id": 1,
"name": "Ellar",
}
db_service.session_factory.close()
# db_service.session_factory.close()

def test_model_exclude_none(self, db_service, ignore_base):
user_factory = get_model_factory(db_service)
Expand All @@ -59,7 +59,7 @@ def test_model_exclude_none(self, db_service, ignore_base):
"id": 1,
"name": "Ellar",
}
db_service.session_factory.close()
# db_service.session_factory.close()

def test_model_export_include(self, db_service, ignore_base):
user_factory = get_model_factory(db_service)
Expand All @@ -73,7 +73,7 @@ def test_model_export_include(self, db_service, ignore_base):
"id",
"name",
}
db_service.session_factory.close()
# db_service.session_factory.close()

def test_model_export_exclude(self, db_service, ignore_base):
user_factory = get_model_factory(db_service)
Expand All @@ -83,7 +83,7 @@ def test_model_export_exclude(self, db_service, ignore_base):
user = user_factory()

assert user.dict(exclude={"email", "name"}).keys() == {"address", "city", "id"}
db_service.session_factory.close()
# db_service.session_factory.close()


@pytest.mark.asyncio
Expand All @@ -105,7 +105,6 @@ async def test_model_export_without_filter_async(
"id": 1,
"name": "Ellar",
}
await db_service_async.session_factory.close()

async def test_model_exclude_none_async(self, db_service_async, ignore_base):
user_factory = get_model_factory(db_service_async)
Expand All @@ -121,7 +120,6 @@ async def test_model_exclude_none_async(self, db_service_async, ignore_base):
"id": 1,
"name": "Ellar",
}
await db_service_async.session_factory.close()

async def test_model_export_include_async(self, db_service_async, ignore_base):
user_factory = get_model_factory(db_service_async)
Expand All @@ -135,7 +133,6 @@ async def test_model_export_include_async(self, db_service_async, ignore_base):
"id",
"name",
}
await db_service_async.session_factory.close()

async def test_model_export_exclude_async(self, db_service_async, ignore_base):
user_factory = get_model_factory(db_service_async)
Expand All @@ -145,4 +142,3 @@ async def test_model_export_exclude_async(self, db_service_async, ignore_base):
user = user_factory()

assert user.dict(exclude={"email", "name"}).keys() == {"address", "city", "id"}
await db_service_async.session_factory.close()
Loading