Skip to content

Commit

Permalink
rework pagination for common use cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sheppard committed Mar 30, 2017
1 parent 388336d commit 663bbda
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 3 deletions.
35 changes: 35 additions & 0 deletions rest/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,41 @@ class Pagination(PageNumberPagination):
page_size = 50
page_size_query_param = 'limit'

def paginate_queryset(self, queryset, request, view=None):
data = super(Pagination, self).paginate_queryset(
queryset, request, view
)
if not view or not view.router:
return data

if request.accepted_renderer.format != 'json':
return data

non_format_kwargs = [
kwarg for kwarg in
list(view.kwargs.keys()) + list(request.GET.keys())
if kwarg != 'format'
]
if view.action != 'list' or any(non_format_kwargs):
return data

conf = view.router.get_model_config(queryset.model)
cache = conf.get('cache', 'first_page')

if cache == 'first_page':
return data
elif cache == 'all':
return list(queryset)
elif cache == 'none':
return []
elif cache == 'filter':
cache_filter = view.router.get_cache_filter_for_model(
queryset.model
)
return list(cache_filter(queryset, self.request))

return data

def get_paginated_response(self, data):
return Response(OrderedDict([
# DRF default metadata
Expand Down
27 changes: 24 additions & 3 deletions rest/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ModelRouter(DefaultRouter):
_fields = {}
_querysets = {}
_filters = {}
_cache_filters = {}
_viewsets = {}
_extra_pages = {}
_config = {}
Expand Down Expand Up @@ -50,7 +51,8 @@ def __init__(self, trailing_slash=False):
super(ModelRouter, self).__init__(trailing_slash=trailing_slash)

def register_model(self, model, viewset=None, serializer=None, fields=None,
queryset=None, filter=None, **kwargs):
queryset=None, filter=None, cache_filter=None,
**kwargs):
if isinstance(model, string_types) and '.' in model:
from django.db.models import get_model
model = get_model(*model.split('.'))
Expand All @@ -62,6 +64,9 @@ def register_model(self, model, viewset=None, serializer=None, fields=None,
self.register_queryset(model, queryset)
if filter:
self.register_filter(model, filter)
if cache_filter:
self.register_cache_filter(model, cache_filter)
kwargs['cache'] = 'filter'

if 'name' not in kwargs:
kwargs['name'] = model._meta.model_name
Expand Down Expand Up @@ -110,16 +115,29 @@ def register_fields(self, model, fields):
def register_queryset(self, model, queryset):
self._querysets[model] = queryset

def register_filter(self, model, queryset):
self._filters[model] = queryset
def register_filter(self, model, filter):
self._filters[model] = filter

def register_cache_filter(self, model, cache_filter):
self._cache_filters[model] = cache_filter

def register_config(self, model, config):
for key in ('partial', 'reversed', 'max_local_pages'):
if key in config:
raise ImproperlyConfigured(
'"%s" is deprecated in favor of "cache"' % key
)
self._config[model] = config
self._base_config = None

def update_config(self, model, **kwargs):
if model not in self._config:
raise RuntimeError("%s must be registered first" % model)
for key in ('partial', 'reversed', 'max_local_pages'):
if key in kwargs:
raise ImproperlyConfigured(
'"%s" is deprecated in favor of "cache"' % key
)
self._config[model].update(kwargs)
self._base_config = None

Expand Down Expand Up @@ -223,6 +241,9 @@ def get_queryset_for_model(self, model, request=None):
qs = self._filters[model](qs, request)
return qs

def get_cache_filter_for_model(self, model):
return self._cache_filters.get(model, lambda qs, req: qs)

def get_lookup_for_model(self, model_class):
config = self.get_model_config(model_class) or {}
return config.get('lookup', 'pk')
Expand Down
7 changes: 7 additions & 0 deletions tests/conflict_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@ class Item(models.Model):

def __str__(self):
return self.name


class TestModel(models.Model):
name = models.CharField(max_length=10)

def __str__(self):
return self.name
11 changes: 11 additions & 0 deletions tests/rest_app/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
SlugRefChildSerializer,
)


def cache_users_own_data(qs, req):
if req.user.is_authenticated():
return qs.filter(user=req.user)
else:
return qs.none()


rest.router.register_model(
RootModel,
url="",
Expand All @@ -20,6 +28,7 @@
rest.router.register_model(
UserManagedModel,
fields="__all__",
cache_filter=cache_users_own_data,
)
rest.router.register_model(
Parent,
Expand All @@ -35,11 +44,13 @@
rest.router.register_model(
ItemType,
fields="__all__",
cache="all",
)
rest.router.register_model(
Item,
serializer=ItemSerializer,
fields="__all__",
cache="none",
)
rest.router.register_model(
GeometryModel,
Expand Down
5 changes: 5 additions & 0 deletions tests/templates/itemtype_list.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<ul>
{{#list}}
<li><a href="/itemtypes/{{id}}">{{label}}</a></li>
{{/list}}
</ul>
8 changes: 8 additions & 0 deletions tests/templates/usermanagedmodel_list.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{{#parent_id}}
<h3><a href="/{{parent_url}}">{{parent_label}}</a>'s Items</h3>
{{/parent_id}}
<ul>
{{#list}}
<li><a href="/usermanagedmodels/{{id}}">{{label}}</a></li>
{{/list}}
</ul>
101 changes: 101 additions & 0 deletions tests/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,82 @@ def test_rest_limit(self):
self.assertTrue(status.is_success(response.status_code), response.data)
self.assertEqual(response.data['per_page'], 10)

def test_rest_cache_all(self):
for num in range(2, 101):
ItemType.objects.create(
pk=num,
name='Type #%s' % num
)

tests = [
(100, '/itemtypes.json'),
(50, '/itemtypes/'),
(50, '/itemtypes.json?page=1'),
]
for expect_items, url in tests:
response = self.client.get(url)
self.assertTrue(
status.is_success(response.status_code), response.data
)
self.assertEqual(
len(response.data['list']), expect_items,
"%s should return %s items" % (url, expect_items)
)
self.assertEqual(response.data['pages'], 2)
self.assertEqual(response.data['per_page'], 50)
self.assertEqual(response.data['count'], 100)

def test_rest_cache_filter(self):
other_user = User.objects.create(username='otheruser')
UserManagedModel.objects.create(id=2, user=other_user)
UserManagedModel.objects.create(id=3, user=other_user)

for auth in False, True:
if auth:
self.client.force_authenticate(self.user)

tests = [
(3, 1 if auth else 0, '/usermanagedmodels.json'),
(3, 3, '/usermanagedmodels/'),
(3, 3, '/usermanagedmodels.json?page=1'),
(1, 1, '/usermanagedmodels.json?user_id=%s' % self.user.pk),
(2, 2, '/usermanagedmodels.json?user_id=%s' % other_user.pk),
]
for expect_count, expect_items, url in tests:
response = self.client.get(url)
self.assertTrue(
status.is_success(response.status_code), response.data
)
self.assertEqual(
len(response.data['list']), expect_items,
"%s should return %s items for %s" % (
url, expect_items,
"authed user" if auth else "anonymous user"
)
)
self.assertEqual(response.data['pages'], 1)
self.assertEqual(response.data['per_page'], 50)
self.assertEqual(response.data['count'], expect_count)

def test_rest_cache_none(self):
tests = [
(0, '/items.json'),
(2, '/items/'),
(2, '/items.json?page=1'),
]
for expect_items, url in tests:
response = self.client.get(url)
self.assertTrue(
status.is_success(response.status_code), response.data
)
self.assertEqual(
len(response.data['list']), expect_items,
"%s should return %s items" % (url, expect_items)
)
self.assertEqual(response.data['pages'], 1)
self.assertEqual(response.data['per_page'], 50)
self.assertEqual(response.data['count'], 2)

def test_rest_date_label(self):
response = self.client.get("/datemodels/1.json")
self.assertTrue(status.is_success(response.status_code), response.data)
Expand Down Expand Up @@ -330,6 +406,31 @@ def test_rest_model_conflict(self):
)
self.assertIn("conflictitem", rest.router.get_config()['pages'])

def test_rest_old_config(self):
from wq.db import rest
from tests.conflict_app.models import TestModel

with self.assertRaises(ImproperlyConfigured):
rest.router.register_model(
TestModel,
partial=True,
fields="__all__"
)

with self.assertRaises(ImproperlyConfigured):
rest.router.register_model(
TestModel,
reversed=True,
fields="__all__"
)

with self.assertRaises(ImproperlyConfigured):
rest.router.register_model(
TestModel,
max_local_pages=0,
fields="__all__"
)


class RestPostTestCase(APITestCase):
def setUp(self):
Expand Down

0 comments on commit 663bbda

Please sign in to comment.