Skip to content

Commit

Permalink
Fix annotation propagation for non-filter queries (#1590) (#1593)
Browse files Browse the repository at this point in the history
* Fix annotation propagation for non-filter queries (#1590)

* Fix lint

* Fix test
  • Loading branch information
abondar committed Apr 24, 2024
1 parent 3644d67 commit 78ef3dd
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 18 deletions.
17 changes: 10 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ aiomysql = { version = "*", optional = true }
asyncmy = { version = "^0.2.8", optional = true, allow-prereleases = true }
psycopg = { extras = ["pool", "binary"], version = "^3.0.12", optional = true }
asyncodbc = { version = "^0.1.1", optional = true }
pydantic = { version = "^2.0,!=2.7.0", optional = true }

[tool.poetry.dev-dependencies]
# Linter tools
Expand All @@ -72,7 +73,7 @@ sanic = "*"
# Sample integration - Starlette
starlette = "*"
# Pydantic support
pydantic = "^2.0"
pydantic = "^2.0,!=2.7.0"
# FastAPI support
fastapi = "^0.100.0"
asgi_lifespan = "*"
Expand Down
76 changes: 75 additions & 1 deletion tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ async def test_aggregation(self):
await Event.all().annotate(tournament_test_id=Sum("tournament__id")).first()
)
self.assertEqual(
event_with_annotation.tournament_test_id, event_with_annotation.tournament_id
event_with_annotation.tournament_test_id,
event_with_annotation.tournament_id,
)

with self.assertRaisesRegex(ConfigurationError, "name__id not resolvable"):
Expand Down Expand Up @@ -162,3 +163,76 @@ async def test_concat_functions(self):
.values("long_info")
)
self.assertEqual(ret, [{"long_info": "Physics Book(physics)"}])

async def test_count_after_aggregate(self):
author = await Author.create(name="1")
await Book.create(name="First!", author=author, rating=4)
await Book.create(name="Second!", author=author, rating=3)
await Book.create(name="Third!", author=author, rating=3)

author2 = await Author.create(name="2")
await Book.create(name="F-2", author=author2, rating=3)
await Book.create(name="F-3", author=author2, rating=3)

author3 = await Author.create(name="3")
await Book.create(name="F-4", author=author3, rating=3)
await Book.create(name="F-5", author=author3, rating=2)
ret = (
await Author.all()
.annotate(average_rating=Avg("books__rating"))
.filter(average_rating__gte=3)
.count()
)

assert ret == 2

async def test_exist_after_aggregate(self):
author = await Author.create(name="1")
await Book.create(name="First!", author=author, rating=4)
await Book.create(name="Second!", author=author, rating=3)
await Book.create(name="Third!", author=author, rating=3)

ret = (
await Author.all()
.annotate(average_rating=Avg("books__rating"))
.filter(average_rating__gte=3)
.exists()
)

assert ret is True

ret = (
await Author.all()
.annotate(average_rating=Avg("books__rating"))
.filter(average_rating__gte=4)
.exists()
)
assert ret is False

async def test_count_after_aggregate_m2m(self):
tournament = await Tournament.create(name="1")
event1 = await Event.create(name="First!", tournament=tournament)
event2 = await Event.create(name="Second!", tournament=tournament)
event3 = await Event.create(name="Third!", tournament=tournament)
event4 = await Event.create(name="Fourth!", tournament=tournament)

team1 = await Team.create(name="1")
team2 = await Team.create(name="2")
team3 = await Team.create(name="3")

await event1.participants.add(team1, team2, team3)
await event2.participants.add(team1, team2)
await event3.participants.add(team1)
await event4.participants.add(team1, team2, team3)

query = (
Event.filter(participants__id__in=[team1.id, team2.id, team3.id])
.annotate(count=Count("event_id"))
.filter(count=3)
.prefetch_related("participants")
)
result = await query
assert len(result) == 2

res = await query.count()
assert res == 2
4 changes: 2 additions & 2 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ async def test_force_index_available_in_more_query(self):
sql_CountQuery = IntFields.filter(pk=1).force_index("index_name").count().sql()
self.assertEqual(
sql_CountQuery,
"SELECT COUNT(*) FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1",
"SELECT COUNT('*') FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1",
)

sql_ExistsQuery = IntFields.filter(pk=1).force_index("index_name").exists().sql()
Expand Down Expand Up @@ -504,7 +504,7 @@ async def test_use_index_available_in_more_query(self):
sql_CountQuery = IntFields.filter(pk=1).use_index("index_name").count().sql()
self.assertEqual(
sql_CountQuery,
"SELECT COUNT(*) FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1",
"SELECT COUNT('*') FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1",
)

sql_ExistsQuery = IntFields.filter(pk=1).use_index("index_name").exists().sql()
Expand Down
20 changes: 13 additions & 7 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
)

from pypika import JoinType, Order, Table
from pypika.functions import Cast, Count
from pypika.analytics import Count
from pypika.functions import Cast
from pypika.queries import QueryBuilder
from pypika.terms import Case, Field, Term, ValueWrapper
from typing_extensions import Literal, Protocol
Expand Down Expand Up @@ -131,7 +132,7 @@ def resolve_filters(
:param annotations: Extra annotations to add.
:param custom_filters: Pre-resolved filters to be passed through.
"""
has_aggregate = self._resolve_annotate()
has_aggregate = self._resolve_annotate(annotations)

modifier = QueryModifier()
for node in q_objects:
Expand Down Expand Up @@ -236,13 +237,14 @@ def resolve_ordering(

self.query = self.query.orderby(field, order=ordering[1])

def _resolve_annotate(self) -> bool:
if not self._annotations:
def _resolve_annotate(self, extra_annotations: Dict[str, Any]) -> bool:
if not self._annotations and not extra_annotations:
return False

table = self.model._meta.basetable
all_annotations = {**self._annotations, **extra_annotations}
annotation_info = {}
for key, annotation in self._annotations.items():
for key, annotation in all_annotations.items():
if isinstance(annotation, Term):
annotation_info[key] = {"joins": [], "field": annotation}
else:
Expand All @@ -251,7 +253,8 @@ def _resolve_annotate(self) -> bool:
for key, info in annotation_info.items():
for join in info["joins"]:
self._join_table_by_field(*join)
self.query._select_other(info["field"].as_(key))
if key in self._annotations:
self.query._select_other(info["field"].as_(key))

return any(info["field"].is_aggregate for info in annotation_info.values())

Expand Down Expand Up @@ -1282,7 +1285,10 @@ def _make_query(self) -> None:
annotations=self.annotations,
custom_filters=self.custom_filters,
)
self.query._select_other(Count("*"))
count_term = Count("*")
if self.query._groupbys:
count_term = count_term.over()
self.query._select_other(count_term)

if self.force_indexes:
self.query._force_indexes = []
Expand Down

0 comments on commit 78ef3dd

Please sign in to comment.