Skip to content

Commit

Permalink
Merge pull request #1030 from jamesrkiger/master
Browse files Browse the repository at this point in the history
Async Pagination support
  • Loading branch information
vitalik committed Apr 30, 2024
2 parents 3c0cc93 + 1cf7dfd commit a937981
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 24 deletions.
4 changes: 4 additions & 0 deletions docs/docs/guides/response/pagination.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def paginate_queryset(self, queryset, pagination: Input, **params):
request = params["request"]
```

#### Async Pagination

Standard **Django Ninja** pagination classes support async. If you wish to handle async requests with a custom pagination class, you should subclass `ninja.pagination.AsyncPaginationBase` and override the `apaginate_queryset(self, queryset, request, **params)` method.

### Output attribute

By defult page items are placed to `'items'` attribute. To override this behaviour use `items_attribute`:
Expand Down
123 changes: 100 additions & 23 deletions ninja/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from functools import partial, wraps
from math import inf
from typing import Any, Callable, List, Optional, Tuple, Type
from typing import Any, AsyncGenerator, Callable, List, Optional, Tuple, Type, Union

from django.db.models import QuerySet
from django.http import HttpRequest
Expand All @@ -15,7 +15,11 @@
from ninja.errors import ConfigError
from ninja.operation import Operation
from ninja.signature.details import is_collection_type
from ninja.utils import contribute_operation_args, contribute_operation_callback
from ninja.utils import (
contribute_operation_args,
contribute_operation_callback,
is_async_callable,
)


class PaginationBase(ABC):
Expand Down Expand Up @@ -54,7 +58,24 @@ def _items_count(self, queryset: QuerySet) -> int:
return len(queryset)


class LimitOffsetPagination(PaginationBase):
class AsyncPaginationBase(PaginationBase):
@abstractmethod
async def apaginate_queryset(
self,
queryset: QuerySet,
pagination: Any,
**params: Any,
) -> Any:
pass # pragma: no cover

async def _aitems_count(self, queryset: QuerySet) -> int:
try:
return await queryset.all().acount()
except AttributeError:
return len(queryset)


class LimitOffsetPagination(AsyncPaginationBase):
class Input(Schema):
limit: int = Field(
settings.PAGINATION_PER_PAGE,
Expand All @@ -78,8 +99,21 @@ def paginate_queryset(
"count": self._items_count(queryset),
} # noqa: E203

async def apaginate_queryset(
self,
queryset: QuerySet,
pagination: Input,
**params: Any,
) -> Any:
offset = pagination.offset
limit: int = min(pagination.limit, settings.PAGINATION_MAX_LIMIT)
return {
"items": queryset[offset : offset + limit],
"count": await self._aitems_count(queryset),
} # noqa: E203

class PageNumberPagination(PaginationBase):

class PageNumberPagination(AsyncPaginationBase):
class Input(Schema):
page: int = Field(1, ge=1)

Expand All @@ -101,6 +135,18 @@ def paginate_queryset(
"count": self._items_count(queryset),
} # noqa: E203

async def apaginate_queryset(
self,
queryset: QuerySet,
pagination: Input,
**params: Any,
) -> Any:
offset = (pagination.page - 1) * self.page_size
return {
"items": queryset[offset : offset + self.page_size],
"count": await self._aitems_count(queryset),
} # noqa: E203


def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable:
"""
Expand All @@ -119,7 +165,9 @@ def my_view(request):
isfunction = inspect.isfunction(func_or_pgn_class)
isnotset = func_or_pgn_class == NOT_SET

pagination_class: Type[PaginationBase] = import_string(settings.PAGINATION_CLASS)
pagination_class: Type[Union[PaginationBase, AsyncPaginationBase]] = import_string(
settings.PAGINATION_CLASS
)

if isfunction:
return _inject_pagination(func_or_pgn_class, pagination_class)
Expand All @@ -135,26 +183,55 @@ def wrapper(func: Callable) -> Any:

def _inject_pagination(
func: Callable,
paginator_class: Type[PaginationBase],
paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]],
**paginator_params: Any,
) -> Callable:
paginator: PaginationBase = paginator_class(**paginator_params)

@wraps(func)
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = func(request, **kwargs)

result = paginator.paginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
)
if paginator.Output: # type: ignore
result[paginator.items_attribute] = list(result[paginator.items_attribute])
# ^ forcing queryset evaluation #TODO: check why pydantic did not do it here
return result
paginator = paginator_class(**paginator_params)
if is_async_callable(func):
if not hasattr(paginator, "apaginate_queryset"):
raise ConfigError("Pagination class not configured for async requests")

@wraps(func)
async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = await func(request, **kwargs)

result = await paginator.apaginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
)

async def evaluate(results: Union[List, QuerySet]) -> AsyncGenerator:
for result in results:
yield result

if paginator.Output: # type: ignore
result[paginator.items_attribute] = [
result
async for result in evaluate(result[paginator.items_attribute])
]
return result
else:

@wraps(func)
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = func(request, **kwargs)

result = paginator.paginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
)
if paginator.Output: # type: ignore
result[paginator.items_attribute] = list(
result[paginator.items_attribute]
)
# ^ forcing queryset evaluation #TODO: check why pydantic did not do it here
return result

contribute_operation_args(
view_with_pagination,
Expand Down
123 changes: 123 additions & 0 deletions tests/test_pagination_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import asyncio
from typing import Any, List

import pytest
from django.db.models import QuerySet

from ninja import NinjaAPI, Schema
from ninja.errors import ConfigError
from ninja.pagination import (
AsyncPaginationBase,
PageNumberPagination,
PaginationBase,
paginate,
)
from ninja.testing import TestAsyncClient

api = NinjaAPI()

ITEMS = list(range(100))


class NoAsyncPagination(PaginationBase):
# only offset param, defaults to 5 per page
class Input(Schema):
skip: int

class Output(Schema):
items: List[Any]
count: str
skip: int

def paginate_queryset(self, items, pagination: Input, **params):
skip = pagination.skip
return {
"items": items[skip : skip + 5],
"count": "many",
"skip": skip,
}


class AsyncNoOutputPagination(AsyncPaginationBase):
# Outputs items without count attribute
class Input(Schema):
skip: int

Output = None

def paginate_queryset(self, items, pagination: Input, **params):
skip = pagination.skip
return items[skip : skip + 5]

async def apaginate_queryset(self, items, pagination: Input, **params):
await asyncio.sleep(0)
skip = pagination.skip
return items[skip : skip + 5]

def _items_count(self, queryset: QuerySet) -> int:
try:
# forcing to find queryset.count instead of list.count:
return queryset.all().count()
except AttributeError:
asyncio.sleep(0)
return len(queryset)


@pytest.mark.asyncio
async def test_async_config_error():
api = NinjaAPI()

with pytest.raises(
ConfigError, match="Pagination class not configured for async requests"
):

@api.get("/items_async_undefined", response=List[int])
@paginate(NoAsyncPagination)
async def items_async_undefined(request, **kwargs):
return ITEMS


@pytest.mark.asyncio
async def test_async_custom_pagination():
api = NinjaAPI()

@api.get("/items_async", response=List[int])
@paginate(AsyncNoOutputPagination)
async def items_async(request):
return ITEMS

client = TestAsyncClient(api)

response = await client.get("/items_async?skip=10")
assert response.json() == [10, 11, 12, 13, 14]


@pytest.mark.asyncio
async def test_async_default():
api = NinjaAPI()

@api.get("/items_default", response=List[int])
@paginate # WITHOUT brackets (should use default pagination)
async def items_default(request, someparam: int = 0, **kwargs):
asyncio.sleep(0)
return ITEMS

client = TestAsyncClient(api)

response = await client.get("/items_default?limit=10")
assert response.json() == {"items": ITEMS[:10], "count": 100}


@pytest.mark.asyncio
async def test_async_page_number():
api = NinjaAPI()

@api.get("/items_page_number", response=List[Any])
@paginate(PageNumberPagination, page_size=10, pass_parameter="page_info")
async def items_page_number(request, **kwargs):
return ITEMS + [kwargs["page_info"]]

client = TestAsyncClient(api)

response = await client.get("/items_page_number?page=11")
assert response.json() == {"items": [{"page": 11}], "count": 101}
16 changes: 15 additions & 1 deletion tests/test_pagination_router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List

import pytest

from ninja import NinjaAPI, Schema
from ninja.pagination import RouterPaginated
from ninja.testing import TestClient
from ninja.testing import TestAsyncClient, TestClient

api = NinjaAPI(default_router=RouterPaginated())

Expand Down Expand Up @@ -62,3 +64,15 @@ def test_for_NON_list_reponse():
]
# print(parameters)
assert parameters == []


@pytest.mark.asyncio
async def test_async_pagination():
@api.get("/items_async", response=List[ItemSchema])
async def items_async(request):
return [{"id": i} for i in range(1, 51)]

client = TestAsyncClient(api)

response = await client.get("/items_async?offset=5&limit=1")
assert response.json() == {"items": [{"id": 6}], "count": 50}

0 comments on commit a937981

Please sign in to comment.