Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion django_restql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__title__ = 'Django RESTQL'
__description__ = 'Turn your API made with Django REST Framework(DRF) into a GraphQL like API.'
__url__ = 'https://yezyilomo.github.io/django-restql'
__version__ = '1.3.0'
__version__ = '1.4.0'
__author__ = 'Yezy Ilomo'
__author_email__ = 'yezileliilomo@hotmail.com'
__license__ = 'MIT'
Expand Down
17 changes: 15 additions & 2 deletions django_restql/mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from django.db.models import Prefetch
from django.db.models import Prefetch, QuerySet
from django.core.exceptions import ObjectDoesNotExist
from django.db.models.fields.related import ManyToManyRel, ManyToOneRel
from django.http import QueryDict
Expand Down Expand Up @@ -1049,6 +1049,7 @@ class OptimizedEagerLoadingMixin(EagerLoadingMixin):
always_apply_only = False
force_query_usage = None
to_select = []
annotated_fields = []

@property
def should_always_apply_only(self):
Expand Down Expand Up @@ -1187,7 +1188,8 @@ def get_queryset(self):
raise ValidationError(
_(f"'{query_param_name}' must be defined in query params.")
)
return super().get_queryset()
queryset = super().get_queryset()
return self.annotate_fields(queryset=queryset)

def parse_model_fields(self, model, fields, skip_non_model_fields=True):
results = []
Expand All @@ -1213,3 +1215,14 @@ def parse_model_fields(self, model, fields, skip_non_model_fields=True):
results.append(field)

return results

def annotate_fields(self, queryset: QuerySet) -> QuerySet:
for field_name in self.annotated_fields:
if self.should_annotate_field(field_name):
queryset = getattr(self, f"annotate_{field_name}")(queryset)

return queryset

def should_annotate_field(self, field_name: str) -> bool:
query = self.get_dict_parsed_restql_query(self.parsed_restql_query)
return "*" in query or field_name in query
91 changes: 73 additions & 18 deletions tests/testapp/tests/test_optimized_views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.models import Value
from django.db.models.functions import Concat
from django.urls import path, reverse
from django_restql.mixins import DynamicFieldsMixin, OptimizedEagerLoadingMixin
from model_bakery import baker
Expand Down Expand Up @@ -44,6 +47,15 @@ def get_first_letter(self, obj):
return obj.title[0]


class SamplePostWithAnnotationSerializer(DynamicFieldsMixin, serializers.ModelSerializer):
author_full_name = serializers.CharField()

class Meta:
model = SamplePost
fields = ("id", "author_full_name")
read_only_fields = fields


