diff --git a/tests/unit/search/test_tasks.py b/tests/unit/search/test_tasks.py index 3cb189b4e3ad..033412b59090 100644 --- a/tests/unit/search/test_tasks.py +++ b/tests/unit/search/test_tasks.py @@ -69,7 +69,7 @@ def test_project_docs(db_session): ).description.raw, }, } - for p, prs in sorted(releases.items(), key=lambda x: x[0].id) + for p, prs in sorted(releases.items(), key=lambda x: x[0].name) ] diff --git a/tests/unit/utils/db/test_windowed_query.py b/tests/unit/utils/db/test_windowed_query.py index 1f5d0cd66fb9..9c750cc7cafc 100644 --- a/tests/unit/utils/db/test_windowed_query.py +++ b/tests/unit/utils/db/test_windowed_query.py @@ -15,7 +15,6 @@ import pytest from sqlalchemy import select -from sqlalchemy.orm import aliased from warehouse.packaging.models import Project from warehouse.utils.db.windowed_query import windowed_query @@ -30,15 +29,13 @@ def test_windowed_query(db_session, query_recorder, window_size): expected = math.ceil(len(projects) / window_size) + 1 - subquery = select(Project.normalized_name).order_by(Project.id).subquery() - pa = aliased(Project, subquery) - - query = select(Project.name).select_from(pa).distinct(Project.id) + query = select(Project) + result_set = set() with query_recorder: - assert ( - set(windowed_query(db_session, query, Project.id, window_size)) - == project_set - ) + for result in windowed_query(db_session, query, Project.id, window_size): + for project in result.scalars(): + result_set.add((project.name, project.id)) + assert result_set == project_set assert len(query_recorder.queries) == expected diff --git a/warehouse/search/tasks.py b/warehouse/search/tasks.py index a04d3a4a7d21..d6da6bae1ce4 100644 --- a/warehouse/search/tasks.py +++ b/warehouse/search/tasks.py @@ -92,13 +92,14 @@ def _project_docs(db, project_name=None): .outerjoin(Release.project) ) - for release in windowed_query(db, release_data, Project.id, 25000): - p = ProjectDocument.from_db(release) - p._index = None - p.full_clean() - doc = p.to_dict(include_meta=True) - doc.pop("_index", None) - yield doc + for chunk in windowed_query(db, release_data, Project.name, 25000): + for release in chunk: + p = ProjectDocument.from_db(release) + p._index = None + p.full_clean() + doc = p.to_dict(include_meta=True) + doc.pop("_index", None) + yield doc class SearchLock(Lock): diff --git a/warehouse/utils/db/windowed_query.py b/warehouse/utils/db/windowed_query.py index 97d65fddc3c7..bf25b383a3c0 100644 --- a/warehouse/utils/db/windowed_query.py +++ b/warehouse/utils/db/windowed_query.py @@ -12,22 +12,66 @@ # Taken from "Theatrum Chemicum" at # https://github.com/sqlalchemy/sqlalchemy/wiki/RangeQuery-and-WindowedRangeQuery -# updated from, with minor tweaks: -# https://github.com/sqlalchemy/sqlalchemy/discussions/7948#discussioncomment-2597083 +from __future__ import annotations -def windowed_query(s, q, column, windowsize): - """Break a Query into chunks on a given column.""" +import typing - q = q.add_columns(column).order_by(column) - last_id = None +from collections.abc import Iterator +from typing import Any - while True: - subq = q - if last_id is not None: - subq = subq.filter(column > last_id) - chunk = s.execute(subq.limit(windowsize)).all() - if not chunk: - break - last_id = chunk[-1][-1] - yield from chunk +from sqlalchemy import and_, func, select +from sqlalchemy.orm import Session + +if typing.TYPE_CHECKING: + from sqlalchemy import Result, Select, SQLColumnExpression + + +def column_windows( + session: Session, + stmt: Select[Any], + column: SQLColumnExpression[Any], + windowsize: int, +) -> Iterator[SQLColumnExpression[bool]]: + """Return a series of WHERE clauses against + a given column that break it into windows. + + Result is an iterable of WHERE clauses that are packaged with + the individual ranges to select from. + + Requires a database that supports window functions. + """ + rownum = func.row_number().over(order_by=column).label("rownum") + + subq = stmt.add_columns(rownum).subquery() + subq_column = list(subq.columns)[-1] + + target_column = subq.corresponding_column(column) # type: ignore + new_stmt = select(target_column) # type: ignore + + if windowsize > 1: + new_stmt = new_stmt.filter(subq_column % windowsize == 1) + + intervals = list(session.scalars(new_stmt)) + + # yield out WHERE clauses for each range + while intervals: + start = intervals.pop(0) + if intervals: + yield and_(column >= start, column < intervals[0]) + else: + yield column >= start + + +def windowed_query( + session: Session, + stmt: Select[Any], + column: SQLColumnExpression[Any], + windowsize: int, +) -> Iterator[Result[Any]]: + """Given a Session and Select() object, organize and execute the statement + such that it is invoked for ordered chunks of the total result. yield + out individual Result objects for each chunk. + """ + for whereclause in column_windows(session, stmt, column, windowsize): + yield session.execute(stmt.filter(whereclause).order_by(column))