Skip to content

Commit 505447f

Browse files
author
Marcos Dione
committed
Merge branch 'support-viewsets' into local.
2 parents fccdce9 + cf64f80 commit 505447f

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

rest_framework_docs/api_endpoint.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import json
22
import inspect
3+
34
from django.contrib.admindocs.views import simplify_regex
45
from django.utils.encoding import force_str
6+
7+
from rest_framework.viewsets import ModelViewSet
58
from rest_framework.serializers import BaseSerializer
69

10+
VIEWSET_METHODS = {
11+
'List': ['get', 'post'],
12+
'Instance': ['get', 'put', 'patch', 'delete'],
13+
}
14+
715

816
class ApiEndpoint(object):
917

@@ -31,8 +39,14 @@ def __get_path__(self, parent_regex):
3139
return "/{0}{1}".format(self.name_parent, simplify_regex(self.pattern.regex.pattern))
3240
return simplify_regex(self.pattern.regex.pattern)
3341

34-
def __get_allowed_methods__(self):
42+
def is_method_allowed(self, callback_cls, method_name):
43+
has_attr = hasattr(callback_cls, method_name)
44+
viewset_method = (issubclass(callback_cls, ModelViewSet) and
45+
method_name in VIEWSET_METHODS.get(self.callback.suffix, []))
46+
47+
return has_attr or viewset_method
3548

49+
def __get_allowed_methods__(self):
3650
viewset_methods = []
3751
if self.drf_router:
3852
for prefix, viewset, basename in self.drf_router.registry:
@@ -57,14 +71,18 @@ def __get_allowed_methods__(self):
5771
)
5872
if self.pattern.regex.pattern == regex:
5973
funcs, viewset_methods = zip(
60-
*[(mapping[m], m.upper()) for m in self.callback.cls.http_method_names if m in mapping]
74+
*[(mapping[m], m.upper())
75+
for m in self.callback.cls.http_method_names
76+
if m in mapping]
6177
)
6278
viewset_methods = list(viewset_methods)
6379
if len(set(funcs)) == 1:
6480
self.docstring = inspect.getdoc(getattr(self.callback.cls, funcs[0]))
6581

66-
view_methods = [force_str(m).upper() for m in self.callback.cls.http_method_names if hasattr(self.callback.cls, m)]
67-
return viewset_methods + view_methods
82+
view_methods = [force_str(m).upper()
83+
for m in self.callback.cls.http_method_names
84+
if self.is_method_allowed(self.callback.cls, m)]
85+
return sorted(viewset_methods + view_methods)
6886

6987
def __get_docstring__(self):
7088
return inspect.getdoc(self.callback)

tests/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_index_view_with_endpoints(self):
3131

3232
# Test the login view
3333
self.assertEqual(response.context["endpoints"][0].name_parent, "accounts")
34-
self.assertEqual(response.context["endpoints"][0].allowed_methods, ['POST', 'OPTIONS'])
34+
self.assertEqual(sorted(response.context["endpoints"][0].allowed_methods), sorted(['OPTIONS', 'POST']))
3535
self.assertEqual(response.context["endpoints"][0].path, "/accounts/login/")
3636
self.assertEqual(response.context["endpoints"][0].docstring, "A view that allows users to login providing their username and password.")
3737
self.assertEqual(len(response.context["endpoints"][0].fields), 2)

0 commit comments

Comments
 (0)