class SampleTagSerializer(DynamicFieldsMixin, serializers.ModelSerializer):
class Meta:
model = SampleTag
Expand Down Expand Up @@ -82,6 +94,27 @@ class SampleViewSet(
always_apply_only = True


class SamplePostWithAnnotationView(
OptimizedEagerLoadingMixin, UpdateAPIView, ListAPIView
):
queryset = SamplePost.objects.all()
serializer_class = SamplePostWithAnnotationSerializer
permission_classes = []
only = {"author_full_name": []}
pagination_class = None
always_apply_only = True
annotated_fields = ("author_full_name",)

@staticmethod
def annotate_author_full_name(queryset):
return queryset.annotate(
author_full_name=Concat(
"author__first_name", Value(" "), "author__last_name",
output_field=models.CharField(),
),
)


class SampleAuthorViewSet(
OptimizedEagerLoadingMixin,
ListAPIView,
Expand All @@ -99,10 +132,16 @@ class SampleAuthorViewSet(
path("", SampleViewSet.as_view(), name="view"),
path("<int:pk>", SampleViewSet.as_view(), name="view-update"),
path("authors/", SampleAuthorViewSet.as_view(), name="authors-view"),
path(
"posts-with-annotation/",
SamplePostWithAnnotationView.as_view(),
name="posts-annotation-view",
),
]


@pytest.mark.django_db
@pytest.mark.urls(__name__)
class TestOnlyInEagerLoading:
@pytest.fixture
def instance(self):
Expand All @@ -112,7 +151,6 @@ def instance(self):
def url(self):
return reverse("view")

@pytest.mark.urls(__name__)
def test_fields_correctly_selected(
self, client, django_assert_max_num_queries, instance, url
):
Expand All @@ -126,13 +164,11 @@ def test_fields_correctly_selected(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_multiple_records(self, client, django_assert_max_num_queries, url):
baker.make(SamplePost, _quantity=10)
with django_assert_max_num_queries(1):
client.get(url, {"query": "{title,author{id}}"})

@pytest.mark.urls(__name__)
def test_no_query(self, client, django_assert_max_num_queries, instance, url):
with django_assert_max_num_queries(1) as x:
response = client.get(url)
Expand All @@ -148,7 +184,6 @@ def test_no_query(self, client, django_assert_max_num_queries, instance, url):
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_using_all_fields(
self, client, django_assert_max_num_queries, instance, url
):
Expand All @@ -165,7 +200,6 @@ def test_using_all_fields(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_custom_only(self, client, django_assert_max_num_queries, instance, url):
with django_assert_max_num_queries(1) as x:
client.get(url, {"query": "{first_letter}"})
Expand All @@ -175,7 +209,6 @@ def test_custom_only(self, client, django_assert_max_num_queries, instance, url)
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_custom_only_in_foreign_key(
self, client, django_assert_max_num_queries, instance, url
):
Expand All @@ -189,7 +222,6 @@ def test_custom_only_in_foreign_key(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_using_nested_all_fields(
self, client, django_assert_max_num_queries, instance, url
):
Expand All @@ -204,7 +236,6 @@ def test_using_nested_all_fields(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_using_exclude_operator(
self, client, django_assert_max_num_queries, instance, url
):
Expand All @@ -222,15 +253,13 @@ def test_using_exclude_operator(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_incorrect_view(
self, client, django_assert_max_num_queries, instance, url, monkeypatch
):
monkeypatch.setattr(SampleViewSet, "only", {"author_str": "author__first_name"})
with pytest.raises(FieldDoesNotExist):
client.get(url)

@pytest.mark.urls(__name__)
def test_m2m_field(
self, client, django_assert_max_num_queries, instance, url, monkeypatch
):
Expand All @@ -252,7 +281,6 @@ def test_m2m_field(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_m2m_field_no_query(
self, client, django_assert_max_num_queries, instance, url, monkeypatch
):
Expand All @@ -279,7 +307,6 @@ def test_m2m_field_no_query(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_no_query_only_serializer_fields(
self, client, django_assert_max_num_queries, instance, url, monkeypatch
):
Expand All @@ -297,7 +324,6 @@ def test_no_query_only_serializer_fields(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_force_query_usage(
self, client, instance, url, settings
):
Expand All @@ -318,7 +344,6 @@ def test_force_query_usage(
(True, False, status.HTTP_200_OK),
)
)
@pytest.mark.urls(__name__)
def test_force_query_usage_defined_in_view(
self,
client,
Expand All @@ -336,7 +361,6 @@ def test_force_query_usage_defined_in_view(
response = client.get(url)
assert response.status_code == status_code

@pytest.mark.urls(__name__)
def test_dont_force_query_usage_on_put_method(
self, client, instance, url, settings
):
Expand All @@ -346,7 +370,6 @@ def test_dont_force_query_usage_on_put_method(
response = client.put(url)
assert response.status_code == status.HTTP_200_OK

@pytest.mark.urls(__name__)
def test_using_aliases(
self, client, django_assert_max_num_queries, instance, url, settings
):
Expand All @@ -364,7 +387,6 @@ def test_using_aliases(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_incorrect_parameters(
self, client, django_assert_max_num_queries, instance, url, settings
):
Expand All @@ -382,7 +404,6 @@ def test_incorrect_parameters(
}
assert expected_fields == get_fields_queried(x)

@pytest.mark.urls(__name__)
def test_many_to_one_rel_ignored_when_no_query(
self, client, django_assert_max_num_queries, instance
):
Expand All @@ -396,3 +417,37 @@ def test_many_to_one_rel_ignored_when_no_query(
"sampleauthor.id",
}
assert expected_fields == get_fields_queried(x)


@pytest.mark.django_db
@pytest.mark.urls(__name__)
class TestSamplePostWithAnnotationView:
@pytest.fixture
def url(self):
return reverse("posts-annotation-view")

def test_annotated_author_full_name(self, client, url):
baker.make(SamplePost, author__first_name="John", author__last_name="Doe")

response = client.get(url)

assert response.status_code == status.HTTP_200_OK
assert response.data[0]["author_full_name"] == "John Doe"

def test_query_author_full_name(self, client, url):
baker.make(SamplePost, author__first_name="John", author__last_name="Doe")

response = client.get(url, {"query": "{author_full_name}"})

assert response.status_code == status.HTTP_200_OK
assert response.data[0]["author_full_name"] == "John Doe"

def test_author_full_name_not_queried(self, client, url, django_assert_max_num_queries):
baker.make(SamplePost, author__first_name="John", author__last_name="Doe")

with django_assert_max_num_queries(1) as x:
client.get(url, {"query": "{id}"})
expected_fields = {
"samplepost.id",
}
assert expected_fields == get_fields_queried(x)
Loading