Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

Commit

Permalink
refactor(db)!: move sqlalchemy config into db sub-package.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterschutt committed Nov 13, 2022
1 parent 3189358 commit fd3b6f6
Show file tree
Hide file tree
Showing 15 changed files with 137 additions and 115 deletions.
4 changes: 2 additions & 2 deletions src/starlite_saqlalchemy/__init__.py
Expand Up @@ -25,13 +25,13 @@ def example_handler() -> dict:
from . import (
cache,
compression,
db,
dependencies,
dto,
exceptions,
health,
log,
openapi,
orm,
redis,
repository,
sentry,
Expand All @@ -47,13 +47,13 @@ def example_handler() -> dict:
"PluginConfig",
"cache",
"compression",
"db",
"dependencies",
"dto",
"exceptions",
"health",
"log",
"openapi",
"orm",
"redis",
"repository",
"sentry",
Expand Down
81 changes: 81 additions & 0 deletions src/starlite_saqlalchemy/db/__init__.py
@@ -0,0 +1,81 @@
"""Database connectivity and transaction management for the application."""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any
from uuid import UUID

from orjson import dumps, loads
from sqlalchemy import event
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool

from starlite_saqlalchemy import settings

from . import orm

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession

__all__ = ["async_session_factory", "engine", "orm"]


def _default(val: Any) -> str:
if isinstance(val, UUID):
return str(val)
raise TypeError()


engine = create_async_engine(
settings.db.URL,
echo=settings.db.ECHO,
echo_pool=settings.db.ECHO_POOL,
json_serializer=partial(dumps, default=_default),
max_overflow=settings.db.POOL_MAX_OVERFLOW,
pool_size=settings.db.POOL_SIZE,
pool_timeout=settings.db.POOL_TIMEOUT,
poolclass=NullPool if settings.db.POOL_DISABLE else None,
)
"""Configure via [DatabaseSettings][starlite_saqlalchemy.settings.DatabaseSettings]. Overrides
default JSON serializer to use `orjson`. See
[`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine] for detailed instructions.
"""
async_session_factory: async_sessionmaker[AsyncSession] = async_sessionmaker(engine)
"""
Database session factory. See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker].
"""


@event.listens_for(engine.sync_engine, "connect")
def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any: # pragma: no cover
"""Using orjson for serialization of the json column values means that the
output is binary, not `str` like `json.dumps` would output.
SQLAlchemy expects that the json serializer returns `str` and calls `.encode()` on the value to
turn it to bytes before writing to the JSONB column. I'd need to either wrap `orjson.dumps` to
return a `str` so that SQLAlchemy could then convert it to binary, or do the following, which
changes the behaviour of the dialect to expect a binary value from the serializer.
See Also https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934 pylint: disable=line-too-long
"""

def encoder(bin_value: bytes) -> bytes:
# \x01 is the prefix for jsonb used by PostgreSQL.
# asyncpg requires it when format='binary'
return b"\x01" + bin_value

def decoder(bin_value: bytes) -> Any:
# the byte is the \x01 prefix for jsonb used by PostgreSQL.
# asyncpg returns it when format='binary'
return loads(bin_value[1:])

dbapi_connection.await_(
dbapi_connection.driver_connection.set_type_codec(
"jsonb",
encoder=encoder,
decoder=decoder,
schema="pg_catalog",
format="binary",
)
)
File renamed without changes.
2 changes: 1 addition & 1 deletion src/starlite_saqlalchemy/repository/sqlalchemy.py
Expand Up @@ -26,7 +26,7 @@
from sqlalchemy.engine import Result
from sqlalchemy.ext.asyncio import AsyncSession

from starlite_saqlalchemy import orm
from starlite_saqlalchemy.db import orm
from starlite_saqlalchemy.repository.types import FilterTypes

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion src/starlite_saqlalchemy/service.py
Expand Up @@ -10,8 +10,8 @@
import logging
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from starlite_saqlalchemy.db import async_session_factory
from starlite_saqlalchemy.repository.sqlalchemy import ModelT
from starlite_saqlalchemy.sqlalchemy_plugin import async_session_factory
from starlite_saqlalchemy.worker import queue

if TYPE_CHECKING:
Expand Down
81 changes: 5 additions & 76 deletions src/starlite_saqlalchemy/sqlalchemy_plugin.py
@@ -1,93 +1,22 @@
"""Database connectivity and transaction management for the application."""
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
from typing import TYPE_CHECKING, cast

from orjson import dumps, loads
from sqlalchemy import event
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from starlite.plugins.sql_alchemy import SQLAlchemyConfig, SQLAlchemyPlugin
from starlite.plugins.sql_alchemy.config import (
SESSION_SCOPE_KEY,
SESSION_TERMINUS_ASGI_EVENTS,
)

from starlite_saqlalchemy import settings
from starlite_saqlalchemy import db, settings

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from starlite.datastructures.state import State
from starlite.types import Message, Scope

__all__ = [
"async_session_factory",
"config",
"engine",
"plugin",
]


def _default(val: Any) -> str:
if isinstance(val, UUID):
return str(val)
raise TypeError()


engine = create_async_engine(
settings.db.URL,
echo=settings.db.ECHO,
echo_pool=settings.db.ECHO_POOL,
json_serializer=partial(dumps, default=_default),
max_overflow=settings.db.POOL_MAX_OVERFLOW,
pool_size=settings.db.POOL_SIZE,
pool_timeout=settings.db.POOL_TIMEOUT,
poolclass=NullPool if settings.db.POOL_DISABLE else None,
)
"""Configure via [DatabaseSettings][starlite_saqlalchemy.settings.DatabaseSettings]. Overrides
default JSON serializer to use `orjson`. See
[`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine] for detailed instructions.
"""
async_session_factory: async_sessionmaker[AsyncSession] = async_sessionmaker(engine)
"""
Database session factory. See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker].
"""


@event.listens_for(engine.sync_engine, "connect")
def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any:
"""Using orjson for serialization of the json column values means that the
output is binary, not `str` like `json.dumps` would output.
SQLAlchemy expects that the json serializer returns `str` and calls `.encode()` on the value to
turn it to bytes before writing to the JSONB column. I'd need to either wrap `orjson.dumps` to
return a `str` so that SQLAlchemy could then convert it to binary, or do the following, which
changes the behaviour of the dialect to expect a binary value from the serializer.
See Also https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934 pylint: disable=line-too-long
"""

def encoder(bin_value: bytes) -> bytes:
# \x01 is the prefix for jsonb used by PostgreSQL.
# asyncpg requires it when format='binary'
return b"\x01" + bin_value

def decoder(bin_value: bytes) -> Any:
# the byte is the \x01 prefix for jsonb used by PostgreSQL.
# asyncpg returns it when format='binary'
return loads(bin_value[1:])

dbapi_connection.await_(
dbapi_connection.driver_connection.set_type_codec(
"jsonb",
encoder=encoder,
decoder=decoder,
schema="pg_catalog",
format="binary",
)
)
__all__ = ["config", "plugin"]


async def before_send_handler(message: "Message", _: "State", scope: "Scope") -> None:
Expand Down Expand Up @@ -115,8 +44,8 @@ async def before_send_handler(message: "Message", _: "State", scope: "Scope") ->
config = SQLAlchemyConfig(
before_send_handler=before_send_handler,
dependency_key=settings.api.DB_SESSION_DEPENDENCY_KEY,
engine_instance=engine,
session_maker_instance=async_session_factory,
engine_instance=db.engine,
session_maker_instance=db.async_session_factory,
)

plugin = SQLAlchemyPlugin(config=config)
2 changes: 1 addition & 1 deletion src/starlite_saqlalchemy/testing/repository.py
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from uuid import uuid4

from starlite_saqlalchemy.orm import Base
from starlite_saqlalchemy.db.orm import Base
from starlite_saqlalchemy.repository.abc import AbstractRepository
from starlite_saqlalchemy.repository.exceptions import RepositoryConflictException

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/conftest.py
Expand Up @@ -14,7 +14,7 @@
from sqlalchemy.pool import NullPool
from starlite import Provide, Router

from starlite_saqlalchemy import orm, sqlalchemy_plugin, worker
from starlite_saqlalchemy import db, sqlalchemy_plugin, worker
from tests.utils import controllers

if TYPE_CHECKING:
Expand Down Expand Up @@ -169,7 +169,7 @@ async def _seed_db(
engine: The SQLAlchemy engine instance.
"""
# get models into metadata
metadata = orm.Base.registry.metadata
metadata = db.orm.Base.registry.metadata
author_table = metadata.tables["author"]
async with engine.begin() as conn:
await conn.run_sync(metadata.create_all)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Expand Up @@ -53,6 +53,7 @@ def _author_repository(raw_authors: list[dict[str, Any]], monkeypatch: pytest.Mo
collection[getattr(author, AuthorRepository.id_attribute)] = author
monkeypatch.setattr(AuthorRepository, "collection", collection)
monkeypatch.setattr(domain, "Repository", AuthorRepository)
monkeypatch.setattr(domain.Service, "repository_type", AuthorRepository)


@pytest.fixture()
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_db.py
@@ -0,0 +1,21 @@
"""Tests for db module."""
# pylint: disable=protected-access
from __future__ import annotations

from uuid import uuid4

import pytest

from starlite_saqlalchemy import db


def test_serializer_default() -> None:
"""Test _default() function serializes UUID."""
val = uuid4()
assert db._default(val) == str(val)


def test_serializer_raises_type_err() -> None:
"""Test _default() function raises ValueError."""
with pytest.raises(TypeError):
db._default(None)
2 changes: 1 addition & 1 deletion tests/unit/test_orm.py
Expand Up @@ -2,7 +2,7 @@
import datetime
from unittest.mock import MagicMock

from starlite_saqlalchemy import orm
from starlite_saqlalchemy.db import orm
from tests.utils.domain import Author, CreateDTO


Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_serializer.py
Expand Up @@ -9,6 +9,7 @@


def test_pg_uuid_serialization() -> None:
"""Test response serializer handles PG UUID."""
py_uuid = uuid4()
pg_uuid = pgproto.UUID(py_uuid.bytes)
assert serializer.default_serializer(pg_uuid) == str(py_uuid)

0 comments on commit fd3b6f6

Please sign in to comment.