Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

Commit

Permalink
feat(repository): make abc more general. (#82)
Browse files Browse the repository at this point in the history
The ABC was typed to accept a sqlalchemy session object. Changed to
instead receive arbitrary kwargs, but if the base class actually ever
receives kwargs it will error out due to super call to
`object.__init__()`.

The service object accepts arbitrary kwargs that are passed through to
the repository, but doesn't care what they are.

This all means that the only thing that knows and cares about the
sqlalchemy session, is the sqlalchemy repository, and that feels right.

One facet of this approach is that it makes the concept of the
transaction an implementation detail. The sqlalchemy repo has the
concept of session/transaction, but a repository doesn't _have_ to
understand those things. This is consistent with the testing repository
implementation, so happy to see how the pattern pans out.

Closes #54
  • Loading branch information
peterschutt committed Nov 5, 2022
1 parent ab5531f commit 9e89434
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 16 deletions.
6 changes: 2 additions & 4 deletions src/starlite_saqlalchemy/repository/abc.py
Expand Up @@ -8,8 +8,6 @@
from starlite_saqlalchemy.repository.exceptions import RepositoryNotFoundException

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession

from .types import FilterTypes

__all__ = ["AbstractRepository"]
Expand All @@ -26,8 +24,8 @@ class AbstractRepository(Generic[T], metaclass=ABCMeta):
id_attribute = "id"
"""Name of the primary identifying attribute on `model_type`."""

def __init__(self, session: AsyncSession) -> None:
self.session = session
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@abstractmethod
async def add(self, data: T) -> T:
Expand Down
7 changes: 5 additions & 2 deletions src/starlite_saqlalchemy/repository/sqlalchemy.py
Expand Up @@ -69,8 +69,11 @@ class SQLAlchemyRepository(AbstractRepository[ModelT]):

model_type: type[ModelT]

def __init__(self, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None) -> None:
super().__init__(session)
def __init__(
self, *, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None, **kwargs: Any
) -> None:
super().__init__(**kwargs)
self.session = session
self._select = select(self.model_type) if select_ is None else select_

async def add(self, data: ModelT) -> ModelT:
Expand Down
5 changes: 2 additions & 3 deletions src/starlite_saqlalchemy/service.py
Expand Up @@ -19,7 +19,6 @@
if TYPE_CHECKING:
from pydantic import BaseModel
from saq.types import Context
from sqlalchemy.ext.asyncio import AsyncSession

from starlite_saqlalchemy.repository.abc import AbstractRepository
from starlite_saqlalchemy.repository.types import FilterTypes
Expand Down Expand Up @@ -67,8 +66,8 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
f"__{model_type.__tablename__}DTO", model_type, dto.Purpose.READ
)

def __init__(self, session: AsyncSession) -> None:
self.repository: AbstractRepository[ModelT] = self.repository_type(session)
def __init__(self, **repo_kwargs: Any) -> None:
self.repository: AbstractRepository[ModelT] = self.repository_type(**repo_kwargs)

# noinspection PyMethodMayBeStatic
async def authorize_create(self, data: ModelT) -> ModelT:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_service.py
Expand Up @@ -52,7 +52,7 @@ async def test_enqueue_service_callback(
"""Tests that job enqueued with desired arguments."""
enqueue_mock = AsyncMock()
monkeypatch.setattr(worker.queue, "enqueue", enqueue_mock)
service_instance = domain.Service(sqlalchemy_plugin.async_session_factory())
service_instance = domain.Service(session=sqlalchemy_plugin.async_session_factory())
await service_instance.enqueue_callback(
service.Operation.UPDATE, domain.Author(**raw_authors[0])
)
Expand Down
8 changes: 2 additions & 6 deletions tests/unit/utils.py
Expand Up @@ -12,8 +12,6 @@
if TYPE_CHECKING:
from collections import abc

from sqlalchemy.ext.asyncio import AsyncSession

from starlite_saqlalchemy.repository.types import FilterTypes

BaseT = TypeVar("BaseT", bound=Base)
Expand All @@ -27,10 +25,8 @@ class GenericMockRepository(AbstractRepository[BaseT], Generic[BaseT]):

collection: "abc.MutableMapping[abc.Hashable, BaseT]" = {}

def __init__(
self, session: "AsyncSession", id_factory: "abc.Callable[[], Any]" = uuid4, **_: Any
) -> None:
super().__init__(session)
def __init__(self, id_factory: "abc.Callable[[], Any]" = uuid4, **_: Any) -> None:
super().__init__()
self._id_factory = id_factory

def _find_or_raise_not_found(self, id_: Any) -> BaseT:
Expand Down

0 comments on commit 9e89434

Please sign in to comment.