Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/unit/search/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_project_docs(db_session):
).description.raw,
},
}
for p, prs in sorted(releases.items(), key=lambda x: x[0].id)
for p, prs in sorted(releases.items(), key=lambda x: x[0].name)
]


Expand Down
15 changes: 6 additions & 9 deletions tests/unit/utils/db/test_windowed_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import pytest

from sqlalchemy import select
from sqlalchemy.orm import aliased

from warehouse.packaging.models import Project
from warehouse.utils.db.windowed_query import windowed_query
Expand All @@ -30,15 +29,13 @@ def test_windowed_query(db_session, query_recorder, window_size):

expected = math.ceil(len(projects) / window_size) + 1

subquery = select(Project.normalized_name).order_by(Project.id).subquery()
pa = aliased(Project, subquery)

query = select(Project.name).select_from(pa).distinct(Project.id)
query = select(Project)

result_set = set()
with query_recorder:
assert (
set(windowed_query(db_session, query, Project.id, window_size))
== project_set
)
for result in windowed_query(db_session, query, Project.id, window_size):
for project in result.scalars():
result_set.add((project.name, project.id))

assert result_set == project_set
assert len(query_recorder.queries) == expected
15 changes: 8 additions & 7 deletions warehouse/search/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,14 @@ def _project_docs(db, project_name=None):
.outerjoin(Release.project)
)

for release in windowed_query(db, release_data, Project.id, 25000):
p = ProjectDocument.from_db(release)
p._index = None
p.full_clean()
doc = p.to_dict(include_meta=True)
doc.pop("_index", None)
yield doc
for chunk in windowed_query(db, release_data, Project.name, 25000):
for release in chunk:
p = ProjectDocument.from_db(release)
p._index = None
p.full_clean()
doc = p.to_dict(include_meta=True)
doc.pop("_index", None)
yield doc


class SearchLock(Lock):
Expand Down
74 changes: 59 additions & 15 deletions warehouse/utils/db/windowed_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,66 @@

# Taken from "Theatrum Chemicum" at
# https://github.com/sqlalchemy/sqlalchemy/wiki/RangeQuery-and-WindowedRangeQuery
# updated from, with minor tweaks:
# https://github.com/sqlalchemy/sqlalchemy/discussions/7948#discussioncomment-2597083

from __future__ import annotations

def windowed_query(s, q, column, windowsize):
"""Break a Query into chunks on a given column."""
import typing

q = q.add_columns(column).order_by(column)
last_id = None
from collections.abc import Iterator
from typing import Any

while True:
subq = q
if last_id is not None:
subq = subq.filter(column > last_id)
chunk = s.execute(subq.limit(windowsize)).all()
if not chunk:
break
last_id = chunk[-1][-1]
yield from chunk
from sqlalchemy import and_, func, select
from sqlalchemy.orm import Session

if typing.TYPE_CHECKING:
from sqlalchemy import Result, Select, SQLColumnExpression


def column_windows(
session: Session,
stmt: Select[Any],
column: SQLColumnExpression[Any],
windowsize: int,
) -> Iterator[SQLColumnExpression[bool]]:
"""Return a series of WHERE clauses against
a given column that break it into windows.

Result is an iterable of WHERE clauses that are packaged with
the individual ranges to select from.

Requires a database that supports window functions.
"""
rownum = func.row_number().over(order_by=column).label("rownum")

subq = stmt.add_columns(rownum).subquery()
subq_column = list(subq.columns)[-1]

target_column = subq.corresponding_column(column) # type: ignore
new_stmt = select(target_column) # type: ignore

if windowsize > 1:
new_stmt = new_stmt.filter(subq_column % windowsize == 1)

intervals = list(session.scalars(new_stmt))

# yield out WHERE clauses for each range
while intervals:
start = intervals.pop(0)
if intervals:
yield and_(column >= start, column < intervals[0])
else:
yield column >= start


def windowed_query(
session: Session,
stmt: Select[Any],
column: SQLColumnExpression[Any],
windowsize: int,
) -> Iterator[Result[Any]]:
"""Given a Session and Select() object, organize and execute the statement
such that it is invoked for ordered chunks of the total result. yield
out individual Result objects for each chunk.
"""
for whereclause in column_windows(session, stmt, column, windowsize):
yield session.execute(stmt.filter(whereclause).order_by(column))