diff --git a/README.md b/README.md index b258514..e426989 100644 --- a/README.md +++ b/README.md @@ -10,18 +10,18 @@ ## Project Status -- [ ] Overall Completion - 80% done +- [ ] Overall Completion - 85% done - [ ] Tests - 90% Complete -- [x] Model class to create SQLAlchemy Models and Declarative Base -- [x] SQLAlchemy model export to dictionary through `model.dict(exclude={'x', 'y', 'z'})` -- [x] Support multiple database configurations in Models and in Query -- [x] Session to manage Model metadata -- [x] Service to manage Engine and Session creation and Migration initialization for async and sync Engines and Sessions -- [x] Alembic env.py with async first `run_migrations_online` action -- [x] Expose all alembic commands to Ellar CLI -- [x] Module to config and setup SQLAlchemy dependencies and migration -- [ ] SQLAlchemy Pagination for both templating and API routes -- [x] File and Image SQLAlchemy Columns integrated with ellar storage +- [x] Model class that transforms to SQLAlchemy Models or DeclarativeBase based on configuration +- [x] Pydantic-like way to exporting model to dictionary object eg:`model.dict(exclude={'x', 'y', 'z'})` +- [x] Support multiple database useful in models and queries +- [x] Session management during request scope and outside request +- [x] Service to manage SQLAlchemy Engine and Session creation, and Migration for async and sync Engines and Sessions +- [x] Async first approach to database migration using Alembic +- [x] Expose all Alembic commands to Ellar CLI +- [x] Module to config and setup SQLAlchemy dependencies and migration path +- [x] SQLAlchemy Pagination for both Templating and API routes +- [x] File and Image SQLAlchemy Columns integration with ellar storage - [ ] SQLAlchemy Django Like Query - [ ] Documentation diff --git a/ellar_sqlalchemy/__init__.py b/ellar_sqlalchemy/__init__.py index bee507f..82c1352 100644 --- a/ellar_sqlalchemy/__init__.py +++ b/ellar_sqlalchemy/__init__.py @@ -3,6 +3,15 @@ __version__ = "0.0.1" from .module import EllarSQLAlchemyModule +from .pagination import LimitOffsetPagination, PageNumberPagination, paginate +from .query import ( + first_or_404, + first_or_404_async, + get_or_404, + get_or_404_async, + one_or_404, + one_or_404_async, +) from .schemas import MigrationOption, SQLAlchemyConfig from .services import EllarSQLAlchemyService @@ -11,4 +20,13 @@ "EllarSQLAlchemyService", "SQLAlchemyConfig", "MigrationOption", + "get_or_404_async", + "get_or_404", + "first_or_404_async", + "first_or_404", + "one_or_404_async", + "one_or_404", + "paginate", + "PageNumberPagination", + "LimitOffsetPagination", ] diff --git a/ellar_sqlalchemy/constant.py b/ellar_sqlalchemy/constant.py index 0e95305..52f94ba 100644 --- a/ellar_sqlalchemy/constant.py +++ b/ellar_sqlalchemy/constant.py @@ -5,7 +5,7 @@ DATABASE_KEY = "__database__" TABLE_KEY = "__table__" ABSTRACT_KEY = "__abstract__" - +PAGINATION_OPTIONS = "__PAGINATION_OPTIONS__" NAMING_CONVERSION = { "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", diff --git a/ellar_sqlalchemy/pagination/__init__.py b/ellar_sqlalchemy/pagination/__init__.py new file mode 100644 index 0000000..598216e --- /dev/null +++ b/ellar_sqlalchemy/pagination/__init__.py @@ -0,0 +1,11 @@ +from .base import Paginator, PaginatorBase +from .decorator import paginate +from .view import LimitOffsetPagination, PageNumberPagination + +__all__ = [ + "Paginator", + "PaginatorBase", + "paginate", + "PageNumberPagination", + "LimitOffsetPagination", +] diff --git a/ellar_sqlalchemy/pagination/base.py b/ellar_sqlalchemy/pagination/base.py new file mode 100644 index 0000000..212c42d --- /dev/null +++ b/ellar_sqlalchemy/pagination/base.py @@ -0,0 +1,325 @@ +import typing as t +from abc import abstractmethod +from math import ceil + +import ellar.common as ecm +import sqlalchemy as sa +import sqlalchemy.orm as sa_orm +from ellar.app import current_injector +from ellar.threading import execute_coroutine_with_sync_worker +from sqlalchemy.ext.asyncio import AsyncSession + +from ellar_sqlalchemy.model.base import ModelBase +from ellar_sqlalchemy.services import EllarSQLAlchemyService + + +class PaginatorBase: + def __init__( + self, + page: int = 1, + per_page: int = 20, + max_per_page: t.Optional[int] = 100, + error_out: bool = True, + count: bool = True, + ) -> None: + page, per_page = self._prepare_page_args( + page=page, + per_page=per_page, + max_per_page=max_per_page, + error_out=error_out, + ) + + self.page: int = page + """The current page.""" + + self.per_page: int = per_page + """The maximum number of items on a page.""" + + self.max_per_page: t.Optional[int] = max_per_page + """The maximum allowed value for ``per_page``.""" + + items = self._query_items() + + if not items and page != 1 and error_out: + raise ecm.NotFound() + + self.items: t.List[t.Any] = items + """The items on the current page. Iterating over the pagination object is + equivalent to iterating over the items. + """ + + if count: + total = self._query_count() + else: + total = None + + self.total: t.Optional[int] = total + """The total number of items across all pages.""" + + def _prepare_page_args( + self, + *, + page: int, + per_page: int, + max_per_page: t.Optional[int], + error_out: bool, + ) -> t.Tuple[int, int]: + if max_per_page is not None: + per_page = min(per_page, max_per_page) + + if page < 1: + if error_out: + raise ecm.NotFound() + else: + page = 1 + + if per_page < 1: + if error_out: + raise ecm.NotFound() + else: + per_page = 20 + + return page, per_page + + @property + def _query_offset(self) -> int: + """ """ + return (self.page - 1) * self.per_page + + @abstractmethod + def _query_items(self) -> t.List[t.Any]: + """Execute the query to get the items on the current page.""" + + @abstractmethod + def _get_init_kwargs(self) -> t.Dict[str, t.Any]: + """Returns dictionary of other attributes a child class requires for initialization""" + + @abstractmethod + def _query_count(self) -> int: + """Execute the query to get the total number of items.""" + + @property + def first(self) -> int: + """The number of the first item on the page, starting from 1, or 0 if there are + no items. + """ + if len(self.items) == 0: + return 0 + + return (self.page - 1) * self.per_page + 1 + + @property + def last(self) -> int: + """The number of the last item on the page, starting from 1, inclusive, or 0 if + there are no items. + """ + first = self.first + return max(first, first + len(self.items) - 1) + + @property + def pages(self) -> int: + """The total number of pages.""" + if self.total == 0 or self.total is None: + return 0 + + return ceil(self.total / self.per_page) + + @property + def has_prev(self) -> bool: + """`True` if this is not the first page.""" + return self.page > 1 + + @property + def prev_num(self) -> t.Optional[int]: + """The previous page number, or `None` if this is the first page.""" + if not self.has_prev: + return None + + return self.page - 1 + + def prev(self, *, error_out: bool = False) -> "PaginatorBase": + """Query the pagination object for the previous page. + + :param error_out: Raise `404 Not Found` error if no items are returned + and `page` is not 1, or if `page` or `per_page` is less than 1. + """ + init_kwargs = self._get_init_kwargs() + init_kwargs.update( + page=self.page - 1, + per_page=self.per_page, + error_out=error_out, + count=False, + ) + p = self.__class__(**init_kwargs) + p.total = self.total + return p + + @property + def has_next(self) -> bool: + """`True` if this is not the last page.""" + return self.page < self.pages + + @property + def next_num(self) -> t.Optional[int]: + """The next page number, or `None` if this is the last page.""" + if not self.has_next: + return None + + return self.page + 1 + + def next(self, *, error_out: bool = False) -> "PaginatorBase": + """ + Query the Pagination object for the next page. + + :param error_out: Raise `404 Not Found` error if no items are returned + and `page` is not 1, or if `page` or `per_page` is less than 1. + """ + init_kwargs = self._get_init_kwargs() + init_kwargs.update( + page=self.page + 1, + per_page=self.per_page, + max_per_page=self.max_per_page, + error_out=error_out, + count=False, + ) + p = self.__class__(**init_kwargs) + p.total = self.total + return p + + def iter_pages( + self, + *, + left_edge: int = 2, + left_current: int = 2, + right_current: int = 4, + right_edge: int = 2, + ) -> t.Iterator[t.Optional[int]]: + """ + Yield page numbers for a pagination widget. + A None represents skipped pages between the edges and middle. + + For example, if there are 20 pages and the current page is 7, the following + values are yielded. + + For example: + 1, 2, None, 5, 6, 7, 8, 9, 10, 11, None, 19, 20 + + :param left_edge: How many pages to show from the first page. + :param left_current: How many pages to show left of the current page. + :param right_current: How many pages to show right of the current page. + :param right_edge: How many pages to show from the last page. + + """ + pages_end = self.pages + 1 + + if pages_end == 1: + return + + left_end = min(1 + left_edge, pages_end) + yield from range(1, left_end) + + if left_end == pages_end: + return + + mid_start = max(left_end, self.page - left_current) + mid_end = min(self.page + right_current + 1, pages_end) + + if mid_start - left_end > 0: + yield None + + yield from range(mid_start, mid_end) + + if mid_end == pages_end: + return + + right_start = max(mid_end, pages_end - right_edge) + + if right_start - mid_end > 0: + yield None + + yield from range(right_start, pages_end) + + def __iter__(self) -> t.Iterator[t.Any]: + yield from self.items + + +class Paginator(PaginatorBase): + def __init__( + self, + model: t.Union[t.Type[ModelBase], sa.sql.Select[t.Any]], + session: t.Optional[t.Union[sa_orm.Session, AsyncSession]] = None, + page: int = 1, + per_page: int = 20, + max_per_page: t.Optional[int] = 100, + error_out: bool = True, + count: bool = True, + ) -> None: + if isinstance(model, type) and issubclass(model, ModelBase): + self._select = sa.select(model) + else: + self._select = t.cast(sa.sql.Select, model) + + self._created_session = False + + self._session: t.Union[sa_orm.Session, AsyncSession] = ( + session or self._get_session() + ) + self._is_async = self._session.get_bind().dialect.is_async + + super().__init__( + page=page, + per_page=per_page, + max_per_page=max_per_page, + error_out=error_out, + count=count, + ) + + if self._created_session: + self._session.close() # session usage is done but only if Paginator created the session + + def _get_session(self) -> t.Union[sa_orm.Session, AsyncSession, t.Any]: + self._created_session = True + service = current_injector.get(EllarSQLAlchemyService) + return service.get_scoped_session()() + + def _query_items(self) -> t.List[t.Any]: + if self._is_async: + res = execute_coroutine_with_sync_worker(self._query_items_async()) + return list(res) + return self._query_items_sync() + + def _query_items_sync(self) -> t.List[t.Any]: + select = self._select.limit(self.per_page).offset(self._query_offset) + return list(self._session.execute(select).unique().scalars()) + + async def _query_items_async(self) -> t.List[t.Any]: + session = t.cast(AsyncSession, self._session) + + select = self._select.limit(self.per_page).offset(self._query_offset) + res = await session.execute(select) + + return list(res.unique().scalars()) + + def _query_count(self) -> int: + if self._is_async: + res = execute_coroutine_with_sync_worker(self._query_count_async()) + return int(res) + return self._query_count_sync() + + def _query_count_sync(self) -> int: + sub = self._select.options(sa_orm.lazyload("*")).order_by(None).subquery() + out = self._session.execute( + sa.select(sa.func.count()).select_from(sub) + ).scalar() + return out # type:ignore[return-value] + + async def _query_count_async(self) -> int: + session = t.cast(AsyncSession, self._session) + + sub = self._select.options(sa_orm.lazyload("*")).order_by(None).subquery() + + out = await session.execute(sa.select(sa.func.count()).select_from(sub)) + return out.scalar() # type:ignore[return-value] + + def _get_init_kwargs(self) -> t.Dict[str, t.Any]: + return {"model": self._select} diff --git a/ellar_sqlalchemy/pagination/decorator.py b/ellar_sqlalchemy/pagination/decorator.py new file mode 100644 index 0000000..cc5c996 --- /dev/null +++ b/ellar_sqlalchemy/pagination/decorator.py @@ -0,0 +1,191 @@ +import asyncio +import functools +import typing as t +import uuid + +import ellar.common as ecm +import sqlalchemy as sa +from ellar.common import set_metadata +from ellar.common.constants import EXTRA_ROUTE_ARGS_KEY, RESPONSE_OVERRIDE_KEY +from pydantic import BaseModel + +from ellar_sqlalchemy.model.base import ModelBase + +from .view import PageNumberPagination, PaginationBase + + +def paginate( + pagination_class: t.Optional[t.Type[PaginationBase]] = None, + model: t.Optional[t.Type[ModelBase]] = None, + template_context: bool = False, + item_schema: t.Optional[t.Type[BaseModel]] = None, + **paginator_options: t.Any, +) -> t.Callable: + """ + =========ROUTE FUNCTION DECORATOR ============== + + :param pagination_class: Pagination Class of type PaginationBase + :param model: SQLAlchemy Model or SQLAlchemy Select Statement + :param template_context: If True adds `paginator` object to templating context data + :param item_schema: Pagination Object Schema for serializing object and creating response schema documentation + :param paginator_options: Other keyword args for initializing `pagination_class` + :return: TCallable + """ + paginator_options.update(model=model) + + def _wraps(func: t.Callable) -> t.Callable: + operation_class = ( + _AsyncPaginationOperation + if asyncio.iscoroutinefunction(func) + else _PaginationOperation + ) + operation = operation_class( + route_function=func, + pagination_class=pagination_class or PageNumberPagination, + template_context=template_context, + item_schema=item_schema, + paginator_options=paginator_options, + ) + + return operation.as_view + + return _wraps + + +class _PaginationOperation: + def __init__( + self, + route_function: t.Callable, + pagination_class: t.Type[PaginationBase], + paginator_options: t.Dict[str, t.Any], + template_context: bool = False, + item_schema: t.Optional[t.Type[BaseModel]] = None, + ) -> None: + self._original_route_function = route_function + self._pagination_view = pagination_class(**paginator_options) + _, _, view = self._get_route_function_wrapper(template_context, item_schema) + self.as_view = functools.wraps(route_function)(view) + + def _prepare_template_response( + self, res: t.Any + ) -> t.Tuple[ + t.Optional[t.Union[t.Type[ModelBase], sa.sql.Select[t.Any]]], t.Dict[str, t.Any] + ]: + if isinstance(res, tuple): + filter_query, extra_context = res + assert isinstance( + extra_context, dict + ), "When using as `template_context`, route function should return a tuple(select, {})" + + elif isinstance(res, dict): + filter_query = None + extra_context = res + + elif ( + isinstance(res, sa.sql.Select) + or isinstance(res, type) + and issubclass(res, ModelBase) + ): + filter_query = res + extra_context = {} + else: + raise RuntimeError( + f"Invalid datastructure returned from route function. - {res}" + ) + + return filter_query, extra_context + + def _get_route_function_wrapper( + self, template_context: bool, item_schema: t.Type[BaseModel] + ) -> t.Tuple[ecm.params.ExtraEndpointArg, ecm.params.ExtraEndpointArg, t.Callable]: + unique_id = str(uuid.uuid4()) + # use unique_id to make the kwargs difficult to collide with any route function parameter + _paginate_args = ecm.params.ExtraEndpointArg( + name=f"paginate_{unique_id[:-6]}", + annotation=self._pagination_view.get_annotation(), + ) + # use unique_id to make the kwargs difficult to collide with any route function parameter + execution_context = ecm.params.ExtraEndpointArg( + name=f"context_{unique_id[:-6]}", + annotation=ecm.Inject[ecm.IExecutionContext], + ) + + set_metadata(EXTRA_ROUTE_ARGS_KEY, [_paginate_args, execution_context])( + self._original_route_function + ) + + if not template_context and not item_schema: + raise ecm.exceptions.ImproperConfiguration( + "Must supply value for either `template_context` or `item_schema`" + ) + + if not template_context: + # if pagination is not for template context, then we create a response schema for the api response + response_schema = self._pagination_view.get_output_schema(item_schema) + ecm.set_metadata(RESPONSE_OVERRIDE_KEY, {200: response_schema})( + self._original_route_function + ) + + def as_view(*args: t.Any, **kw: t.Any) -> t.Any: + func_kwargs = dict(**kw) + paginate_input = _paginate_args.resolve(func_kwargs) + context: ecm.IExecutionContext = execution_context.resolve(func_kwargs) + + items = self._original_route_function(*args, **func_kwargs) + + if not template_context: + return self._pagination_view.api_paginate( + items, + paginate_input, + context.switch_to_http_connection().get_request(), + ) + + filter_query, extra_context = self._prepare_template_response(items) + + pagination_context = self._pagination_view.pagination_context( + filter_query, + paginate_input, + context.switch_to_http_connection().get_request(), + ) + extra_context.update(pagination_context) + + return extra_context + + return _paginate_args, execution_context, as_view + + +class _AsyncPaginationOperation(_PaginationOperation): + def _get_route_function_wrapper( + self, template_context: bool, item_schema: t.Type[BaseModel] + ) -> t.Tuple[ecm.params.ExtraEndpointArg, ecm.params.ExtraEndpointArg, t.Callable]: + _paginate_args, execution_context, _ = super()._get_route_function_wrapper( + template_context, item_schema + ) + + async def as_view(*args: t.Any, **kw: t.Any) -> t.Any: + func_kwargs = dict(**kw) + + paginate_input = _paginate_args.resolve(func_kwargs) + context: ecm.IExecutionContext = execution_context.resolve(func_kwargs) + + items = await self._original_route_function(*args, **func_kwargs) + + if not template_context: + return self._pagination_view.api_paginate( + items, + paginate_input, + context.switch_to_http_connection().get_request(), + ) + + filter_query, extra_context = self._prepare_template_response(items) + + pagination_context = self._pagination_view.pagination_context( + filter_query, + paginate_input, + context.switch_to_http_connection().get_request(), + ) + extra_context.update(pagination_context) + + return extra_context + + return _paginate_args, execution_context, as_view diff --git a/ellar_sqlalchemy/pagination/utils.py b/ellar_sqlalchemy/pagination/utils.py new file mode 100644 index 0000000..394b3d0 --- /dev/null +++ b/ellar_sqlalchemy/pagination/utils.py @@ -0,0 +1,25 @@ +from urllib import parse + + +def replace_query_param(url: str, key: str, val: int) -> str: + """ + Given a URL and a key/val pair, set or replace an item in the query + parameters of the URL, and return the new URL. + """ + (scheme, netloc, path, query, fragment) = parse.urlsplit(str(url)) + query_dict = parse.parse_qs(query, keep_blank_values=True) + query_dict[str(key)] = [str("{}".format(val))] + query = parse.urlencode(sorted(query_dict.items()), doseq=True) + return parse.urlunsplit((scheme, netloc, path, query, fragment)) + + +def remove_query_param(url: str, key: str) -> str: + """ + Given a URL and a key/val pair, remove an item in the query + parameters of the URL, and return the new URL. + """ + (scheme, netloc, path, query, fragment) = parse.urlsplit(str(url)) + query_dict = parse.parse_qs(query, keep_blank_values=True) + query_dict.pop(key, None) + query = parse.urlencode(sorted(query_dict.items()), doseq=True) + return parse.urlunsplit((scheme, netloc, path, query, fragment)) diff --git a/ellar_sqlalchemy/pagination/view.py b/ellar_sqlalchemy/pagination/view.py new file mode 100644 index 0000000..4327cfb --- /dev/null +++ b/ellar_sqlalchemy/pagination/view.py @@ -0,0 +1,242 @@ +import typing as t +from abc import ABC, abstractmethod +from collections import OrderedDict + +import ellar.common as ecm +import ellar.core as ec +import sqlalchemy as sa +from pydantic import BaseModel, Field + +from ellar_sqlalchemy.model.base import ModelBase +from ellar_sqlalchemy.schemas import BasicPaginationSchema, PageNumberPaginationSchema + +from .base import Paginator +from .utils import remove_query_param, replace_query_param + + +class PaginationBase(ABC): + InputSource = ecm.Query + + class Input(BaseModel): + pass + + def get_annotation(self) -> t.Any: + return self.InputSource[self.Input, self.InputSource.P(...)] + + @abstractmethod + def get_output_schema( + self, item_schema: t.Type[BaseModel] + ) -> t.Type[BaseModel]: # pragma: no cover + """Return a Response Schema Type for item schema""" + + def validate_model( + self, + model: t.Union[t.Type[ModelBase], sa.sql.Select[t.Any]], + fallback: t.Optional[t.Union[t.Type[ModelBase], sa.sql.Select[t.Any]]], + ) -> t.Union[t.Type[ModelBase], sa.sql.Select[t.Any]]: + if isinstance(model, sa.sql.Select): + working_model = model + else: + working_model = model or fallback # type:ignore[assignment] + assert working_model, "Model Can not be None" + return working_model + + @abstractmethod + def api_paginate( + self, + model: t.Union[t.Type[ModelBase], sa.sql.Select[t.Any]], + input_schema: t.Any, + request: ec.Request, + **params: t.Any, + ) -> t.Any: + pass # pragma: no cover + + @abstractmethod + def pagination_context( + self, + model: t.Union[t.Type[ModelBase], sa.sql.Select[t.Any]], + input_schema: t.Any, + request: ec.Request, + **params: t.Any, + ) -> t.Dict[str, t.Any]: + pass # pragma: no cover + + if t.TYPE_CHECKING: + + def __init__(self, **kwargs: t.Any) -> None: + ... + + +class PageNumberPagination(PaginationBase): + class Input(BaseModel): + page: int = Field(1, gt=0) + + paginator_class: t.Type[Paginator] = Paginator + page_query_param: str = "page" + + def __init__( + self, + *, + model: t.Optional[t.Type[ModelBase]] = None, + per_page: int = 20, + max_per_page: int = 100, + error_out: bool = True, + ) -> None: + super().__init__() + self._model = model + self._paginator_init_kwargs = { + "per_page": per_page, + "max_per_page": max_per_page, + "error_out": error_out, + } + + def get_output_schema(self, item_schema: t.Type[BaseModel]) -> t.Type[BaseModel]: + return PageNumberPaginationSchema[item_schema] # type:ignore[valid-type] + + def api_paginate( + self, + model: t.Union[t.Type[ModelBase], sa.sql.Select, t.Any], + input_schema: Input, + request: ec.Request, + **params: t.Any, + ) -> t.Any: + working_model = self.validate_model(model, self._model) + + paginator = self.paginator_class( + model=working_model, page=input_schema.page, **self._paginator_init_kwargs + ) + return self._get_paginated_response( + base_url=str(request.url), paginator=paginator + ) + + def pagination_context( + self, + model: t.Union[t.Type, sa.sql.Select], + input_schema: Input, + request: ec.Request, + **params: t.Any, + ) -> t.Dict[str, t.Any]: + working_model = self.validate_model(model, self._model) + + paginator = self.paginator_class( + model=working_model, page=input_schema.page, **self._paginator_init_kwargs + ) + return {"paginator": paginator} + + def _get_paginated_response( + self, *, base_url: str, paginator: Paginator + ) -> t.Dict[str, t.Any]: + is_query = self.InputSource.name == "Query" + next_url = ( + self._get_next_link(base_url, paginator=paginator) if is_query else None + ) + + prev_url = ( + self._get_previous_link(base_url, paginator=paginator) if is_query else None + ) + return OrderedDict( + [ + ("count", paginator.total), + ("next", next_url), + ("previous", prev_url), + ("items", list(paginator)), + ] + ) + + def _get_next_link(self, url: str, paginator: Paginator) -> t.Optional[str]: + if not paginator.has_next: + return None + page_number = paginator.page + 1 + return replace_query_param(url, self.page_query_param, page_number) + + def _get_previous_link(self, url: str, paginator: Paginator) -> t.Optional[str]: + if not paginator.has_prev: + return None + page_number = paginator.page - 1 + if page_number == 1: + return remove_query_param(url, self.page_query_param) + return replace_query_param(url, self.page_query_param, page_number) + + +class LimitOffsetPagination(PaginationBase): + class Input(BaseModel): + limit: int = Field(50, ge=1) + offset: int = Field(0, ge=0) + + paginator_class: t.Type[Paginator] = Paginator + + def __init__( + self, + *, + model: t.Optional[t.Type[ModelBase]] = None, + limit: int = 50, + max_limit: int = 100, + error_out: bool = True, + ) -> None: + super().__init__() + self._model = model + self._max_limit = max_limit + self._error_out = error_out + + self._paginator_init_kwargs = { + "error_out": error_out, + "max_per_page": max_limit, + } + self.Input = self.create_input(limit) # type:ignore[misc] + + def create_input(self, limit: int) -> t.Type[Input]: + _limit = int(limit) + + class DynamicInput(self.Input): + limit: int = Field(_limit, ge=1) + offset: int = Field(0, ge=0) + + return DynamicInput + + def get_output_schema(self, item_schema: t.Type[BaseModel]) -> t.Type[BaseModel]: + return BasicPaginationSchema[item_schema] + + def api_paginate( + self, + model: t.Union[t.Type[ModelBase], sa.sql.Select[t.Any], t.Any], + input_schema: Input, + request: ec.Request, + **params: t.Any, + ) -> t.Any: + working_model = self.validate_model(model, self._model) + + page = input_schema.offset or 1 + per_page: int = min(input_schema.limit, self._max_limit) + + paginator = self.paginator_class( + model=working_model, + page=page, + per_page=per_page, + **self._paginator_init_kwargs, + ) + return OrderedDict( + [ + ("count", paginator.total), + ("items", list(paginator)), + ] + ) + + def pagination_context( + self, + model: t.Union[t.Type, sa.sql.Select[t.Any]], + input_schema: Input, + request: ec.Request, + **params: t.Any, + ) -> t.Dict[str, t.Any]: + working_model = self.validate_model(model, self._model) + + page = input_schema.offset or 1 + per_page: int = min(input_schema.limit, self._max_limit) + + paginator = self.paginator_class( + model=working_model, + page=page, + per_page=per_page, + **self._paginator_init_kwargs, + ) + return {"paginator": paginator} diff --git a/ellar_sqlalchemy/query/__init__.py b/ellar_sqlalchemy/query/__init__.py new file mode 100644 index 0000000..8c8238a --- /dev/null +++ b/ellar_sqlalchemy/query/__init__.py @@ -0,0 +1,17 @@ +from .utils import ( + first_or_404, + first_or_404_async, + get_or_404, + get_or_404_async, + one_or_404, + one_or_404_async, +) + +__all__ = [ + "get_or_404_async", + "get_or_404", + "one_or_404_async", + "one_or_404", + "first_or_404_async", + "first_or_404", +] diff --git a/ellar_sqlalchemy/query/utils.py b/ellar_sqlalchemy/query/utils.py new file mode 100644 index 0000000..13aa40f --- /dev/null +++ b/ellar_sqlalchemy/query/utils.py @@ -0,0 +1,106 @@ +import typing as t + +import ellar.common as ecm +import sqlalchemy as sa +import sqlalchemy.exc as sa_exc +from ellar.app import current_injector + +from ellar_sqlalchemy.services import EllarSQLAlchemyService + +_O = t.TypeVar("_O", bound=object) + + +def get_or_404( + entity: t.Type[_O], + ident: t.Any, + *, + error_message: t.Optional[str] = None, + **kwargs: t.Any, +) -> _O: + """ """ + db_service = current_injector.get(EllarSQLAlchemyService) + session = db_service.session_factory() + + value = session.get(entity, ident, **kwargs) + + if value is None: + raise ecm.NotFound(detail=error_message) + + return t.cast(_O, value) + + +async def get_or_404_async( + entity: t.Type[_O], + ident: t.Any, + *, + error_message: t.Optional[str] = None, + **kwargs: t.Any, +) -> _O: + """ """ + db_service = current_injector.get(EllarSQLAlchemyService) + session = db_service.session_factory() + + value = await session.get(entity, ident, **kwargs) + + if value is None: + raise ecm.NotFound(detail=error_message) + + return t.cast(_O, value) + + +def first_or_404( + statement: sa.sql.Select[t.Any], *, error_message: t.Optional[str] = None +) -> t.Any: + """ """ + db_service = current_injector.get(EllarSQLAlchemyService) + session = db_service.session_factory() + + value = session.execute(statement).scalar() + + if value is None: + raise ecm.NotFound(detail=error_message) + + return value + + +async def first_or_404_async( + statement: sa.sql.Select[t.Any], *, error_message: t.Optional[str] = None +) -> t.Any: + """ """ + db_service = current_injector.get(EllarSQLAlchemyService) + session = db_service.session_factory() + + res = await session.execute(statement) + value = res.scalar() + + if value is None: + raise ecm.NotFound(detail=error_message) + + return value + + +def one_or_404( + statement: sa.sql.Select[t.Any], *, error_message: t.Optional[str] = None +) -> t.Any: + """ """ + db_service = current_injector.get(EllarSQLAlchemyService) + session = db_service.session_factory() + + try: + return session.execute(statement).scalar_one() + except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound) as ex: + raise ecm.NotFound(detail=error_message) from ex + + +async def one_or_404_async( + statement: sa.sql.Select[t.Any], *, error_message: t.Optional[str] = None +) -> t.Any: + """ """ + db_service = current_injector.get(EllarSQLAlchemyService) + session = db_service.session_factory() + + try: + res = await session.execute(statement) + return res.scalar_one() + except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound) as ex: + raise ecm.NotFound(detail=error_message) from ex diff --git a/ellar_sqlalchemy/schemas.py b/ellar_sqlalchemy/schemas.py index eef00f6..6770be1 100644 --- a/ellar_sqlalchemy/schemas.py +++ b/ellar_sqlalchemy/schemas.py @@ -4,6 +4,33 @@ import ellar.common as ecm import sqlalchemy.orm as sa_orm +from pydantic import BaseModel, BeforeValidator, HttpUrl, TypeAdapter +from typing_extensions import Annotated + +T = t.TypeVar("T") + +Url = Annotated[ + str, BeforeValidator(lambda value: str(TypeAdapter(HttpUrl).validate_python(value))) +] + + +class BasePaginatedResponseSchema(BaseModel): + count: int + next: t.Optional[Url] + previous: t.Optional[Url] + results: t.List[t.Any] + + +class BasicPaginationSchema(BaseModel, t.Generic[T]): + count: int + items: t.List[T] + + +class PageNumberPaginationSchema(BaseModel, t.Generic[T]): + count: int + next: t.Optional[Url] + previous: t.Optional[Url] + items: t.List[T] @dataclass diff --git a/pyproject.toml b/pyproject.toml index f983a46..0f102ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ ] dependencies = [ - "ellar >= 0.6.6", + "ellar >= 0.6.7", "sqlalchemy >= 2.0.23", "alembic >= 1.10.0", "python-magic >= 0.4.27", @@ -88,7 +88,7 @@ known-third-party = ["ellar"] python_version = "3.8" show_error_codes = true pretty = true -strict = true +strict_optional = true disable_error_code = ["name-defined", 'union-attr'] disallow_subclassing_any = false [[tool.mypy.overrides]] @@ -98,5 +98,8 @@ ignore_errors = true module = "ellar_sqlalchemy.migrations.*" disable_error_code = ["arg-type", 'union-attr'] [[tool.mypy.overrides]] +module = "ellar_sqlalchemy.pagination.*" +disable_error_code = ["arg-type", 'union-attr', 'valid-type'] +[[tool.mypy.overrides]] module = "ellar_sqlalchemy.model.base" disable_error_code = ["attr-defined", 'union-attr'] diff --git a/tests/conftest.py b/tests/conftest.py index d24fe97..105eb15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -86,7 +86,8 @@ def db_service_async(tmp_path) -> EllarSQLAlchemyService: @pytest.fixture() def app_setup(tmp_path): def _setup(**kwargs): - kwargs.setdefault( + sql_module = kwargs.pop("sql_module", {}) + sql_module.setdefault( "databases", { "default": { @@ -96,9 +97,12 @@ def _setup(**kwargs): } }, ) - kwargs.setdefault("migration_options", {"directory": "migrations"}) + sql_module.setdefault("migration_options", {"directory": "migrations"}) tm = Test.create_test_module( - modules=[EllarSQLAlchemyModule.setup(root_path=str(tmp_path), **kwargs)] + modules=[ + EllarSQLAlchemyModule.setup(root_path=str(tmp_path), **sql_module) + ], + **kwargs, ) return tm.create_application() @@ -108,7 +112,8 @@ def _setup(**kwargs): @pytest.fixture() def app_setup_async(tmp_path): def _setup(**kwargs): - kwargs.setdefault( + sql_module = kwargs.pop("sql_module", {}) + sql_module.setdefault( "databases", { "default": { @@ -118,10 +123,29 @@ def _setup(**kwargs): } }, ) - kwargs.setdefault("migration_options", {"directory": "migrations"}) + sql_module.setdefault("migration_options", {"directory": "migrations"}) tm = Test.create_test_module( - modules=[EllarSQLAlchemyModule.setup(root_path=str(tmp_path), **kwargs)] + modules=[ + EllarSQLAlchemyModule.setup(root_path=str(tmp_path), **sql_module) + ], + **kwargs, ) return tm.create_application() return _setup + + +@pytest.fixture() +async def app_ctx(app_setup): + app = app_setup() + + async with app.application_context(): + yield app + + +@pytest.fixture() +async def app_ctx_async(app_setup_async): + app = app_setup_async() + + async with app.application_context(): + yield app diff --git a/tests/test_migrations/samples/default_async.py b/tests/test_migrations/samples/default_async.py new file mode 100644 index 0000000..d614b99 --- /dev/null +++ b/tests/test_migrations/samples/default_async.py @@ -0,0 +1,46 @@ +#!/bin/env python +import ellar_cli.click as click +from ellar.app import AppFactory, current_injector +from ellar.common.utils.importer import get_main_directory_by_stack +from ellar_cli.main import create_ellar_cli +from models import User +from sqlalchemy.ext.asyncio import AsyncSession + +from ellar_sqlalchemy import EllarSQLAlchemyModule + + +def bootstrap(): + path = get_main_directory_by_stack( + "__main__/__parent__/__parent__/dumbs/default_async", stack_level=1 + ) + application = AppFactory.create_app( + modules=[ + EllarSQLAlchemyModule.setup( + databases="sqlite+aiosqlite:///app.db", + migration_options={"context_configure": {"compare_types": False}}, + root_path=str(path), + ) + ] + ) + return application + + +cli = create_ellar_cli("default_async:bootstrap") + + +@cli.command() +@click.run_as_async +async def add_user(): + session = current_injector.get(AsyncSession) + user = User(name="default App Ellar") + session.add(user) + + await session.commit() + await session.refresh(user) + await session.close() + + click.echo(f"") + + +if __name__ == "__main__": + cli() diff --git a/tests/test_migrations/samples/multiple_database_async.py b/tests/test_migrations/samples/multiple_database_async.py new file mode 100644 index 0000000..6b9da0f --- /dev/null +++ b/tests/test_migrations/samples/multiple_database_async.py @@ -0,0 +1,57 @@ +#!/bin/env python +import ellar_cli.click as click +from ellar.app import AppFactory, current_injector +from ellar.common.utils.importer import get_main_directory_by_stack +from ellar_cli.main import create_ellar_cli +from models import Group, User +from sqlalchemy.ext.asyncio import AsyncSession + +from ellar_sqlalchemy import EllarSQLAlchemyModule + + +def bootstrap(): + path = get_main_directory_by_stack( + "__main__/__parent__/__parent__/dumbs/multiple_async", stack_level=1 + ) + application = AppFactory.create_app( + modules=[ + EllarSQLAlchemyModule.setup( + databases={ + "default": "sqlite+aiosqlite:///app.db", + "db1": "sqlite+aiosqlite:///app2.db", + }, + migration_options={"context_configure": {"compare_types": False}}, + root_path=str(path), + ) + ] + ) + return application + + +cli = create_ellar_cli("multiple_database_async:bootstrap") + + +@cli.command() +@click.run_as_async +async def add_user(): + session = current_injector.get(AsyncSession) + user = User(name="Multiple Database App Ellar") + group = Group(name="group") + + session.add(user) + session.add(group) + + await session.commit() + + await session.refresh(user) + await session.refresh(group) + + await session.close() + + click.echo( + f"" + ) + + +if __name__ == "__main__": + cli() diff --git a/tests/test_migrations/test_migrations_commands.py b/tests/test_migrations/test_migrations_commands.py index cc2a9c8..fa5555a 100644 --- a/tests/test_migrations/test_migrations_commands.py +++ b/tests/test_migrations/test_migrations_commands.py @@ -132,3 +132,33 @@ def test_other_alembic_commands(): result = run_command("default.py db downgrade") assert result.returncode == 1 assert b"Relative revision -1 didn't produce 1 migrations" in result.stderr + + +@clean_directory("default_async") +def test_migrate_upgrade_async(): + result = run_command("default_async.py db init") + assert result.returncode == 0 + assert ( + b"tests/dumbs/default_async/migrations/alembic.ini' before proceeding." + in result.stdout + ) + + result = run_command("default_async.py db check") + assert result.returncode == 1 + + result = run_command("default_async.py db migrate") + assert result.returncode == 0 + + result = run_command("default_async.py db check") + assert result.returncode == 1 + + result = run_command("default_async.py db upgrade") + assert result.returncode == 0 + + result = run_command("default_async.py db check") + assert result.returncode == 0 + assert result.stdout == b"No new upgrade operations detected.\n" + + result = run_command("default_async.py add-user") + assert result.returncode == 0 + assert result.stdout == b"\n" diff --git a/tests/test_migrations/test_multiple_database.py b/tests/test_migrations/test_multiple_database.py index 33d7240..1cbea4b 100644 --- a/tests/test_migrations/test_multiple_database.py +++ b/tests/test_migrations/test_multiple_database.py @@ -54,3 +54,37 @@ def test_migrate_upgrade_multiple_database_with_model_changes(): b"Detected type change from VARCHAR(length=256) to String(length=128)" in result.stderr ) + + +@clean_directory("multiple_async") +def test_migrate_upgrade_for_multiple_database_async(): + with set_env_variable("multiple_db", "true"): + result = run_command("multiple_database_async.py db init") + assert result.returncode == 0 + assert ( + b"tests/dumbs/multiple_async/migrations/alembic.ini' before proceeding." + in result.stdout + ) + + result = run_command("multiple_database_async.py db check") + assert result.returncode == 1 + + result = run_command("multiple_database_async.py db migrate") + assert result.returncode == 0 + + result = run_command("multiple_database_async.py db check") + assert result.returncode == 1 + + result = run_command("multiple_database_async.py db upgrade") + assert result.returncode == 0 + + result = run_command("multiple_database_async.py db check") + assert result.returncode == 0 + assert result.stdout == b"No new upgrade operations detected.\n" + + result = run_command("multiple_database_async.py add-user") + assert result.returncode == 0 + assert ( + result.stdout + == b"\n" + ) diff --git a/tests/test_pagination/__init__.py b/tests/test_pagination/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_pagination/seed.py b/tests/test_pagination/seed.py new file mode 100644 index 0000000..227d585 --- /dev/null +++ b/tests/test_pagination/seed.py @@ -0,0 +1,35 @@ +import typing as t + +from ellar.app import App +from ellar.threading import execute_coroutine_with_sync_worker + +from ellar_sqlalchemy import EllarSQLAlchemyService, model + + +def create_model(): + class User(model.Model): + id: model.Mapped[int] = model.Column(model.Integer, primary_key=True) + name: model.Mapped[str] = model.Column(model.String) + + return User + + +def seed_100_users(app: App): + user_model = create_model() + db_service = app.injector.get(EllarSQLAlchemyService) + + session = db_service.session_factory() + + if session.get_bind().dialect.is_async: + execute_coroutine_with_sync_worker(db_service.create_all_async()) + else: + db_service.create_all() + + for i in range(100): + session.add(user_model(name=f"User Number {i+1}")) + + res = session.commit() + if isinstance(res, t.Coroutine): + execute_coroutine_with_sync_worker(res) + + return user_model diff --git a/tests/test_pagination/templates/list.html b/tests/test_pagination/templates/list.html new file mode 100644 index 0000000..b2cb570 --- /dev/null +++ b/tests/test_pagination/templates/list.html @@ -0,0 +1,29 @@ + + +

