From c7300583a8534b23d50380ecb731b8b220374318 Mon Sep 17 00:00:00 2001 From: Sergii Maksymov Date: Fri, 24 Jun 2022 16:46:20 +0200 Subject: [PATCH] use AsyncSession from SQLAlchemy without custom constructor --- sqlmodel/ext/asyncio/__init__.py | 5 +++ sqlmodel/ext/asyncio/session.py | 77 +++++++++++++++++++------------- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/sqlmodel/ext/asyncio/__init__.py b/sqlmodel/ext/asyncio/__init__.py index e69de29bb2..0d4b7f19c4 100644 --- a/sqlmodel/ext/asyncio/__init__.py +++ b/sqlmodel/ext/asyncio/__init__.py @@ -0,0 +1,5 @@ +from .session import AsyncSession + +__all__ = [ + "AsyncSession", +] diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 80267b25e5..4f37a9783c 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -1,62 +1,79 @@ -from typing import Any, Mapping, Optional, Sequence, TypeVar, Union +from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload from sqlalchemy import util from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession -from sqlalchemy.ext.asyncio import engine from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine -from sqlalchemy.util.concurrency import greenlet_spawn from sqlmodel.sql.base import Executable -from ...engine.result import ScalarResult +from ...engine.result import Result, ScalarResult from ...orm.session import Session -from ...sql.expression import Select +from ...sql.expression import Select, SelectOfScalar _T = TypeVar("_T") class AsyncSession(_AsyncSession): - sync_session: Session - def __init__( self, bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, **kw: Any, ): - # All the same code of the original AsyncSession - kw["future"] = True - if bind: - self.bind = bind - bind = engine._get_sync_engine_or_connection(bind) # type: ignore - - if binds: - self.binds = binds - binds = { - key: engine._get_sync_engine_or_connection(b) # type: ignore - for key, b in binds.items() - } - - self.sync_session = self._proxied = self._assign_proxied( # type: ignore - Session(bind=bind, binds=binds, **kw) # type: ignore - ) + opts = dict(expire_on_commit=False) + super().__init__(bind, binds, sync_session_class=Session, **{**opts, **kw}) + @overload async def exec( self, - statement: Union[Select[_T], Executable[_T]], + statement: Select[_T], + *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[Any, Any] = util.EMPTY_DICT, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Result[_T]: + ... + + @overload + async def exec( + self, + statement: SelectOfScalar[_T], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, **kw: Any, ) -> ScalarResult[_T]: - # TODO: the documentation says execution_options accepts a dict, but only - # util.immutabledict has the union() method. Is this a bug in SQLAlchemy? - execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore + ... - return await greenlet_spawn( - self.sync_session.exec, + async def exec( + self, + statement: Union[ + Select[_T], + SelectOfScalar[_T], + Executable[_T], + ], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Union[Result[_T], ScalarResult[_T]]: + results = await super().execute( statement, params=params, execution_options=execution_options, bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, **kw, ) + if isinstance(statement, SelectOfScalar): + return results.scalars() # type: ignore + return results # type: ignore