From a43c29b4a7a888fda643356f38dcc0586adc7404 Mon Sep 17 00:00:00 2001 From: "S. Andrew Sheppard" Date: Thu, 28 Mar 2019 10:38:41 +0900 Subject: [PATCH] option to set filename in view (fixes #31) --- README.md | 12 ++++++++++++ rest_pandas/views.py | 37 ++++++++++++++++++++++++++++++++++--- runserver.sh | 2 ++ tests/settings.py | 4 +++- tests/test_excel.py | 8 ++++++++ tests/testapp/views.py | 6 ++++++ 6 files changed, 65 insertions(+), 4 deletions(-) create mode 100755 runserver.sh diff --git a/README.md b/README.md index 817fa9b..30ae392 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,18 @@ class TimeSeriesView(PandasView): renderer_classes = [PandasCSVRenderer, PandasExcelRenderer] # You can also set the default renderers for all of your pandas views by # defining the PANDAS_RENDERERS in your settings.py. + + # Step 5 (Optional). The default filename may not be particularly useful + # for your users. To override, define get_pandas_filename() on your view. + # If a filename is returned, rest_pandas will include the following header: + # 'Content-Disposition: attachment; filename="Data Export.xlsx"' + def get_pandas_filename(self, request, format): + if format in ('xls', 'xlsx'): + # Use custom filename and Content-Disposition header + return "Data Export" # Extension will be appended automatically + else: + # Default filename from URL (no Content-Disposition header) + return None ``` #### Django Pandas Integration diff --git a/rest_pandas/views.py b/rest_pandas/views.py index 75bf63d..b6edee4 100644 --- a/rest_pandas/views.py +++ b/rest_pandas/views.py @@ -75,6 +75,31 @@ def get_serializer_class(self): else: return self.serializer_class + def get_pandas_filename(self, request, format): + return None + + def get_pandas_headers(self, request): + format = request.accepted_renderer.format + filename = self.get_pandas_filename(request, format) + if not filename: + return {} + + extension = '.' + format + if not filename.endswith(extension): + filename += extension + + return { + 'Content-Disposition': 'attachment; filename="{}"'.format( + filename + ) + } + + def update_pandas_headers(self, response): + headers = self.get_pandas_headers(self.request) + for key, val in headers.items(): + response[key] = val + return response + class PandasViewBase(PandasMixin): renderer_classes = PANDAS_RENDERERS @@ -97,18 +122,24 @@ def get(self, request, *args, **kwargs): data = self.get_data(request, *args, **kwargs) serializer_class = self.get_serializer_class() serializer = serializer_class(data, many=True) - return Response(serializer.data) + response = Response(serializer.data) + return self.update_pandas_headers(response) class PandasView(PandasViewBase, ListAPIView): """ Pandas-capable model list view """ - pass + + def list(self, request, *args, **kwargs): + response = super(PandasView, self).list(request, *args, **kwargs) + return self.update_pandas_headers(response) class PandasViewSet(PandasViewBase, ListModelMixin, GenericViewSet): """ Pandas-capable model ViewSet (list only) """ - pass + def list(self, request, *args, **kwargs): + response = super(PandasViewSet, self).list(request, *args, **kwargs) + return self.update_pandas_headers(response) diff --git a/runserver.sh b/runserver.sh new file mode 100755 index 0000000..974ba1b --- /dev/null +++ b/runserver.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python3 -m django runserver --settings=tests.settings diff --git a/tests/settings.py b/tests/settings.py index 921b8c8..8e65fa9 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -8,7 +8,7 @@ DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': ':memory:', + 'NAME': 'rest_pandas_test.sqlite3', } } ROOT_URLCONF = "tests.urls" @@ -19,6 +19,8 @@ }, ] +DEBUG = True + try: import matplotlib # noqa diff --git a/tests/test_excel.py b/tests/test_excel.py index d8e35c1..92a7ae0 100644 --- a/tests/test_excel.py +++ b/tests/test_excel.py @@ -17,6 +17,10 @@ def setUp(self): def test_xls(self): response = self.client.get("/timeseries.xls") + self.assertEqual( + 'attachment; filename="Time Series.xls"', + response['content-disposition'], + ) xlfile = open('tests/output.xls', 'wb') xlfile.write(response.content) xlfile.close() @@ -28,6 +32,10 @@ def test_xls(self): def test_xlsx(self): response = self.client.get("/timeseries.xlsx") + self.assertEqual( + 'attachment; filename="Time Series.xlsx"', + response['content-disposition'], + ) xlfile = open('tests/output.xlsx', 'wb') xlfile.write(response.content) xlfile.close() diff --git a/tests/testapp/views.py b/tests/testapp/views.py index 4cfed7d..5c2c9b4 100644 --- a/tests/testapp/views.py +++ b/tests/testapp/views.py @@ -43,6 +43,12 @@ class TimeSeriesView(PandasView): def get_template_context(self, data): return {'name': data['name'] + ' Custom'} + def get_pandas_filename(self, request, format): + if format in ('xls', 'xlsx'): + return self.get_view_name() + else: + return None + class TimeSeriesNoIdView(PandasView): queryset = TimeSeries.objects.all()