/
generators.py
187 lines (151 loc) · 7.89 KB
/
generators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import inspect
from urllib.parse import urljoin
from django.urls import URLPattern, URLResolver
from rest_framework import views, viewsets
from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
from drf_spectacular.extensions import OpenApiViewExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import (
ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, error,
is_versioning_supported, modify_for_versioning, normalize_result_object,
operation_matches_version, reset_generator_stats, sanitize_result_object, warn,
)
from drf_spectacular.settings import spectacular_settings
class EndpointEnumerator(BaseEndpointEnumerator):
def get_api_endpoints(self, patterns=None, prefix=''):
api_endpoints = self._get_api_endpoints(patterns, prefix)
for hook in spectacular_settings.PREPROCESSING_HOOKS:
api_endpoints = hook(endpoints=api_endpoints)
api_endpoints_deduplicated = {}
for path, path_regex, method, callback in api_endpoints:
if (path, method) not in api_endpoints_deduplicated:
api_endpoints_deduplicated[path, method] = (path, path_regex, method, callback)
return sorted(api_endpoints_deduplicated.values(), key=alpha_operation_sorter)
def _get_api_endpoints(self, patterns, prefix):
"""
Return a list of all available API endpoints by inspecting the URL conf.
Only modification the the DRF version is passing through the path_regex.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + str(pattern.pattern)
if isinstance(pattern, URLPattern):
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
endpoint = (path, path_regex, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, URLResolver):
nested_endpoints = self._get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
return api_endpoints
class SchemaGenerator(BaseSchemaGenerator):
endpoint_inspector_cls = EndpointEnumerator
def __init__(self, *args, **kwargs):
self.registry = ComponentRegistry()
self.api_version = kwargs.pop('api_version', None)
self.inspector = None
super().__init__(*args, **kwargs)
def create_view(self, callback, method, request=None):
"""
customized create_view which is called when all routes are traversed. part of this
is instantiating views with default params. in case of custom routes (@action) the
custom AutoSchema is injected properly through 'initkwargs' on view. However, when
decorating plain views like retrieve, this initialization logic is not running.
Therefore forcefully set the schema if @extend_schema decorator was used.
"""
override_view = OpenApiViewExtension.get_match(callback.cls)
if override_view:
callback.cls = override_view.view_replacement()
view = super().create_view(callback, method, request)
if isinstance(view, viewsets.GenericViewSet) or isinstance(view, viewsets.ViewSet):
action = getattr(view, view.action)
elif isinstance(view, views.APIView):
action = getattr(view, method.lower())
else:
error(
'Using not supported View class. Class must be derived from APIView '
'or any of its subclasses like GenericApiView, GenericViewSet.'
)
return view
# in case of @extend_schema, manually init custom schema class here due to
# weakref reverse schema.view bug for multi annotations.
schema = getattr(action, 'kwargs', {}).get('schema', None)
if schema and inspect.isclass(schema):
view.schema = schema()
return view
def _initialise_endpoints(self):
if self.endpoints is None:
self.inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = self.inspector.get_api_endpoints()
def _get_paths_and_endpoints(self, request):
"""
Generate (path, method, view) given (path, method, callback) for paths.
"""
view_endpoints = []
for path, path_regex, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
path = self.coerce_path(path, method, view)
view_endpoints.append((path, path_regex, method, view))
return view_endpoints
def parse(self, request, public):
""" Iterate endpoints generating per method path operations. """
result = {}
self._initialise_endpoints()
for path, path_regex, method, view in self._get_paths_and_endpoints(None if public else request):
if not self.has_view_permissions(path, method, view):
continue
# mocked request to allow certain operations in get_queryset and get_serializer[_class]
# without exceptions being raised due to no request.
if not request:
request = spectacular_settings.GET_MOCK_REQUEST(method, path, view, request)
view.request = request
if view.versioning_class and not is_versioning_supported(view.versioning_class):
warn(
f'using unsupported versioning class "{view.versioning_class}". view will be '
f'processed as unversioned view.'
)
elif view.versioning_class:
version = (
self.api_version # generator was explicitly versioned
or getattr(request, 'version', None) # incoming request was versioned
or view.versioning_class.default_version # fallback
)
path = modify_for_versioning(self.inspector.patterns, method, path, view, version)
if not version or not operation_matches_version(view, version):
continue
assert isinstance(view.schema, AutoSchema), (
f'Incompatible AutoSchema used on View. Is DRF\'s DEFAULT_SCHEMA_CLASS '
f'pointing to "drf_spectacular.openapi.AutoSchema" or any other drf-spectacular '
f'compatible AutoSchema?'
)
operation = view.schema.get_operation(path, path_regex, method, self.registry)
# operation was manually removed via @extend_schema
if not operation:
continue
# Normalise path for any provided mount url.
if path.startswith('/'):
path = path[1:]
path = urljoin(self.url or '/', path)
if spectacular_settings.CAMELIZE_NAMES:
path, operation = camelize_operation(path, operation)
result.setdefault(path, {})
result[path][method.lower()] = operation
return result
def get_schema(self, request=None, public=False):
""" Generate a OpenAPI schema. """
reset_generator_stats()
result = build_root_object(
paths=self.parse(request, public),
components=self.registry.build(spectacular_settings.APPEND_COMPONENTS),
)
for hook in spectacular_settings.POSTPROCESSING_HOOKS:
result = hook(result=result, generator=self, request=request, public=public)
return sanitize_result_object(normalize_result_object(result))