Skip to content

Commit

Permalink
Add get_resource_kwargs method to viewsets
Browse files Browse the repository at this point in the history
It makes possible to pass some values in resource
class from viewset.

Also simplify ImportJobViewSet not to use filterset class
  • Loading branch information
NikAzanov committed Sep 11, 2023
1 parent 5244105 commit e508d66
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 62 deletions.
3 changes: 3 additions & 0 deletions docs/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ to implement import/export via API. Just create custom class with ``resource_cla
urlpatterns = band_import_export_router.urls
By default, all import/export jobs for the set ``resource_class`` will be available,
but you can override ``get_queryset`` method to change it. You can also override
``get_resource_kwargs`` method to provide some values in resource class (for ``start`` action).

These view sets provide all methods required for entire import/export workflow: start, details,
confirm, cancel and list actions. There is also `drf-spectacular <https://github.com/tfranzel/drf-spectacular>`_
Expand Down
14 changes: 6 additions & 8 deletions import_export_extensions/api/serializers/export_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ class CreateExportJob(serializers.Serializer):

def __init__(
self,
filter_kwargs: typing.Optional[dict[str, typing.Any]] = None,
*args,
filter_kwargs: typing.Optional[dict[str, typing.Any]] = None,
resource_kwargs: typing.Optional[dict[str, typing.Any]] = None,
**kwargs,
):
"""Set filter kwargs and current user."""
super().__init__(*args, **kwargs)
self._filter_kwargs: typing.Optional[dict[str, typing.Any]] = (
filter_kwargs
)
self._filter_kwargs = filter_kwargs
self._resource_kwargs = resource_kwargs or {}
self._request: request.Request = self.context.get("request")
self._user = getattr(self._request, "user", None)

Expand All @@ -86,11 +86,10 @@ def create(
]
return models.ExportJob.objects.create(
resource_path=self.resource_class.class_path,
file_format_path=(
f"{file_format_class.__module__}.{file_format_class.__name__}"
),
file_format_path=f"{file_format_class.__module__}.{file_format_class.__name__}",
resource_kwargs=dict(
filter_kwargs=self._filter_kwargs,
**self._resource_kwargs,
),
created_by=self._user,
)
Expand All @@ -108,7 +107,6 @@ class _CreateExportJob(CreateExportJob):
"""Serializer to start export job."""

resource_class: typing.Type[resources.CeleryModelResource] = resource

