Skip to content

Commit

Permalink
Add tests for cursor pagination
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmullen committed Aug 29, 2023
1 parent b5a11e4 commit 4b99b8c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 6 deletions.
10 changes: 5 additions & 5 deletions ninja/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _clamp(val: int, min_: int, max_: int) -> int:
return max(min_, min(val, max_))


def _reverse_order(order: tuple):
def _reverse_order(order: tuple) -> tuple:
"""
Reverse the ordering specification for a Django ORM query.
Expand All @@ -262,7 +262,7 @@ def invert(x):
return tuple(invert(item) for item in order)


def _replace_query_param(url: str, key: str, val: str):
def _replace_query_param(url: str, key: str, val: str) -> str:
scheme, netloc, path, query, fragment = parse.urlsplit(url)
query_dict = parse.parse_qs(query, keep_blank_values=True)
query_dict[key] = [val]
Expand Down Expand Up @@ -296,8 +296,8 @@ def decode_cursor(cls, encoded_cursor: Optional[str]) -> Cursor:
reverse = bool(int(reverse))

position = tokens.get("p", [None])[0]
except (TypeError, ValueError):
raise ValueError(_("Invalid cursor.")) from None
except (TypeError, ValueError) as e:
raise ValueError(_("Invalid cursor.")) from e

return Cursor(offset=offset, reverse=reverse, position=position)

Expand Down Expand Up @@ -535,7 +535,7 @@ def previous_link(
cursor = Cursor(offset=offset, reverse=True, position=position)
return self._encode_cursor(cursor, base_url)

def _get_position_from_instance(self, instance, ordering):
def _get_position_from_instance(self, instance, ordering) -> str:
field_name = ordering[0].lstrip("-")
if isinstance(instance, dict):
attr = instance[field_name]
Expand Down
16 changes: 15 additions & 1 deletion tests/demo_project/someapp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from pydantic import BaseModel

from ninja import Router
from ninja.pagination import CursorPagination, paginate

from .models import Event
from .models import Category, Event

router = Router()

Expand All @@ -20,6 +21,19 @@ class Config:
from_attributes = True


class CategorySchema(BaseModel):
title: str

class Config:
from_attributes = True


@router.get("/categories", response=List[CategorySchema])
@paginate(CursorPagination)
def list_categories(request):
return Category.objects.order_by("title")


@router.post("/create", url_name="event-create-url-name")
def create_event(request, event: EventSchema):
Event.objects.create(**event.model_dump())
Expand Down
58 changes: 58 additions & 0 deletions tests/test_cursor_pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
from django.test import Client
from someapp.models import Category


@pytest.fixture
def categories():
yield [Category.objects.create(title=title) for title in ["C", "E", "B", "D", "A"]]

Category.objects.all().delete()


@pytest.mark.django_db
def test_cursor_pagination_single_page(client: Client, categories):
response = client.get("/api/events/categories")
assert response.status_code == 200, response.json()
assert response.json() == {
"results": [
{"title": "A"},
{"title": "B"},
{"title": "C"},
{"title": "D"},
{"title": "E"},
],
"count": 5,
"next": None,
"previous": None,
}


@pytest.mark.django_db
def test_cursor_pagination_iteration(client: Client, categories):
response = client.get("/api/events/categories", data={"limit": 2})
assert response.status_code == 200, response.json()
assert response.json()["results"] == [{"title": "A"}, {"title": "B"}]
next_url = response.json()["next"]
assert next_url is not None
assert response.json()["previous"] is None
assert response.json()["count"] == 5

# follow next page link
response = client.get(next_url)
assert response.status_code == 200, response.json()
assert response.json()["results"] == [{"title": "C"}, {"title": "D"}]
previous_url = response.json()["previous"]
assert previous_url is not None
assert response.json()["count"] == 5

# follow previous page link
response = client.get(previous_url)
assert response.status_code == 200, response.json()
assert response.json()["results"] == [{"title": "A"}, {"title": "B"}]


def test_invalid_cursor(client: Client):
response = client.get("/api/events/categories", data={"cursor": "invalid"})
assert response.status_code == 422
assert "Invalid cursor." in response.json()["detail"][0]["msg"]

0 comments on commit 4b99b8c

Please sign in to comment.