Skip to content

Commit

Permalink
Improve pagination typing
Browse files Browse the repository at this point in the history
  • Loading branch information
max-muoto committed Apr 12, 2024
1 parent 08af4b9 commit 284d74d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 23 deletions.
105 changes: 82 additions & 23 deletions ninja/pagination.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
import inspect
from abc import ABC, abstractmethod
from functools import partial, wraps
from typing import Any, Callable, List, Optional, Tuple, Type

from django.db.models import QuerySet
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)

from django.db import models
from django.http import HttpRequest
from django.utils.module_loading import import_string
from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeGuard
from typing_extensions import get_args as get_collection_args

from ninja import Field, Query, Router, Schema
from ninja.conf import settings
from ninja.constants import NOT_SET
from ninja.constants import NOT_SET, NOT_SET_TYPE
from ninja.errors import ConfigError
from ninja.operation import Operation
from ninja.signature.details import is_collection_type
from ninja.utils import contribute_operation_args, contribute_operation_callback

Req = TypeVar("Req", bound=HttpRequest)
M = TypeVar("M", bound=models.Model)
P = ParamSpec("P")

ViewFn: TypeAlias = Callable[Concatenate[Req, P], Sequence[Any]]
PaginatedViewFn: TypeAlias = Callable[Concatenate[Req, P], Dict[str, Any]]


class PaginationBase(ABC):
class Input(Schema):
Expand All @@ -35,20 +55,21 @@ def __init__(self, *, pass_parameter: Optional[str] = None, **kwargs: Any) -> No
@abstractmethod
def paginate_queryset(
self,
queryset: QuerySet,
queryset: Sequence[M],
pagination: Any,
**params: Any,
) -> Any:
) -> Dict[str, Any]:
pass # pragma: no cover

def _items_count(self, queryset: QuerySet) -> int:
def _items_count(self, queryset: Sequence[M]) -> int:
"""
Since lists are mainly compatible with QuerySets and can be passed to paginator.
We will first to try to use .count - and if not there will use a len
"""
try:
# forcing to find queryset.count instead of list.count:
return queryset.all().count()
# Avoid checking the type with `isinstance` because this might not work with
# monkey-patched QuerySets.
return queryset.all().count() # type: ignore
except AttributeError:
return len(queryset)

Expand All @@ -60,10 +81,10 @@ class Input(Schema):

def paginate_queryset(
self,
queryset: QuerySet,
queryset: Sequence[M],
pagination: Input,
**params: Any,
) -> Any:
) -> Dict[str, Any]:
offset = pagination.offset
limit: int = min(pagination.limit, settings.PAGINATION_MAX_LIMIT)
return {
Expand All @@ -84,18 +105,40 @@ def __init__(

def paginate_queryset(
self,
queryset: QuerySet,
queryset: Sequence[M],
pagination: Input,
**params: Any,
) -> Any:
) -> Dict[str, Any]:
offset = (pagination.page - 1) * self.page_size
return {
"items": queryset[offset : offset + self.page_size],
"count": self._items_count(queryset),
} # noqa: E203


def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable:
@overload
def paginate(
func_or_pgn_class: ViewFn[Req, P], **paginator_params: Any
) -> PaginatedViewFn[Req, P]:
...


@overload
def paginate(
func_or_pgn_class: Union[Type[PaginationBase], NOT_SET_TYPE] = NOT_SET,
**paginator_params: Any,
) -> Callable[[ViewFn[Req, P]], PaginatedViewFn[Req, P]]:
...


def paginate(
func_or_pgn_class: Union[
ViewFn[Req, P], Type[PaginationBase], NOT_SET_TYPE
] = NOT_SET,
**paginator_params: Any,
) -> Union[
PaginatedViewFn[Req, P], Callable[[ViewFn[Req, P]], PaginatedViewFn[Req, P]]
]:
"""
@api.get(...
@paginate
Expand All @@ -109,37 +152,53 @@ def my_view(request):
"""

isfunction = inspect.isfunction(func_or_pgn_class)
def _is_view_func(func: Any) -> TypeGuard[ViewFn[Req, P]]:
return inspect.isfunction(func_or_pgn_class)

isnotset = func_or_pgn_class == NOT_SET

pagination_class: Type[PaginationBase] = import_string(settings.PAGINATION_CLASS)

if isfunction:
if _is_view_func(func_or_pgn_class):
return _inject_pagination(func_or_pgn_class, pagination_class)

if not isnotset:
# Second check is redundant, but `TypeGuard` doesn't narrow the negative case.
# `TypeIs` should resolve this: https://peps.python.org/pep-0742/
if not isnotset and isinstance(func_or_pgn_class, type):
pagination_class = func_or_pgn_class

def wrapper(func: Callable) -> Any:
def wrapper(func: ViewFn[Req, P]) -> PaginatedViewFn[Req, P]:
return _inject_pagination(func, pagination_class, **paginator_params)

return wrapper


def _inject_pagination(
func: Callable,
func: ViewFn[Req, P],
paginator_class: Type[PaginationBase],
**paginator_params: Any,
) -> Callable:
paginator: PaginationBase = paginator_class(**paginator_params)
) -> PaginatedViewFn[Req, P]:
"""Inject pagination into the view function.
Args:
func: The view function.
paginator_class: The paginator class.
**paginator_params: Parameters for the paginator class.
Returns:
The view function with pagination injected into the response.
"""
paginator = paginator_class(**paginator_params)

@wraps(func)
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
def view_with_pagination(
request: Req, *args: P.args, **kwargs: P.kwargs
) -> Dict[str, Any]:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = func(request, **kwargs)
items = func(request, *args, **kwargs)

result = paginator.paginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ branch = true
fail_under = 100
skip_covered = true
show_missing = true
exclude_also = [
"@(typing\\.)?overload",
]

0 comments on commit 284d74d

Please sign in to comment.