{{ name }}

+{% macro render_pagination(pagination, endpoint) %} +
+ {{ pagination.first }} - {{ pagination.last }} of {{ pagination.total }} +
+
+ {% for page in pagination.iter_pages() %} + {% if page %} + {% if page != pagination.page %} + {{ page }} + {% else %} + {{ page }} + {% endif %} + {% else %} + + {% endif %} + {% endfor %} +
+{% endmacro %} + +
    + {% for user in paginator %} +
  • {{ user.id }} @ {{ user.name }} + {% endfor %} +
+{{render_pagination(pagination=paginator, endpoint="html_pagination") }} + diff --git a/tests/test_pagination/test_base.py b/tests/test_pagination/test_base.py new file mode 100644 index 0000000..4aa8854 --- /dev/null +++ b/tests/test_pagination/test_base.py @@ -0,0 +1,116 @@ +import typing as t + +import pytest + +from ellar_sqlalchemy.pagination import PaginatorBase + + +class RangePagination(PaginatorBase): + def __init__( + self, total: t.Optional[int] = 150, page: int = 1, per_page: int = 10 + ) -> None: + if total is None: + self._data = range(150) + else: + self._data = range(total) + + super().__init__(page=page, per_page=per_page) + + if total is None: + self.total = None + + def _get_init_kwargs(self) -> t.Dict[str, t.Any]: + return {"total": self.total} + + def _query_items(self) -> t.List[t.Any]: + first = self._query_offset + last = first + self.per_page + 1 + return list(self._data[first:last]) + + def _query_count(self) -> int: + return len(self._data) + + +def test_first_page(): + p = RangePagination() + assert p.page == 1 + assert p.per_page == 10 + assert p.total == 150 + assert p.pages == 15 + assert not p.has_prev + assert p.prev_num is None + assert p.has_next + assert p.next_num == 2 + + +def test_last_page(): + p = RangePagination(page=15) + assert p.page == 15 + assert p.has_prev + assert p.prev_num == 14 + assert not p.has_next + assert p.next_num is None + + +def test_item_numbers_first_page(): + p = RangePagination() + p.items = list(range(10)) + assert p.first == 1 + assert p.last == 10 + + +def test_item_numbers_last_page(): + p = RangePagination(page=15) + p.items = list(range(5)) + assert p.first == 141 + assert p.last == 145 + + +def test_item_numbers_0(): + p = RangePagination(total=0) + assert p.first == 0 + assert p.last == 0 + + +@pytest.mark.parametrize("total", [0, None]) +def test_0_pages(total): + p = RangePagination(total=total) + assert p.pages == 0 + assert not p.has_prev + assert not p.has_next + + +@pytest.mark.parametrize( + ("page", "expect"), + [ + (1, [1, 2, 3, 4, 5, None, 14, 15]), + (2, [1, 2, 3, 4, 5, 6, None, 14, 15]), + (3, [1, 2, 3, 4, 5, 6, 7, None, 14, 15]), + (4, [1, 2, 3, 4, 5, 6, 7, 8, None, 14, 15]), + (5, [1, 2, 3, 4, 5, 6, 7, 8, 9, None, 14, 15]), + (6, [1, 2, None, 4, 5, 6, 7, 8, 9, 10, None, 14, 15]), + (7, [1, 2, None, 5, 6, 7, 8, 9, 10, 11, None, 14, 15]), + (8, [1, 2, None, 6, 7, 8, 9, 10, 11, 12, None, 14, 15]), + (9, [1, 2, None, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + (10, [1, 2, None, 8, 9, 10, 11, 12, 13, 14, 15]), + (11, [1, 2, None, 9, 10, 11, 12, 13, 14, 15]), + (12, [1, 2, None, 10, 11, 12, 13, 14, 15]), + (13, [1, 2, None, 11, 12, 13, 14, 15]), + (14, [1, 2, None, 12, 13, 14, 15]), + (15, [1, 2, None, 13, 14, 15]), + ], +) +def test_iter_pages(page, expect): + p = RangePagination(page=page) + assert list(p.iter_pages()) == expect + + +def test_iter_0_pages(): + p = RangePagination(total=0) + assert list(p.iter_pages()) == [] + + +@pytest.mark.parametrize("page", [1, 2, 3, 4]) +def test_iter_pages_short(page): + p = RangePagination(page=page, total=40) + assert list(p.iter_pages()) == [1, 2, 3, 4] diff --git a/tests/test_pagination/test_pagination_view.py b/tests/test_pagination/test_pagination_view.py new file mode 100644 index 0000000..74ae0a2 --- /dev/null +++ b/tests/test_pagination/test_pagination_view.py @@ -0,0 +1,311 @@ +from pathlib import Path + +import ellar.common as ecm +import pytest +from ellar.testing import TestClient + +from ellar_sqlalchemy import ( + LimitOffsetPagination, + PageNumberPagination, + model, + paginate, +) + +from .seed import seed_100_users + +base = Path(__file__).parent + + +class UserSerializer(ecm.Serializer): + id: int + name: str + + +def _get_route_test_route( + user_model, pagination_class, case_2=False, case_3=False, invalid=False, **kw +): + kwargs = dict(kw, template_context=True, pagination_class=pagination_class) + if case_2: + kwargs.update(model=user_model) + + @ecm.get("/list") + @ecm.render("list") + @paginate(**kwargs) + def html_pagination(): + if case_2: + return {"name": "Ellar Pagination"} + + if case_3: + return model.select(user_model) + + if invalid: + return [] + + return model.select(user_model), {"name": "Ellar Pagination"} + + return html_pagination + + +@pytest.mark.parametrize( + "pagination_class, kw", + [(LimitOffsetPagination, {"limit": 5}), (PageNumberPagination, {"per_page": 5})], +) +def test_paginate_template_case_1(ignore_base, app_setup, pagination_class, kw): + app = app_setup(base_directory=base, template_folder="templates") + user_model = seed_100_users(app) + + app.router.append(_get_route_test_route(user_model, pagination_class, **kw)) + client = TestClient(app) + + res = client.get("/list") + assert res.status_code == 200 + + assert '2' in res.text + assert '20' in res.text + assert "
  • 1 @ User Number 1" in res.text + assert "
  • 5 @ User Number 5" in res.text + assert "Ellar Pagination" in res.text + + +@pytest.mark.parametrize( + "pagination_class, kw", + [(LimitOffsetPagination, {"limit": 5}), (PageNumberPagination, {"per_page": 5})], +) +def test_paginate_template_case_2(ignore_base, app_setup, pagination_class, kw): + app = app_setup(base_directory=base, template_folder="templates") + user_model = seed_100_users(app) + + app.router.append( + _get_route_test_route(user_model, pagination_class, case_2=True, **kw) + ) + client = TestClient(app) + + res = client.get("/list") + assert res.status_code == 200 + + assert '2' in res.text + assert '20' in res.text + assert "
  • 1 @ User Number 1" in res.text + assert "
  • 5 @ User Number 5" in res.text + assert "Ellar Pagination" in res.text + + +@pytest.mark.parametrize( + "pagination_class, kw", + [(LimitOffsetPagination, {"limit": 5}), (PageNumberPagination, {"per_page": 5})], +) +def test_paginate_template_case_3(ignore_base, app_setup, pagination_class, kw): + app = app_setup(base_directory=base, template_folder="templates") + user_model = seed_100_users(app) + + app.router.append( + _get_route_test_route(user_model, pagination_class, case_3=True, **kw) + ) + client = TestClient(app) + + res = client.get("/list") + assert res.status_code == 200 + + assert '2' in res.text + assert '20' in res.text + assert "
  • 1 @ User Number 1" in res.text + assert "
  • 5 @ User Number 5" in res.text + assert "Ellar Pagination" not in res.text + + +@pytest.mark.parametrize( + "pagination_class, kw", + [(LimitOffsetPagination, {"limit": 5}), (PageNumberPagination, {"per_page": 5})], +) +def test_paginate_template_case_invalid(ignore_base, app_setup, pagination_class, kw): + app = app_setup(base_directory=base, template_folder="templates") + user_model = seed_100_users(app) + + app.router.append( + _get_route_test_route(user_model, pagination_class, invalid=True, **kw) + ) + client = TestClient(app, raise_server_exceptions=False) + + res = client.get("/list") + assert res.status_code == 500 + + +def test_api_paginate_case_1(ignore_base, app_setup): + app = app_setup() + user_model = seed_100_users(app) + + @ecm.get("/list") + @paginate(item_schema=UserSerializer, per_page=5) + def paginated_user(): + return model.select(user_model) + + app.router.append(paginated_user) + client = TestClient(app) + + res = client.get("/list") + + assert res.status_code == 200 + assert res.json() == { + "count": 100, + "next": "http://testserver/list?page=2", + "previous": None, + "items": [ + {"id": 1, "name": "User Number 1"}, + {"id": 2, "name": "User Number 2"}, + {"id": 3, "name": "User Number 3"}, + {"id": 4, "name": "User Number 4"}, + {"id": 5, "name": "User Number 5"}, + ], + } + + +def test_api_paginate_case_2(ignore_base, app_setup): + app = app_setup() + user_model = seed_100_users(app) + + @ecm.get("/list") + @paginate(item_schema=UserSerializer, per_page=10) + def paginated_user(): + return user_model + + app.router.append(paginated_user) + client = TestClient(app) + + res = client.get("/list?page=10") + + assert res.status_code == 200 + data = res.json() + assert len(data["items"]) == 10 + assert data["next"] is None + + res = client.get("/list?page=2") + data = res.json() + assert data["next"] == "http://testserver/list?page=3" + assert data["previous"] == "http://testserver/list" + + +def test_api_paginate_case_3(ignore_base, app_setup): + app = app_setup() + user_model = seed_100_users(app) + + @ecm.get("/list") + @paginate(model=user_model, item_schema=UserSerializer, per_page=5) + def paginated_user(): + pass + + app.router.append(paginated_user) + client = TestClient(app) + + res = client.get("/list") + + assert res.status_code == 200 + assert len(res.json()["items"]) == 5 + + +def test_api_paginate_case_invalid(ignore_base, app_setup): + with pytest.raises(ecm.exceptions.ImproperConfiguration): + + @ecm.get("/list") + @paginate(per_page=5) + def paginated_user(): + pass + + +def test_api_paginate_with_limit_offset_case_1(ignore_base, app_setup): + app = app_setup() + user_model = seed_100_users(app) + + @ecm.get("/list") + @paginate( + item_schema=UserSerializer, + pagination_class=LimitOffsetPagination, + limit=5, + max_limit=10, + ) + def paginated_user(): + return user_model + + app.router.append(paginated_user) + client = TestClient(app) + + res = client.get("/list") + + assert res.status_code == 200 + assert res.json() == { + "count": 100, + "items": [ + {"id": 1, "name": "User Number 1"}, + {"id": 2, "name": "User Number 2"}, + {"id": 3, "name": "User Number 3"}, + {"id": 4, "name": "User Number 4"}, + {"id": 5, "name": "User Number 5"}, + ], + } + + res = client.get("/list?limit=10") + + assert res.status_code == 200 + assert len(res.json()["items"]) == 10 + + res = client.get("/list?limit=20") + + assert res.status_code == 200 + assert len(res.json()["items"]) == 10 + + +def test_api_paginate_with_limit_offset_case_2(ignore_base, app_setup): + app = app_setup() + user_model = seed_100_users(app) + + @ecm.get("/list") + @paginate( + item_schema=UserSerializer, + pagination_class=LimitOffsetPagination, + limit=5, + max_limit=10, + ) + def paginated_user(): + return model.select(user_model) + + app.router.append(paginated_user) + client = TestClient(app) + + res = client.get("/list?limit=5&offset=2") + + assert res.status_code == 200 + assert len(res.json()["items"]) == 5 + assert res.json()["items"][0] == {"id": 6, "name": "User Number 6"} + + +def test_api_paginate_with_limit_offset_case_3(ignore_base, app_setup): + app = app_setup() + user_model = seed_100_users(app) + + @ecm.get("/list") + @paginate( + model=user_model, + item_schema=UserSerializer, + pagination_class=LimitOffsetPagination, + limit=5, + max_limit=10, + ) + def paginated_user(): + pass + + app.router.append(paginated_user) + client = TestClient(app) + + res = client.get("/list?limit=5&offset=2") + + assert res.status_code == 200 + assert len(res.json()["items"]) == 5 + assert res.json()["items"][0] == {"id": 6, "name": "User Number 6"} + + +def test_api_paginate_with_limit_offset_case_invalid(ignore_base, app_setup): + with pytest.raises(ecm.exceptions.ImproperConfiguration): + + @ecm.get("/list") + @paginate(pagination_class=LimitOffsetPagination) + def paginated_user(): + pass diff --git a/tests/test_pagination/test_pagination_view_async.py b/tests/test_pagination/test_pagination_view_async.py new file mode 100644 index 0000000..d0a3a7a --- /dev/null +++ b/tests/test_pagination/test_pagination_view_async.py @@ -0,0 +1,203 @@ +from pathlib import Path + +import ellar.common as ecm +import pytest +from ellar.testing import TestClient + +from ellar_sqlalchemy import ( + LimitOffsetPagination, + PageNumberPagination, + model, + paginate, +) + +from .seed import seed_100_users + +base = Path(__file__).parent + + +class UserSerializer(ecm.Serializer): + id: int + name: str + + +def _get_route_test_route( + user_model, pagination_class, case_2=False, case_3=False, invalid=False, **kw +): + kwargs = dict(kw, template_context=True, pagination_class=pagination_class) + if case_2: + kwargs.update(model=user_model) + + @ecm.get("/list") + @ecm.render("list") + @paginate(**kwargs) + async def html_pagination(): + if case_2: + return {"name": "Ellar Pagination"} + + if case_3: + return model.select(user_model) + + if invalid: + return [] + + return model.select(user_model), {"name": "Ellar Pagination"} + + return html_pagination + + +@pytest.mark.parametrize( + "pagination_class, kw", + [(LimitOffsetPagination, {"limit": 5}), (PageNumberPagination, {"per_page": 5})], +) +def test_paginate_template_case_1_async(ignore_base, app_setup, pagination_class, kw): + app = app_setup(base_directory=base, template_folder="templates") + user_model = seed_100_users(app) + + app.router.append(_get_route_test_route(user_model, pagination_class, **kw)) + client = TestClient(app) + + res = client.get("/list") + assert res.status_code == 200 + + assert '2' in res.text + assert '20' in res.text + assert "
  • 1 @ User Number 1" in res.text + assert "
  • 5 @ User Number 5" in res.text + assert "Ellar Pagination" in res.text + + +@pytest.mark.parametrize( + "pagination_class, kw", + [(LimitOffsetPagination, {"limit": 5}), (PageNumberPagination, {"per_page": 5})], +) +def test_paginate_template_case_2_async(ignore_base, app_setup, pagination_class, kw): + app = app_setup(base_directory=base, template_folder="templates") + user_model = seed_100_users(app) + + app.router.append( + _get_route_test_route(user_model, pagination_class, case_2=True, **kw) + ) + client = TestClient(app) + + res = client.get("/list") + assert res.status_code == 200 + + assert '2' in res.text + assert '20' in res.text + assert "
  • 1 @ User Number 1" in res.text + assert "
  • 5 @ User Number 5" in res.text + assert "Ellar Pagination" in res.text + + +@pytest.mark.parametrize( + "pagination_class, kw", + [(LimitOffsetPagination, {"limit": 5}), (PageNumberPagination, {"per_page": 5})], +) +def test_paginate_template_case_3_async(ignore_base, app_setup, pagination_class, kw): + app = app_setup(base_directory=base, template_folder="templates") + user_model = seed_100_users(app) + + app.router.append( + _get_route_test_route(user_model, pagination_class, case_3=True, **kw) + ) + client = TestClient(app) + + res = client.get("/list") + assert res.status_code == 200 + + assert '2' in res.text + assert '20' in res.text + assert "
  • 1 @ User Number 1" in res.text + assert "
  • 5 @ User Number 5" in res.text + assert "Ellar Pagination" not in res.text + + +@pytest.mark.parametrize( + "pagination_class, kw", + [(LimitOffsetPagination, {"limit": 5}), (PageNumberPagination, {"per_page": 5})], +) +def test_paginate_template_case_invalid_async( + ignore_base, app_setup, pagination_class, kw +): + app = app_setup(base_directory=base, template_folder="templates") + user_model = seed_100_users(app) + + app.router.append( + _get_route_test_route(user_model, pagination_class, invalid=True, **kw) + ) + client = TestClient(app, raise_server_exceptions=False) + + res = client.get("/list") + assert res.status_code == 500 + + +def test_api_paginate_case_1_async(ignore_base, app_setup): + app = app_setup() + user_model = seed_100_users(app) + + @ecm.get("/list") + @paginate(item_schema=UserSerializer, per_page=5) + async def paginated_user(): + return model.select(user_model) + + app.router.append(paginated_user) + client = TestClient(app) + + res = client.get("/list") + + assert res.status_code == 200 + assert res.json() == { + "count": 100, + "next": "http://testserver/list?page=2", + "previous": None, + "items": [ + {"id": 1, "name": "User Number 1"}, + {"id": 2, "name": "User Number 2"}, + {"id": 3, "name": "User Number 3"}, + {"id": 4, "name": "User Number 4"}, + {"id": 5, "name": "User Number 5"}, + ], + } + + +def test_api_paginate_with_limit_offset_case_1_async(ignore_base, app_setup): + app = app_setup() + user_model = seed_100_users(app) + + @ecm.get("/list") + @paginate( + item_schema=UserSerializer, + pagination_class=LimitOffsetPagination, + limit=5, + max_limit=10, + ) + async def paginated_user(): + return user_model + + app.router.append(paginated_user) + client = TestClient(app) + + res = client.get("/list") + + assert res.status_code == 200 + assert res.json() == { + "count": 100, + "items": [ + {"id": 1, "name": "User Number 1"}, + {"id": 2, "name": "User Number 2"}, + {"id": 3, "name": "User Number 3"}, + {"id": 4, "name": "User Number 4"}, + {"id": 5, "name": "User Number 5"}, + ], + } + + res = client.get("/list?limit=10") + + assert res.status_code == 200 + assert len(res.json()["items"]) == 10 + + res = client.get("/list?limit=20") + + assert res.status_code == 200 + assert len(res.json()["items"]) == 10 diff --git a/tests/test_pagination/test_paginator.py b/tests/test_pagination/test_paginator.py new file mode 100644 index 0000000..14ea9dd --- /dev/null +++ b/tests/test_pagination/test_paginator.py @@ -0,0 +1,99 @@ +import pytest +from ellar.common import NotFound + +from ellar_sqlalchemy import EllarSQLAlchemyService +from ellar_sqlalchemy.pagination import Paginator + +from .seed import create_model, seed_100_users + + +async def test_user_model_paginator(ignore_base, app_ctx, anyio_backend): + user_model = seed_100_users(app_ctx) + page1 = Paginator(model=user_model, per_page=25, page=1, count=True) + assert page1.page == 1 + + assert page1.per_page == 25 + assert len(page1.items) == 25 + + assert page1.total == 100 + assert page1.pages == 4 + + +async def test_user_model_paginator_async(ignore_base, app_ctx_async, anyio_backend): + user_model = seed_100_users(app_ctx_async) + page2 = Paginator(model=user_model, per_page=25, page=2, count=True) + assert page2.page == 2 + + assert page2.per_page == 25 + assert len(page2.items) == 25 + + assert page2.total == 100 + assert page2.pages == 4 + page1 = page2.prev(error_out=True) + assert page1.total == 100 + assert page1.has_prev is False + + +async def test_paginate_qs(ignore_base, app_ctx, anyio_backend): + user_model = seed_100_users(app_ctx) + + p = Paginator(model=user_model, page=2, per_page=10) + assert p.page == 2 + assert p.per_page == 10 + + +async def test_paginate_max(ignore_base, app_ctx, anyio_backend): + user_model = seed_100_users(app_ctx) + + p = Paginator(model=user_model, per_page=100, max_per_page=50) + assert p.per_page == 50 + + +async def test_next_page_size(ignore_base, app_ctx, anyio_backend): + user_model = seed_100_users(app_ctx) + + p = Paginator(model=user_model, per_page=25, max_per_page=50) + assert p.page == 1 + assert p.per_page == 25 + + p = p.next() + assert p.page == 2 + assert p.per_page == 25 + + +async def test_no_count(ignore_base, app_ctx, anyio_backend): + user_model = seed_100_users(app_ctx) + + p = Paginator(model=user_model, count=False) + assert p.total is None + + +async def test_no_items_404(ignore_base, app_ctx, anyio_backend): + user_model = create_model() + db_service = app_ctx.injector.get(EllarSQLAlchemyService) + + db_service.create_all() + + p = Paginator(model=user_model) + assert len(p.items) == 0 + + with pytest.raises(NotFound): + p.next(error_out=True) + + with pytest.raises(NotFound): + p.prev(error_out=True) + + +async def test_error_out(ignore_base, app_ctx, anyio_backend): + user_model = create_model() + db_service = app_ctx.injector.get(EllarSQLAlchemyService) + + db_service.create_all() + for page, per_page in [(-2, 5), (1, -5)]: + with pytest.raises(NotFound): + Paginator(model=user_model, page=page, per_page=per_page) + + for page, per_page in [(-2, -5), (-1, 0)]: + p = Paginator(model=user_model, page=page, per_page=per_page, error_out=False) + assert p.per_page == 20 + assert p.page == 1 diff --git a/tests/test_query/__init__.py b/tests/test_query/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_query/test_utils.py b/tests/test_query/test_utils.py new file mode 100644 index 0000000..ebff0bf --- /dev/null +++ b/tests/test_query/test_utils.py @@ -0,0 +1,112 @@ +import typing as t + +import pytest +from ellar.app import App +from ellar.common import NotFound +from ellar.threading import execute_coroutine_with_sync_worker + +from ellar_sqlalchemy import ( + EllarSQLAlchemyService, + first_or_404, + first_or_404_async, + get_or_404, + get_or_404_async, + model, + one_or_404, + one_or_404_async, +) + + +def _create_model(): + class User(model.Model): + id: model.Mapped[int] = model.Column(model.Integer, primary_key=True) + name: model.Mapped[str] = model.Column(model.String) + + return User + + +def _seed_model(app: App): + user_model = _create_model() + db_service = app.injector.get(EllarSQLAlchemyService) + + session = db_service.session_factory() + + if session.get_bind().dialect.is_async: + execute_coroutine_with_sync_worker(db_service.create_all_async()) + else: + db_service.create_all() + + session.add(user_model(name="First User")) + res = session.commit() + + if isinstance(res, t.Coroutine): + execute_coroutine_with_sync_worker(res) + + return user_model + + +async def test_get_or_404_works(ignore_base, app_ctx, anyio_backend): + user_model = _seed_model(app_ctx) + + user_instance = get_or_404(user_model, 1) + assert user_instance.name == "First User" + + with pytest.raises(NotFound): + get_or_404(user_model, 2) + + +async def test_get_or_404_async_works(ignore_base, app_ctx_async, anyio_backend): + if anyio_backend == "asyncio": + user_model = _seed_model(app_ctx_async) + + user_instance = await get_or_404_async(user_model, 1) + assert user_instance.name == "First User" + + with pytest.raises(NotFound): + await get_or_404_async(user_model, 2) + + +async def test_first_or_404_works(ignore_base, app_ctx, anyio_backend): + user_model = _seed_model(app_ctx) + + user_instance = first_or_404(model.select(user_model).where(user_model.id == 1)) + assert user_instance.name == "First User" + + with pytest.raises(NotFound): + first_or_404(model.select(user_model).where(user_model.id == 2)) + + +async def test_first_or_404_async_works(ignore_base, app_ctx_async, anyio_backend): + if anyio_backend == "asyncio": + user_model = _seed_model(app_ctx_async) + + user_instance = await first_or_404_async( + model.select(user_model).where(user_model.id == 1) + ) + assert user_instance.name == "First User" + + with pytest.raises(NotFound): + await first_or_404_async(model.select(user_model).where(user_model.id == 2)) + + +async def test_one_or_404_works(ignore_base, app_ctx, anyio_backend): + user_model = _seed_model(app_ctx) + + user_instance = one_or_404(model.select(user_model).where(user_model.id == 1)) + assert user_instance.name == "First User" + + with pytest.raises(NotFound): + one_or_404(model.select(user_model).where(user_model.id == 2)) + + +async def test_one_or_404_async_works(ignore_base, app_ctx_async, anyio_backend): + if anyio_backend == "asyncio": + user_model = _seed_model(app_ctx_async) + + user_instance = await one_or_404_async( + model.select(user_model).where(user_model.id == 1) + ) + assert user_instance.name == "First User" + + with pytest.raises(NotFound): + await one_or_404_async(model.select(user_model).where(user_model.id == 2)) diff --git a/tests/test_session.py b/tests/test_session.py index a6ebf33..96dfc27 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -23,7 +23,7 @@ def scope() -> int: count += 1 return count - app = app_setup(session_options={"scopefunc": scope}) + app = app_setup(sql_module={"session_options": {"scopefunc": scope}}) async with app.application_context(): first = app.injector.get(model.Session) @@ -37,7 +37,7 @@ def test_session_class(app_setup, ignore_base): class CustomSession(model.Session): pass - app = app_setup(session_options={"class_": CustomSession}) + app = app_setup(sql_module={"session_options": {"class_": CustomSession}}) session = app.injector.get(model.Session) assert isinstance(session, CustomSession) @@ -56,7 +56,9 @@ class Post(model.Model): __database__ = "a" id: model.Mapped[int] = model.mapped_column(model.Integer, primary_key=True) - app = app_setup(databases={"a": "sqlite://", "default": "sqlite://"}) + app = app_setup( + sql_module={"databases": {"a": "sqlite://", "default": "sqlite://"}} + ) session = app.injector.get(model.Session) db_service = app.injector.get(EllarSQLAlchemyService) @@ -80,7 +82,9 @@ class Post(model.Model): __database__ = "a" id = model.Column(model.Integer, primary_key=True) - app = app_setup(databases={"a": "sqlite://", "default": "sqlite://"}) + app = app_setup( + sql_module={"databases": {"a": "sqlite://", "default": "sqlite://"}} + ) session = app.injector.get(model.Session) db_service = app.injector.get(EllarSQLAlchemyService) @@ -111,7 +115,9 @@ class Admin(User): org: model.Mapped[str] = model.mapped_column(model.String, nullable=False) __mapper_args__ = {"polymorphic_identity": "admin"} - app = app_setup(databases={"a": "sqlite://", "default": "sqlite://"}) + app = app_setup( + sql_module={"databases": {"a": "sqlite://", "default": "sqlite://"}} + ) db_service = app.injector.get(EllarSQLAlchemyService) db_service.create_all() @@ -138,7 +144,9 @@ class Admin(User): org = model.Column(model.String, nullable=False) __mapper_args__ = {"polymorphic_identity": "admin"} - app = app_setup(databases={"a": "sqlite://", "default": "sqlite://"}) + app = app_setup( + sql_module={"databases": {"a": "sqlite://", "default": "sqlite://"}} + ) db_service = app.injector.get(EllarSQLAlchemyService) db_service.create_all() @@ -176,7 +184,9 @@ class Product(Base): model.String(50), nullable=False, init=False ) - app = app_setup(databases={"db1": "sqlite:///", "default": "sqlite://"}) + app = app_setup( + sql_module={"databases": {"db1": "sqlite:///", "default": "sqlite://"}} + ) db_service = app.injector.get(EllarSQLAlchemyService) db_service.create_all() @@ -212,7 +222,9 @@ class Product(model.Model): id = model.Column(model.Integer, primary_key=True) name = model.Column(model.String(50), nullable=False) - app = app_setup(databases={"db1": "sqlite:///", "default": "sqlite://"}) + app = app_setup( + sql_module={"databases": {"db1": "sqlite:///", "default": "sqlite://"}} + ) db_service = app.injector.get(EllarSQLAlchemyService) db_service.create_all()