file_format = serializers.ChoiceField(
required=True,
choices=[
Expand Down
25 changes: 5 additions & 20 deletions import_export_extensions/api/serializers/import_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from rest_framework import request, serializers

from celery import states
from django_filters.utils import translate_validation

from ... import models, resources
from .progress import ProgressSerializer
Expand Down Expand Up @@ -50,30 +49,20 @@ class CreateImportJob(serializers.Serializer):

resource_class: typing.Type[resources.CeleryModelResource]

file = serializers.FileField(required=True)

def __init__(
self,
filter_kwargs: typing.Optional[dict[str, typing.Any]] = None,
*args,
resource_kwargs: typing.Optional[dict[str, typing.Any]] = None,
**kwargs,
):
"""Set filter kwargs and current user."""
super().__init__(*args, **kwargs)
self._filter_kwargs: typing.Optional[dict[str, typing.Any]] = (
filter_kwargs
)
self._request: request.Request = self.context.get("request")
self._resource_kwargs = resource_kwargs or {}
self._user = getattr(self._request, "user", None)

def validate(self, attrs: dict[str, typing.Any]) -> dict[str, typing.Any]:
"""Check that filter kwargs are valid."""
if self._filter_kwargs:
filter_instance = self.resource_class.filterset_class(
data=self._filter_kwargs,
)
if not filter_instance.is_valid():
raise translate_validation(error_dict=filter_instance.errors)
return attrs

def create(
self,
validated_data: dict[str, typing.Any],
Expand All @@ -82,9 +71,7 @@ def create(
return models.ImportJob.objects.create(
data_file=validated_data["file"],
resource_path=self.resource_class.class_path,
resource_kwargs=dict(
filter_kwargs=self._filter_kwargs,
),
resource_kwargs=self._resource_kwargs,
created_by=self._user,
)

Expand All @@ -102,8 +89,6 @@ class _CreateImportJob(CreateImportJob):

resource_class: typing.Type[resources.CeleryModelResource] = resource

file = serializers.FileField(required=True)

return type(
f"{resource.__name__}CreateImportJob",
(_CreateImportJob,),
Expand Down
10 changes: 10 additions & 0 deletions import_export_extensions/api/views/export_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ def get_queryset(self):
resource_path=self.resource_class.class_path,
)

def get_resource_kwargs(self) -> dict[str, typing.Any]:
"""Provide extra arguments to resource class."""
return {}

def get_serializer(self, *args, **kwargs):
"""Provide resource kwargs to serializer class."""
if self.action == "start":
kwargs.setdefault("resource_kwargs", self.get_resource_kwargs())
return super().get_serializer(*args, **kwargs)

def get_serializer_class(self):
"""Return special serializer on creation."""
if self.action == "start":
Expand Down
60 changes: 26 additions & 34 deletions import_export_extensions/api/views/import_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
viewsets,
)
from rest_framework.exceptions import ValidationError
from rest_framework.request import Request

import django_filters

from ... import models, resources
from .. import serializers
Expand All @@ -23,9 +20,7 @@ class ImportBase(type):
def __new__(cls, name, bases, attrs, **kwargs):
"""Dynamically create an import start api endpoint.
We need this to specify on fly action's filterset_class and queryset
(django-filters requires view's queryset and filterset_class's
queryset model to match). Also, if drf-spectacular is installed
If drf-spectacular is installed
specify request and response, and enable filters.
"""
Expand All @@ -41,40 +36,13 @@ def __new__(cls, name, bases, attrs, **kwargs):
if name == "ImportJobViewSet":
return viewset

def start(self: "ImportJobViewSet", request: Request):
"""Validate request data and start ImportJob."""
serializer = self.get_serializer(
data=request.data,
filter_kwargs=request.query_params,
)
serializer.is_valid(raise_exception=True)
import_job = serializer.save()
return response.Response(
data=self.get_detail_serializer_class()(
instance=import_job,
).data,
status=status.HTTP_201_CREATED,
)

viewset.start = decorators.action(
methods=["POST"],
detail=False,
queryset=viewset.resource_class.get_model_queryset(),
filterset_class=getattr(
viewset.resource_class, "filterset_class", None,
),
filter_backends=[
django_filters.rest_framework.DjangoFilterBackend,
],
)(start)
# Correct specs of drf-spectacular if it is installed
try:
from drf_spectacular.utils import extend_schema, extend_schema_view

detail_serializer_class = viewset().get_detail_serializer_class()
return extend_schema_view(
start=extend_schema(
filters=True,
request=viewset().get_import_create_serializer_class(),
responses={
status.HTTP_201_CREATED: detail_serializer_class,
Expand Down Expand Up @@ -125,14 +93,23 @@ class ImportJobViewSet(
resource_class: typing.Optional[
typing.Type[resources.CeleryModelResource]
] = None
filterset_class: django_filters.rest_framework.FilterSet = None

def get_queryset(self):
"""Filter import jobs by resource used in viewset."""
return super().get_queryset().filter(
resource_path=self.resource_class.class_path,
)

def get_resource_kwargs(self) -> dict[str, typing.Any]:
"""Provide extra arguments to resource class."""
return {}

def get_serializer(self, *args, **kwargs):
"""Provide resource kwargs to serializer class."""
if self.action == "start":
kwargs.setdefault("resource_kwargs", self.get_resource_kwargs())
return super().get_serializer(*args, **kwargs)

def get_serializer_class(self):
"""Return special serializer on creation."""
if self.action == "start":
Expand All @@ -149,6 +126,21 @@ def get_import_create_serializer_class(self):
self.resource_class,
)

@decorators.action(methods=["POST"], detail=False)
def start(self, request, *args, **kwargs):
"""Validate request data and start ImportJob."""
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)

import_job = serializer.save()

return response.Response(
data=self.get_detail_serializer_class()(
instance=import_job,
).data,
status=status.HTTP_201_CREATED,
)

@decorators.action(methods=["POST"], detail=True)
def confirm(self, *args, **kwargs):
"""Confirm import job that has `parsed` status."""
Expand Down
1 change: 1 addition & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
# Configure `drf-spectacular` to check it works for import-export API
REST_FRAMEWORK = {
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
"COMPONENT_SPLIT_REQUEST": True, # Allows to upload import file from Swagger UI
}

# Don't use celery when you're local
Expand Down

0 comments on commit e508d66

Please sign in to comment.