diff --git a/django_restql/__init__.py b/django_restql/__init__.py index 4bf2bb2..7ffa8f0 100644 --- a/django_restql/__init__.py +++ b/django_restql/__init__.py @@ -1,7 +1,7 @@ __title__ = 'Django RESTQL' __description__ = 'Turn your API made with Django REST Framework(DRF) into a GraphQL like API.' __url__ = 'https://yezyilomo.github.io/django-restql' -__version__ = '1.3.0' +__version__ = '1.4.0' __author__ = 'Yezy Ilomo' __author_email__ = 'yezileliilomo@hotmail.com' __license__ = 'MIT' diff --git a/django_restql/mixins.py b/django_restql/mixins.py index a423447..0b6eaa2 100644 --- a/django_restql/mixins.py +++ b/django_restql/mixins.py @@ -1,4 +1,4 @@ -from django.db.models import Prefetch +from django.db.models import Prefetch, QuerySet from django.core.exceptions import ObjectDoesNotExist from django.db.models.fields.related import ManyToManyRel, ManyToOneRel from django.http import QueryDict @@ -1049,6 +1049,7 @@ class OptimizedEagerLoadingMixin(EagerLoadingMixin): always_apply_only = False force_query_usage = None to_select = [] + annotated_fields = [] @property def should_always_apply_only(self): @@ -1187,7 +1188,8 @@ def get_queryset(self): raise ValidationError( _(f"'{query_param_name}' must be defined in query params.") ) - return super().get_queryset() + queryset = super().get_queryset() + return self.annotate_fields(queryset=queryset) def parse_model_fields(self, model, fields, skip_non_model_fields=True): results = [] @@ -1213,3 +1215,14 @@ def parse_model_fields(self, model, fields, skip_non_model_fields=True): results.append(field) return results + + def annotate_fields(self, queryset: QuerySet) -> QuerySet: + for field_name in self.annotated_fields: + if self.should_annotate_field(field_name): + queryset = getattr(self, f"annotate_{field_name}")(queryset) + + return queryset + + def should_annotate_field(self, field_name: str) -> bool: + query = self.get_dict_parsed_restql_query(self.parsed_restql_query) + return "*" in query or field_name in query diff --git a/tests/testapp/tests/test_optimized_views.py b/tests/testapp/tests/test_optimized_views.py index 5ee86f0..13963c2 100644 --- a/tests/testapp/tests/test_optimized_views.py +++ b/tests/testapp/tests/test_optimized_views.py @@ -1,5 +1,8 @@ import pytest from django.core.exceptions import FieldDoesNotExist +from django.db import models +from django.db.models import Value +from django.db.models.functions import Concat from django.urls import path, reverse from django_restql.mixins import DynamicFieldsMixin, OptimizedEagerLoadingMixin from model_bakery import baker @@ -44,6 +47,15 @@ def get_first_letter(self, obj): return obj.title[0] +class SamplePostWithAnnotationSerializer(DynamicFieldsMixin, serializers.ModelSerializer): + author_full_name = serializers.CharField() + + class Meta: + model = SamplePost + fields = ("id", "author_full_name") + read_only_fields = fields + + class SampleTagSerializer(DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = SampleTag @@ -82,6 +94,27 @@ class SampleViewSet( always_apply_only = True +class SamplePostWithAnnotationView( + OptimizedEagerLoadingMixin, UpdateAPIView, ListAPIView +): + queryset = SamplePost.objects.all() + serializer_class = SamplePostWithAnnotationSerializer + permission_classes = [] + only = {"author_full_name": []} + pagination_class = None + always_apply_only = True + annotated_fields = ("author_full_name",) + + @staticmethod + def annotate_author_full_name(queryset): + return queryset.annotate( + author_full_name=Concat( + "author__first_name", Value(" "), "author__last_name", + output_field=models.CharField(), + ), + ) + + class SampleAuthorViewSet( OptimizedEagerLoadingMixin, ListAPIView, @@ -99,10 +132,16 @@ class SampleAuthorViewSet( path("", SampleViewSet.as_view(), name="view"), path("", SampleViewSet.as_view(), name="view-update"), path("authors/", SampleAuthorViewSet.as_view(), name="authors-view"), + path( + "posts-with-annotation/", + SamplePostWithAnnotationView.as_view(), + name="posts-annotation-view", + ), ] @pytest.mark.django_db +@pytest.mark.urls(__name__) class TestOnlyInEagerLoading: @pytest.fixture def instance(self): @@ -112,7 +151,6 @@ def instance(self): def url(self): return reverse("view") - @pytest.mark.urls(__name__) def test_fields_correctly_selected( self, client, django_assert_max_num_queries, instance, url ): @@ -126,13 +164,11 @@ def test_fields_correctly_selected( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_multiple_records(self, client, django_assert_max_num_queries, url): baker.make(SamplePost, _quantity=10) with django_assert_max_num_queries(1): client.get(url, {"query": "{title,author{id}}"}) - @pytest.mark.urls(__name__) def test_no_query(self, client, django_assert_max_num_queries, instance, url): with django_assert_max_num_queries(1) as x: response = client.get(url) @@ -148,7 +184,6 @@ def test_no_query(self, client, django_assert_max_num_queries, instance, url): } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_using_all_fields( self, client, django_assert_max_num_queries, instance, url ): @@ -165,7 +200,6 @@ def test_using_all_fields( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_custom_only(self, client, django_assert_max_num_queries, instance, url): with django_assert_max_num_queries(1) as x: client.get(url, {"query": "{first_letter}"}) @@ -175,7 +209,6 @@ def test_custom_only(self, client, django_assert_max_num_queries, instance, url) } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_custom_only_in_foreign_key( self, client, django_assert_max_num_queries, instance, url ): @@ -189,7 +222,6 @@ def test_custom_only_in_foreign_key( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_using_nested_all_fields( self, client, django_assert_max_num_queries, instance, url ): @@ -204,7 +236,6 @@ def test_using_nested_all_fields( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_using_exclude_operator( self, client, django_assert_max_num_queries, instance, url ): @@ -222,7 +253,6 @@ def test_using_exclude_operator( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_incorrect_view( self, client, django_assert_max_num_queries, instance, url, monkeypatch ): @@ -230,7 +260,6 @@ def test_incorrect_view( with pytest.raises(FieldDoesNotExist): client.get(url) - @pytest.mark.urls(__name__) def test_m2m_field( self, client, django_assert_max_num_queries, instance, url, monkeypatch ): @@ -252,7 +281,6 @@ def test_m2m_field( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_m2m_field_no_query( self, client, django_assert_max_num_queries, instance, url, monkeypatch ): @@ -279,7 +307,6 @@ def test_m2m_field_no_query( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_no_query_only_serializer_fields( self, client, django_assert_max_num_queries, instance, url, monkeypatch ): @@ -297,7 +324,6 @@ def test_no_query_only_serializer_fields( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_force_query_usage( self, client, instance, url, settings ): @@ -318,7 +344,6 @@ def test_force_query_usage( (True, False, status.HTTP_200_OK), ) ) - @pytest.mark.urls(__name__) def test_force_query_usage_defined_in_view( self, client, @@ -336,7 +361,6 @@ def test_force_query_usage_defined_in_view( response = client.get(url) assert response.status_code == status_code - @pytest.mark.urls(__name__) def test_dont_force_query_usage_on_put_method( self, client, instance, url, settings ): @@ -346,7 +370,6 @@ def test_dont_force_query_usage_on_put_method( response = client.put(url) assert response.status_code == status.HTTP_200_OK - @pytest.mark.urls(__name__) def test_using_aliases( self, client, django_assert_max_num_queries, instance, url, settings ): @@ -364,7 +387,6 @@ def test_using_aliases( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_incorrect_parameters( self, client, django_assert_max_num_queries, instance, url, settings ): @@ -382,7 +404,6 @@ def test_incorrect_parameters( } assert expected_fields == get_fields_queried(x) - @pytest.mark.urls(__name__) def test_many_to_one_rel_ignored_when_no_query( self, client, django_assert_max_num_queries, instance ): @@ -396,3 +417,37 @@ def test_many_to_one_rel_ignored_when_no_query( "sampleauthor.id", } assert expected_fields == get_fields_queried(x) + + +@pytest.mark.django_db +@pytest.mark.urls(__name__) +class TestSamplePostWithAnnotationView: + @pytest.fixture + def url(self): + return reverse("posts-annotation-view") + + def test_annotated_author_full_name(self, client, url): + baker.make(SamplePost, author__first_name="John", author__last_name="Doe") + + response = client.get(url) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["author_full_name"] == "John Doe" + + def test_query_author_full_name(self, client, url): + baker.make(SamplePost, author__first_name="John", author__last_name="Doe") + + response = client.get(url, {"query": "{author_full_name}"}) + + assert response.status_code == status.HTTP_200_OK + assert response.data[0]["author_full_name"] == "John Doe" + + def test_author_full_name_not_queried(self, client, url, django_assert_max_num_queries): + baker.make(SamplePost, author__first_name="John", author__last_name="Doe") + + with django_assert_max_num_queries(1) as x: + client.get(url, {"query": "{id}"}) + expected_fields = { + "samplepost.id", + } + assert expected_fields == get_fields_queried(x)