Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ellar_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""EllarSQL Module adds support for SQLAlchemy and Alembic package to your Ellar application"""

__version__ = "0.1.3"
__version__ = "0.1.5"

from .model.database_binds import get_all_metadata, get_metadata
from .module import EllarSQLModule
Expand Down
36 changes: 36 additions & 0 deletions ellar_sql/factory/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect
import typing as t

import factory
import sqlalchemy as sa
import sqlalchemy.orm as sa_orm
from ellar.threading import run_as_sync
Expand Down Expand Up @@ -118,3 +120,37 @@ def _save(
cls._session_execute(session.commit)
cls._session_execute(session.refresh, obj)
return obj


class EllarSQLSubFactoryId(factory.SubFactory):
"""
A SubFactory that returns the id of the created object.
"""

def evaluate(self, instance, step, extra):
value = super().evaluate(instance, step, extra)
if inspect.isawaitable(value):

async def resolve_value():
resolved_value = await value
return resolved_value.id

return resolve_value()
return value


class EllarSQLSubFactory(factory.SubFactory):
"""
A SubFactory that returns the created object.
"""

def evaluate(self, instance, step, extra):
value = super().evaluate(instance, step, extra)
if inspect.isawaitable(value):

async def resolve_value():
resolved_value = await value
return resolved_value

return resolve_value()
return value
23 changes: 14 additions & 9 deletions ellar_sql/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ def _raise_exception():
return _raise_exception


async def _session_cleanup(
db_service: EllarSQLService, session: t.Union[Session, AsyncSession]
) -> None:
res = db_service.session_factory.remove()
if isinstance(res, t.Coroutine):
await res

res = session.close()
if isinstance(res, t.Coroutine):
await res


@as_middleware
async def session_middleware(
context: IHostContext, call_next: t.Callable[..., t.Coroutine]
Expand All @@ -40,15 +52,8 @@ async def session_middleware(

try:
await call_next()
except Exception as ex:
res = session.rollback()
if isinstance(res, t.Coroutine):
await res
raise ex

res = db_service.session_factory.remove()
if isinstance(res, t.Coroutine):
await res
finally:
await _session_cleanup(db_service, session)


@Module(
Expand Down
4 changes: 2 additions & 2 deletions requirements-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ autoflake
ellar-cli >= 0.3.7
factory-boy >= 3.3.0
httpx
Pillow >=10.4.0, <11.2.0
mypy == 1.15.0
Pillow >=10.4.0, <11.2.0
pytest >= 7.1.3,< 9.0.0
pytest-asyncio
pytest-cov >= 2.12.0,< 7.0.0
ruff ==0.9.9
ruff ==0.13.3
2 changes: 1 addition & 1 deletion samples/db-learning/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def tm():
test_module = Test.create_test_module(modules=[ApplicationModule])
app = test_module.create_application()

with execute_async_context_manager(app.application_context()):
with execute_async_context_manager(app.with_injector_context()):
yield test_module


Expand Down