Skip to content

Commit

Permalink
bugfix isolation for @extend_schema/@extend_schema_view reorg #554
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Oct 12, 2021
1 parent 4aaaa81 commit bee6c94
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 42 deletions.
31 changes: 28 additions & 3 deletions drf_spectacular/drainage.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,42 @@ def set_override(obj, prop, value):
return obj


def get_view_methods(view, schema=None):
def get_view_method_names(view, schema=None):
schema = schema or view.schema
return [
getattr(view, item) for item in dir(view) if callable(getattr(view, item)) and (
item for item in dir(view) if callable(getattr(view, item)) and (
item in view.http_method_names
or item in (schema or view.schema).method_mapping.values()
or item in schema.method_mapping.values()
or item == 'list'
or hasattr(getattr(view, item), 'mapping')
)
]


def isolate_view_method(view, method_name):
"""
Prevent modifying a view method which is derived from other views. Changes to
a derived method would leak into the view where the method originated from.
Break derivation by wrapping the method and explicitly setting it on the view.
"""
method = getattr(view, method_name)
# no isolation required as the view method is not derived
if method_name in view.__dict__:
return method

@functools.wraps(method)
def wrapped_method(self, request, *args, **kwargs):
return method(self, request, *args, **kwargs)

# wraps() will only create a shallow copy of method.__dict__. Updates to "kwargs"
# via @extend_schema would leak to the original method. Isolate by creating a copy.
if hasattr(method, 'kwargs'):
wrapped_method.kwargs = method.kwargs.copy()

setattr(view, method_name, wrapped_method)
return wrapped_method


def cache(user_function):
""" simple polyfill for python < 3.9 """
return functools.lru_cache(maxsize=None)(user_function)
43 changes: 16 additions & 27 deletions drf_spectacular/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import inspect
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
Expand All @@ -7,7 +6,9 @@
from rest_framework.serializers import Serializer
from rest_framework.settings import api_settings

from drf_spectacular.drainage import error, get_view_methods, set_override, warn
from drf_spectacular.drainage import (
error, get_view_method_names, isolate_view_method, set_override, warn,
)
from drf_spectacular.types import OpenApiTypes, _KnownPythonTypes

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -362,11 +363,13 @@ def get_extensions(self):
)
# reorder schema class MRO so that view method annotation takes precedence
# over view class annotation. only relevant if there is a method annotation
for view_method in get_view_methods(view=f, schema=BaseSchema):
if 'schema' in getattr(view_method, 'kwargs', {}):
view_method.kwargs['schema'] = type(
'ExtendedMetaSchema', (view_method.kwargs['schema'], ExtendedSchema), {}
)
for view_method_name in get_view_method_names(view=f, schema=BaseSchema):
if 'schema' not in getattr(getattr(f, view_method_name), 'kwargs', {}):
continue
view_method = isolate_view_method(f, view_method_name)
view_method.kwargs['schema'] = type(
'ExtendedMetaSchema', (view_method.kwargs['schema'], ExtendedSchema), {}
)
# persist schema on class to provide annotation to derived view methods.
# the second purpose is to serve as base for view multi-annotation
f.schema = ExtendedSchema()
Expand Down Expand Up @@ -472,35 +475,21 @@ def extend_schema_view(**kwargs) -> Callable[[F], F]:
:param kwargs: method names as argument names and :func:`@extend_schema <.extend_schema>`
calls as values
"""
def wrapping_decorator(method_decorator, method):
@functools.wraps(method)
def wrapped_method(self, request, *args, **kwargs):
return method(self, request, *args, **kwargs)

if hasattr(method, 'kwargs'):
wrapped_method.kwargs = method.kwargs.copy()

return method_decorator(wrapped_method)

def decorator(view):
view_methods = {m.__name__: m for m in get_view_methods(view)}
available_view_methods = get_view_method_names(view)

for method_name, method_decorator in kwargs.items():
if method_name not in view_methods:
if method_name not in available_view_methods:
warn(
f'@extend_schema_view argument "{method_name}" was not found on view '
f'{view.__name__}. method override for "{method_name}" will be ignored.'
)
continue

method = view_methods[method_name]
# the context of derived methods must not be altered, as it belongs to the other
# class. create a new context via the wrapping_decorator so the schema can be safely
# stored in the wrapped_method. methods belonging to the view can be safely altered.
if method_name in view.__dict__:
method_decorator(method)
else:
setattr(view, method_name, wrapping_decorator(method_decorator, method))
# the context of derived methods must not be altered, as it belongs to the
# other view. create a new context so the schema can be safely stored in the
# wrapped_method. view methods that are not derived can be safely altered.
method_decorator(isolate_view_method(view, method_name))
return view

return decorator
Expand Down
74 changes: 62 additions & 12 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2568,22 +2568,13 @@ def custom_action(self):
pass # pragma: no cover

schema = generate_schema('x', viewset=XViewSet)
schema['paths']['/x/{id}/custom_action/']['get']['summary'] == 'A custom action!'
assert schema['paths']['/x/{id}/custom_action/']['get']['summary'] == 'A custom action!'


def test_extend_schema_view_isolation(no_warnings):

class Animal(models.Model):
pass

class AnimalSerializer(serializers.ModelSerializer):
class Meta:
model = Animal
fields = '__all__'

class AnimalViewSet(viewsets.GenericViewSet):
serializer_class = AnimalSerializer
queryset = Animal.objects.all()
serializer_class = SimpleSerializer
queryset = SimpleModel.objects.all()

@action(detail=False)
def notes(self, request):
Expand All @@ -2604,3 +2595,62 @@ class InsectViewSet(AnimalViewSet):
schema = generate_schema(None, patterns=router.urls)
assert schema['paths']['/api/mammals/notes/']['get']['summary'] == 'List mammals.'
assert schema['paths']['/api/insects/notes/']['get']['summary'] == 'List insects.'


def test_extend_schema_view_layering(no_warnings):
class YSerializer(serializers.Serializer):
field = serializers.FloatField()

class ZSerializer(serializers.Serializer):
field = serializers.UUIDField()

class XViewSet(viewsets.ReadOnlyModelViewSet):
queryset = SimpleModel.objects.all()
serializer_class = SimpleSerializer

@extend_schema_view(retrieve=extend_schema(responses=YSerializer))
class YViewSet(XViewSet):
pass

@extend_schema_view(retrieve=extend_schema(responses=ZSerializer))
class ZViewSet(YViewSet):
pass

router = routers.SimpleRouter()
router.register('x', XViewSet)
router.register('y', YViewSet)
router.register('z', ZViewSet)
schema = generate_schema(None, patterns=router.urls)
resp = {
c: get_response_schema(schema['paths'][f'/{c.lower()}/{{id}}/']['get'])
for c in ['X', 'Y', 'Z']
}
assert resp['X'] == {'$ref': '#/components/schemas/Simple'}
assert resp['Y'] == {'$ref': '#/components/schemas/Y'}
assert resp['Z'] == {'$ref': '#/components/schemas/Z'}


def test_extend_schema_view_extend_schema_crosstalk(no_warnings):
class XSerializer(serializers.Serializer):
field = serializers.FloatField()

# extend_schema_view provokes decorator reordering in extend_schema
@extend_schema(tags=['X'])
@extend_schema_view(retrieve=extend_schema(responses=XSerializer))
class XViewSet(viewsets.ReadOnlyModelViewSet):
queryset = SimpleModel.objects.all()
serializer_class = SimpleSerializer

@extend_schema(tags=['Y'])
class YViewSet(XViewSet):
pass

router = routers.SimpleRouter()
router.register('x', XViewSet)
router.register('y', YViewSet)
schema = generate_schema(None, patterns=router.urls)
op = {
c: schema['paths'][f'/{c.lower()}/{{id}}/']['get'] for c in ['X', 'Y']
}
assert op['X']['tags'] == ['X']
assert op['Y']['tags'] == ['Y']

0 comments on commit bee6c94

Please sign in to comment.