diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index 93064d3e..75733cd0 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -8,8 +8,6 @@ from starlite_saqlalchemy.repository.exceptions import RepositoryNotFoundException if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncSession - from .types import FilterTypes __all__ = ["AbstractRepository"] @@ -26,8 +24,8 @@ class AbstractRepository(Generic[T], metaclass=ABCMeta): id_attribute = "id" """Name of the primary identifying attribute on `model_type`.""" - def __init__(self, session: AsyncSession) -> None: - self.session = session + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) @abstractmethod async def add(self, data: T) -> T: diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index a7bb2fc9..89f540f4 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -69,8 +69,11 @@ class SQLAlchemyRepository(AbstractRepository[ModelT]): model_type: type[ModelT] - def __init__(self, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None) -> None: - super().__init__(session) + def __init__( + self, *, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.session = session self._select = select(self.model_type) if select_ is None else select_ async def add(self, data: ModelT) -> ModelT: diff --git a/src/starlite_saqlalchemy/service.py b/src/starlite_saqlalchemy/service.py index ce3e4116..437024bc 100644 --- a/src/starlite_saqlalchemy/service.py +++ b/src/starlite_saqlalchemy/service.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from pydantic import BaseModel from saq.types import Context - from sqlalchemy.ext.asyncio import AsyncSession from starlite_saqlalchemy.repository.abc import AbstractRepository from starlite_saqlalchemy.repository.types import FilterTypes @@ -67,8 +66,8 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: f"__{model_type.__tablename__}DTO", model_type, dto.Purpose.READ ) - def __init__(self, session: AsyncSession) -> None: - self.repository: AbstractRepository[ModelT] = self.repository_type(session) + def __init__(self, **repo_kwargs: Any) -> None: + self.repository: AbstractRepository[ModelT] = self.repository_type(**repo_kwargs) # noinspection PyMethodMayBeStatic async def authorize_create(self, data: ModelT) -> ModelT: diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index 6574077f..45db1275 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -52,7 +52,7 @@ async def test_enqueue_service_callback( """Tests that job enqueued with desired arguments.""" enqueue_mock = AsyncMock() monkeypatch.setattr(worker.queue, "enqueue", enqueue_mock) - service_instance = domain.Service(sqlalchemy_plugin.async_session_factory()) + service_instance = domain.Service(session=sqlalchemy_plugin.async_session_factory()) await service_instance.enqueue_callback( service.Operation.UPDATE, domain.Author(**raw_authors[0]) ) diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 4e40b692..e38fa308 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -12,8 +12,6 @@ if TYPE_CHECKING: from collections import abc - from sqlalchemy.ext.asyncio import AsyncSession - from starlite_saqlalchemy.repository.types import FilterTypes BaseT = TypeVar("BaseT", bound=Base) @@ -27,10 +25,8 @@ class GenericMockRepository(AbstractRepository[BaseT], Generic[BaseT]): collection: "abc.MutableMapping[abc.Hashable, BaseT]" = {} - def __init__( - self, session: "AsyncSession", id_factory: "abc.Callable[[], Any]" = uuid4, **_: Any - ) -> None: - super().__init__(session) + def __init__(self, id_factory: "abc.Callable[[], Any]" = uuid4, **_: Any) -> None: + super().__init__() self._id_factory = id_factory def _find_or_raise_not_found(self, id_: Any) -> BaseT: