Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sqlmodel/ext/asyncio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .session import AsyncSession

__all__ = [
"AsyncSession",
]
77 changes: 47 additions & 30 deletions sqlmodel/ext/asyncio/session.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,79 @@
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload

from sqlalchemy import util
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
from sqlalchemy.ext.asyncio import engine
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
from sqlalchemy.util.concurrency import greenlet_spawn
from sqlmodel.sql.base import Executable

from ...engine.result import ScalarResult
from ...engine.result import Result, ScalarResult
from ...orm.session import Session
from ...sql.expression import Select
from ...sql.expression import Select, SelectOfScalar

_T = TypeVar("_T")


class AsyncSession(_AsyncSession):
sync_session: Session

def __init__(
self,
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
**kw: Any,
):
# All the same code of the original AsyncSession
kw["future"] = True
if bind:
self.bind = bind
bind = engine._get_sync_engine_or_connection(bind) # type: ignore

if binds:
self.binds = binds
binds = {
key: engine._get_sync_engine_or_connection(b) # type: ignore
for key, b in binds.items()
}

self.sync_session = self._proxied = self._assign_proxied( # type: ignore
Session(bind=bind, binds=binds, **kw) # type: ignore
)
opts = dict(expire_on_commit=False)
super().__init__(bind, binds, sync_session_class=Session, **{**opts, **kw})

@overload
async def exec(
self,
statement: Union[Select[_T], Executable[_T]],
statement: Select[_T],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Result[_T]:
...

@overload
async def exec(
self,
statement: SelectOfScalar[_T],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> ScalarResult[_T]:
# TODO: the documentation says execution_options accepts a dict, but only
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
...

return await greenlet_spawn(
self.sync_session.exec,
async def exec(
self,
statement: Union[
Select[_T],
SelectOfScalar[_T],
Executable[_T],
],
*,
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
**kw: Any,
) -> Union[Result[_T], ScalarResult[_T]]:
results = await super().execute(
statement,
params=params,
execution_options=execution_options,
bind_arguments=bind_arguments,
_parent_execute_state=_parent_execute_state,
_add_event=_add_event,
**kw,
)
if isinstance(statement, SelectOfScalar):
return results.scalars() # type: ignore
return results # type: ignore