From fd3b6f661438862552d9988ecda190d3856a80a0 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Sun, 13 Nov 2022 14:04:09 +1000 Subject: [PATCH] refactor(db)!: move sqlalchemy config into `db` sub-package. --- src/starlite_saqlalchemy/__init__.py | 4 +- src/starlite_saqlalchemy/db/__init__.py | 81 +++++++++++++++++++ src/starlite_saqlalchemy/{ => db}/orm.py | 0 .../repository/sqlalchemy.py | 2 +- src/starlite_saqlalchemy/service.py | 2 +- src/starlite_saqlalchemy/sqlalchemy_plugin.py | 81 ++----------------- .../testing/repository.py | 2 +- tests/integration/conftest.py | 4 +- tests/unit/conftest.py | 1 + tests/unit/test_db.py | 21 +++++ tests/unit/test_orm.py | 2 +- tests/unit/test_serializer.py | 1 + tests/unit/test_service.py | 40 ++++----- tests/unit/test_setting.py | 1 + tests/utils/domain.py | 10 +-- 15 files changed, 137 insertions(+), 115 deletions(-) create mode 100644 src/starlite_saqlalchemy/db/__init__.py rename src/starlite_saqlalchemy/{ => db}/orm.py (100%) create mode 100644 tests/unit/test_db.py diff --git a/src/starlite_saqlalchemy/__init__.py b/src/starlite_saqlalchemy/__init__.py index e61c7ca8..2752b240 100644 --- a/src/starlite_saqlalchemy/__init__.py +++ b/src/starlite_saqlalchemy/__init__.py @@ -25,13 +25,13 @@ def example_handler() -> dict: from . import ( cache, compression, + db, dependencies, dto, exceptions, health, log, openapi, - orm, redis, repository, sentry, @@ -47,13 +47,13 @@ def example_handler() -> dict: "PluginConfig", "cache", "compression", + "db", "dependencies", "dto", "exceptions", "health", "log", "openapi", - "orm", "redis", "repository", "sentry", diff --git a/src/starlite_saqlalchemy/db/__init__.py b/src/starlite_saqlalchemy/db/__init__.py new file mode 100644 index 00000000..bbad44e2 --- /dev/null +++ b/src/starlite_saqlalchemy/db/__init__.py @@ -0,0 +1,81 @@ +"""Database connectivity and transaction management for the application.""" + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Any +from uuid import UUID + +from orjson import dumps, loads +from sqlalchemy import event +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.pool import NullPool + +from starlite_saqlalchemy import settings + +from . import orm + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + +__all__ = ["async_session_factory", "engine", "orm"] + + +def _default(val: Any) -> str: + if isinstance(val, UUID): + return str(val) + raise TypeError() + + +engine = create_async_engine( + settings.db.URL, + echo=settings.db.ECHO, + echo_pool=settings.db.ECHO_POOL, + json_serializer=partial(dumps, default=_default), + max_overflow=settings.db.POOL_MAX_OVERFLOW, + pool_size=settings.db.POOL_SIZE, + pool_timeout=settings.db.POOL_TIMEOUT, + poolclass=NullPool if settings.db.POOL_DISABLE else None, +) +"""Configure via [DatabaseSettings][starlite_saqlalchemy.settings.DatabaseSettings]. Overrides +default JSON serializer to use `orjson`. See +[`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine] for detailed instructions. +""" +async_session_factory: async_sessionmaker[AsyncSession] = async_sessionmaker(engine) +""" +Database session factory. See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker]. +""" + + +@event.listens_for(engine.sync_engine, "connect") +def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any: # pragma: no cover + """Using orjson for serialization of the json column values means that the + output is binary, not `str` like `json.dumps` would output. + + SQLAlchemy expects that the json serializer returns `str` and calls `.encode()` on the value to + turn it to bytes before writing to the JSONB column. I'd need to either wrap `orjson.dumps` to + return a `str` so that SQLAlchemy could then convert it to binary, or do the following, which + changes the behaviour of the dialect to expect a binary value from the serializer. + + See Also https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934 pylint: disable=line-too-long + """ + + def encoder(bin_value: bytes) -> bytes: + # \x01 is the prefix for jsonb used by PostgreSQL. + # asyncpg requires it when format='binary' + return b"\x01" + bin_value + + def decoder(bin_value: bytes) -> Any: + # the byte is the \x01 prefix for jsonb used by PostgreSQL. + # asyncpg returns it when format='binary' + return loads(bin_value[1:]) + + dbapi_connection.await_( + dbapi_connection.driver_connection.set_type_codec( + "jsonb", + encoder=encoder, + decoder=decoder, + schema="pg_catalog", + format="binary", + ) + ) diff --git a/src/starlite_saqlalchemy/orm.py b/src/starlite_saqlalchemy/db/orm.py similarity index 100% rename from src/starlite_saqlalchemy/orm.py rename to src/starlite_saqlalchemy/db/orm.py diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index 1d9d7592..94a16b1f 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -26,7 +26,7 @@ from sqlalchemy.engine import Result from sqlalchemy.ext.asyncio import AsyncSession - from starlite_saqlalchemy import orm + from starlite_saqlalchemy.db import orm from starlite_saqlalchemy.repository.types import FilterTypes __all__ = [ diff --git a/src/starlite_saqlalchemy/service.py b/src/starlite_saqlalchemy/service.py index 4322201b..f9ebb848 100644 --- a/src/starlite_saqlalchemy/service.py +++ b/src/starlite_saqlalchemy/service.py @@ -10,8 +10,8 @@ import logging from typing import TYPE_CHECKING, Any, Generic, TypeVar +from starlite_saqlalchemy.db import async_session_factory from starlite_saqlalchemy.repository.sqlalchemy import ModelT -from starlite_saqlalchemy.sqlalchemy_plugin import async_session_factory from starlite_saqlalchemy.worker import queue if TYPE_CHECKING: diff --git a/src/starlite_saqlalchemy/sqlalchemy_plugin.py b/src/starlite_saqlalchemy/sqlalchemy_plugin.py index cf43d570..0b59ba83 100644 --- a/src/starlite_saqlalchemy/sqlalchemy_plugin.py +++ b/src/starlite_saqlalchemy/sqlalchemy_plugin.py @@ -1,93 +1,22 @@ """Database connectivity and transaction management for the application.""" from __future__ import annotations -from functools import partial -from typing import TYPE_CHECKING, Any, cast -from uuid import UUID +from typing import TYPE_CHECKING, cast -from orjson import dumps, loads -from sqlalchemy import event -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from sqlalchemy.pool import NullPool from starlite.plugins.sql_alchemy import SQLAlchemyConfig, SQLAlchemyPlugin from starlite.plugins.sql_alchemy.config import ( SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, ) -from starlite_saqlalchemy import settings +from starlite_saqlalchemy import db, settings if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession from starlite.datastructures.state import State from starlite.types import Message, Scope -__all__ = [ - "async_session_factory", - "config", - "engine", - "plugin", -] - - -def _default(val: Any) -> str: - if isinstance(val, UUID): - return str(val) - raise TypeError() - - -engine = create_async_engine( - settings.db.URL, - echo=settings.db.ECHO, - echo_pool=settings.db.ECHO_POOL, - json_serializer=partial(dumps, default=_default), - max_overflow=settings.db.POOL_MAX_OVERFLOW, - pool_size=settings.db.POOL_SIZE, - pool_timeout=settings.db.POOL_TIMEOUT, - poolclass=NullPool if settings.db.POOL_DISABLE else None, -) -"""Configure via [DatabaseSettings][starlite_saqlalchemy.settings.DatabaseSettings]. Overrides -default JSON serializer to use `orjson`. See -[`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine] for detailed instructions. -""" -async_session_factory: async_sessionmaker[AsyncSession] = async_sessionmaker(engine) -""" -Database session factory. See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker]. -""" - - -@event.listens_for(engine.sync_engine, "connect") -def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any: - """Using orjson for serialization of the json column values means that the - output is binary, not `str` like `json.dumps` would output. - - SQLAlchemy expects that the json serializer returns `str` and calls `.encode()` on the value to - turn it to bytes before writing to the JSONB column. I'd need to either wrap `orjson.dumps` to - return a `str` so that SQLAlchemy could then convert it to binary, or do the following, which - changes the behaviour of the dialect to expect a binary value from the serializer. - - See Also https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934 pylint: disable=line-too-long - """ - - def encoder(bin_value: bytes) -> bytes: - # \x01 is the prefix for jsonb used by PostgreSQL. - # asyncpg requires it when format='binary' - return b"\x01" + bin_value - - def decoder(bin_value: bytes) -> Any: - # the byte is the \x01 prefix for jsonb used by PostgreSQL. - # asyncpg returns it when format='binary' - return loads(bin_value[1:]) - - dbapi_connection.await_( - dbapi_connection.driver_connection.set_type_codec( - "jsonb", - encoder=encoder, - decoder=decoder, - schema="pg_catalog", - format="binary", - ) - ) +__all__ = ["config", "plugin"] async def before_send_handler(message: "Message", _: "State", scope: "Scope") -> None: @@ -115,8 +44,8 @@ async def before_send_handler(message: "Message", _: "State", scope: "Scope") -> config = SQLAlchemyConfig( before_send_handler=before_send_handler, dependency_key=settings.api.DB_SESSION_DEPENDENCY_KEY, - engine_instance=engine, - session_maker_instance=async_session_factory, + engine_instance=db.engine, + session_maker_instance=db.async_session_factory, ) plugin = SQLAlchemyPlugin(config=config) diff --git a/src/starlite_saqlalchemy/testing/repository.py b/src/starlite_saqlalchemy/testing/repository.py index 778614f7..19953e44 100644 --- a/src/starlite_saqlalchemy/testing/repository.py +++ b/src/starlite_saqlalchemy/testing/repository.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from uuid import uuid4 -from starlite_saqlalchemy.orm import Base +from starlite_saqlalchemy.db.orm import Base from starlite_saqlalchemy.repository.abc import AbstractRepository from starlite_saqlalchemy.repository.exceptions import RepositoryConflictException diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 639544a7..0e8b1153 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -14,7 +14,7 @@ from sqlalchemy.pool import NullPool from starlite import Provide, Router -from starlite_saqlalchemy import orm, sqlalchemy_plugin, worker +from starlite_saqlalchemy import db, sqlalchemy_plugin, worker from tests.utils import controllers if TYPE_CHECKING: @@ -169,7 +169,7 @@ async def _seed_db( engine: The SQLAlchemy engine instance. """ # get models into metadata - metadata = orm.Base.registry.metadata + metadata = db.orm.Base.registry.metadata author_table = metadata.tables["author"] async with engine.begin() as conn: await conn.run_sync(metadata.create_all) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 63dc8b3c..ef378f4b 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -53,6 +53,7 @@ def _author_repository(raw_authors: list[dict[str, Any]], monkeypatch: pytest.Mo collection[getattr(author, AuthorRepository.id_attribute)] = author monkeypatch.setattr(AuthorRepository, "collection", collection) monkeypatch.setattr(domain, "Repository", AuthorRepository) + monkeypatch.setattr(domain.Service, "repository_type", AuthorRepository) @pytest.fixture() diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py new file mode 100644 index 00000000..312d73f3 --- /dev/null +++ b/tests/unit/test_db.py @@ -0,0 +1,21 @@ +"""Tests for db module.""" +# pylint: disable=protected-access +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from starlite_saqlalchemy import db + + +def test_serializer_default() -> None: + """Test _default() function serializes UUID.""" + val = uuid4() + assert db._default(val) == str(val) + + +def test_serializer_raises_type_err() -> None: + """Test _default() function raises ValueError.""" + with pytest.raises(TypeError): + db._default(None) diff --git a/tests/unit/test_orm.py b/tests/unit/test_orm.py index daf57601..fd25ba22 100644 --- a/tests/unit/test_orm.py +++ b/tests/unit/test_orm.py @@ -2,7 +2,7 @@ import datetime from unittest.mock import MagicMock -from starlite_saqlalchemy import orm +from starlite_saqlalchemy.db import orm from tests.utils.domain import Author, CreateDTO diff --git a/tests/unit/test_serializer.py b/tests/unit/test_serializer.py index 72765e87..9054358d 100644 --- a/tests/unit/test_serializer.py +++ b/tests/unit/test_serializer.py @@ -9,6 +9,7 @@ def test_pg_uuid_serialization() -> None: + """Test response serializer handles PG UUID.""" py_uuid = uuid4() pg_uuid = pgproto.UUID(py_uuid.bytes) assert serializer.default_serializer(pg_uuid) == str(py_uuid) diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index 4067ca01..dfae2931 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -9,7 +9,7 @@ import orjson import pytest -from starlite_saqlalchemy import service, sqlalchemy_plugin, worker +from starlite_saqlalchemy import db, service, worker from tests.utils import domain if TYPE_CHECKING: @@ -19,31 +19,22 @@ ServiceType = service.Service[domain.Author] -@pytest.fixture(name="service_obj") -def fx_service() -> ServiceType: - """Service object backed by mock repository.""" - - class Service(service.Service[domain.Author]): - repository_type = domain.Repository - - return Service() - - -async def test_service_create(service_obj: ServiceType) -> None: +async def test_service_create() -> None: """Test repository create action.""" - resp = await service_obj.create(domain.Author(name="someone", dob=date.min)) + resp = await domain.Service().create(domain.Author(name="someone", dob=date.min)) assert resp.name == "someone" assert resp.dob == date.min -async def test_service_list(service_obj: ServiceType) -> None: +async def test_service_list() -> None: """Test repository list action.""" - resp = await service_obj.list() + resp = await domain.Service().list() assert len(resp) == 2 -async def test_service_update(service_obj: ServiceType) -> None: +async def test_service_update() -> None: """Test repository update action.""" + service_obj = domain.Service() author, _ = await service_obj.list() assert author.name == "Agatha Christie" author.name = "different" @@ -51,8 +42,9 @@ async def test_service_update(service_obj: ServiceType) -> None: assert resp.name == "different" -async def test_service_upsert_update(service_obj: ServiceType) -> None: +async def test_service_upsert_update() -> None: """Test repository upsert action for update.""" + service_obj = domain.Service() author, _ = await service_obj.list() assert author.name == "Agatha Christie" author.name = "different" @@ -61,23 +53,25 @@ async def test_service_upsert_update(service_obj: ServiceType) -> None: assert resp.name == "different" -async def test_service_upsert_create(service_obj: ServiceType) -> None: +async def test_service_upsert_create() -> None: """Test repository upsert action for create.""" author = domain.Author(id=uuid4(), name="New Author") - resp = await service_obj.upsert(author.id, author) + resp = await domain.Service().upsert(author.id, author) assert resp.id == author.id assert resp.name == "New Author" -async def test_service_get(service_obj: ServiceType) -> None: +async def test_service_get() -> None: """Test repository get action.""" + service_obj = domain.Service() author, _ = await service_obj.list() retrieved = await service_obj.get(author.id) assert author is retrieved -async def test_service_delete(service_obj: ServiceType) -> None: +async def test_service_delete() -> None: """Test repository delete action.""" + service_obj = domain.Service() author, _ = await service_obj.list() deleted = await service_obj.delete(author.id) assert author is deleted @@ -108,7 +102,7 @@ async def test_make_service_callback( async def test_make_service_callback_raises_runtime_error( - raw_authors: list[dict[str, Any]], monkeypatch: "MonkeyPatch" + raw_authors: list[dict[str, Any]] ) -> None: """Tests loading and retrieval of service object types.""" with pytest.raises(RuntimeError): @@ -125,7 +119,7 @@ async def test_enqueue_service_callback(monkeypatch: "MonkeyPatch") -> None: """Tests that job enqueued with desired arguments.""" enqueue_mock = AsyncMock() monkeypatch.setattr(worker.queue, "enqueue", enqueue_mock) - service_instance = domain.Service(session=sqlalchemy_plugin.async_session_factory()) + service_instance = domain.Service(session=db.async_session_factory()) await service_instance.enqueue_background_task("receive_callback", raw_obj={"a": "b"}) enqueue_mock.assert_called_once_with( "make_service_callback", diff --git a/tests/unit/test_setting.py b/tests/unit/test_setting.py index 26c9aaec..b5230dfe 100644 --- a/tests/unit/test_setting.py +++ b/tests/unit/test_setting.py @@ -1,3 +1,4 @@ +"""Test settings module.""" from starlite_saqlalchemy import settings diff --git a/tests/utils/domain.py b/tests/utils/domain.py index e5a289b4..c248d8e9 100644 --- a/tests/utils/domain.py +++ b/tests/utils/domain.py @@ -3,12 +3,11 @@ from sqlalchemy.orm import Mapped -from starlite_saqlalchemy import dto, orm, service +from starlite_saqlalchemy import db, dto, service from starlite_saqlalchemy.repository.sqlalchemy import SQLAlchemyRepository -from starlite_saqlalchemy.worker import queue -class Author(orm.Base): # pylint: disable=too-few-public-methods +class Author(db.orm.Base): # pylint: disable=too-few-public-methods """The Author domain object.""" name: Mapped[str] @@ -26,11 +25,6 @@ class Service(service.Service[Author]): repository_type = Repository - async def create(self, data: Author) -> Author: - created = await super().create(data) - await queue.enqueue("author_created", data=ReadDTO.from_orm(created).dict()) - return data - CreateDTO = dto.factory("AuthorCreateDTO", Author, purpose=dto.Purpose.WRITE, exclude={"id"}) """