This repository has been archived by the owner on Sep 12, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
/
sqlalchemy_plugin.py
51 lines (40 loc) · 1.62 KB
/
sqlalchemy_plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""Database connectivity and transaction management for the application."""
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from starlite.plugins.sql_alchemy import SQLAlchemyConfig, SQLAlchemyPlugin
from starlite.plugins.sql_alchemy.config import (
SESSION_SCOPE_KEY,
SESSION_TERMINUS_ASGI_EVENTS,
)
from starlite_saqlalchemy import db, settings
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from starlite.datastructures.state import State
from starlite.types import Message, Scope
__all__ = ["config", "plugin"]
async def before_send_handler(message: "Message", _: "State", scope: "Scope") -> None:
"""Custom `before_send_handler` for SQLAlchemy plugin that inspects the
status of response and commits, or rolls back the database.
Args:
message: ASGI message
_:
scope: ASGI scope
"""
session = cast("AsyncSession | None", scope.get(SESSION_SCOPE_KEY))
try:
if session is not None and message["type"] == "http.response.start":
if 200 <= message["status"] < 300:
await session.commit()
else:
await session.rollback()
finally:
if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
await session.close()
del scope[SESSION_SCOPE_KEY] # type:ignore[misc]
config = SQLAlchemyConfig(
before_send_handler=before_send_handler,
dependency_key=settings.api.DB_SESSION_DEPENDENCY_KEY,
engine_instance=db.engine,
session_maker_instance=db.async_session_factory,
)
plugin = SQLAlchemyPlugin(config=config)