Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type hint #1294

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/flask_sqlalchemy/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
*,
metadata: sa.MetaData | None = None,
session_options: dict[str, t.Any] | None = None,
query_class: type[Query] = Query,
query_class: type[Query[Model]] = Query,
model_class: _FSA_MCT = Model, # type: ignore[assignment]
engine_options: dict[str, t.Any] | None = None,
add_models_to_shell: bool = True,
Expand Down Expand Up @@ -808,7 +808,7 @@ def paginate(
max_per_page: int | None = None,
error_out: bool = True,
count: bool = True,
) -> Pagination:
) -> Pagination[t.Any]:
"""Apply an offset and limit to a select statment based on the current page and
number of items per page, returning a :class:`.Pagination` object.

Expand Down
11 changes: 5 additions & 6 deletions src/flask_sqlalchemy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import sqlalchemy as sa
import sqlalchemy.orm as sa_orm
import typing_extensions as te

from .query import Query

Expand All @@ -18,10 +19,8 @@ class _QueryProperty:
:meta private:
"""

def __get__(self, obj: Model | None, cls: type[Model]) -> Query:
return cls.query_class(
cls, session=cls.__fsa__.session() # type: ignore[arg-type]
)
def __get__(self, obj: Model | None, cls: type[Model]) -> Query[Model]:
return cls.query_class(cls, session=cls.__fsa__.session())


class Model:
Expand All @@ -39,12 +38,12 @@ class Model:
:meta private:
"""

query_class: t.ClassVar[type[Query]] = Query
query_class: t.ClassVar[type[Query[Model]]] = Query
"""Query class used by :attr:`query`. Defaults to :attr:`.SQLAlchemy.Query`, which
defaults to :class:`.Query`.
"""

query: t.ClassVar[Query] = _QueryProperty() # type: ignore[assignment]
query: t.ClassVar[Query[te.Self]] = _QueryProperty() # type: ignore[assignment]
"""A SQLAlchemy query for a model. Equivalent to ``db.session.query(Model)``. Can be
customized per-model by overriding :attr:`query_class`.

Expand Down
22 changes: 13 additions & 9 deletions src/flask_sqlalchemy/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@

import sqlalchemy as sa
import sqlalchemy.orm as sa_orm
import typing_extensions as te
from flask import abort
from flask import request


class Pagination:
_T = t.TypeVar("_T")


class Pagination(t.Generic[_T]):
"""Apply an offset and limit to the query based on the current page and number of
items per page.

Expand Down Expand Up @@ -147,7 +151,7 @@ def _query_offset(self) -> int:
"""
return (self.page - 1) * self.per_page

def _query_items(self) -> list[t.Any]:
def _query_items(self) -> list[_T]:
"""Execute the query to get the items on the current page.

Uses init arguments stored in :attr:`_query_args`.
Expand Down Expand Up @@ -212,7 +216,7 @@ def prev_num(self) -> int | None:

return self.page - 1

def prev(self, *, error_out: bool = False) -> Pagination:
def prev(self, *, error_out: bool = False) -> te.Self:
"""Query the :class:`Pagination` object for the previous page.

:param error_out: Abort with a ``404 Not Found`` error if no items are returned
Expand Down Expand Up @@ -242,7 +246,7 @@ def next_num(self) -> int | None:

return self.page + 1

def next(self, *, error_out: bool = False) -> Pagination:
def next(self, *, error_out: bool = False) -> te.Self:
"""Query the :class:`Pagination` object for the next page.

:param error_out: Abort with a ``404 Not Found`` error if no items are returned
Expand Down Expand Up @@ -321,18 +325,18 @@ def iter_pages(

yield from range(right_start, pages_end)

def __iter__(self) -> t.Iterator[t.Any]:
def __iter__(self) -> t.Iterator[_T]:
yield from self.items


class SelectPagination(Pagination):
class SelectPagination(Pagination[_T]):
"""Returned by :meth:`.SQLAlchemy.paginate`. Takes ``select`` and ``session``
arguments in addition to the :class:`Pagination` arguments.

.. versionadded:: 3.0
"""

def _query_items(self) -> list[t.Any]:
def _query_items(self) -> list[_T]:
select = self._query_args["select"]
select = select.limit(self.per_page).offset(self._query_offset)
session = self._query_args["session"]
Expand All @@ -346,14 +350,14 @@ def _query_count(self) -> int:
return out # type: ignore[no-any-return]


class QueryPagination(Pagination):
class QueryPagination(Pagination[_T]):
"""Returned by :meth:`.Query.paginate`. Takes a ``query`` argument in addition to
the :class:`Pagination` arguments.

.. versionadded:: 3.0
"""

def _query_items(self) -> list[t.Any]:
def _query_items(self) -> list[_T]:
query = self._query_args["query"]
out = query.limit(self.per_page).offset(self._query_offset).all()
return out # type: ignore[no-any-return]
Expand Down
15 changes: 9 additions & 6 deletions src/flask_sqlalchemy/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from .pagination import QueryPagination


class Query(sa_orm.Query): # type: ignore[type-arg]
_T = t.TypeVar("_T")


class Query(sa_orm.Query[_T]):
"""SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with some extra methods
useful for querying in a web application.

Expand All @@ -20,7 +23,7 @@ class Query(sa_orm.Query): # type: ignore[type-arg]
Renamed to ``Query`` from ``BaseQuery``.
"""

def get_or_404(self, ident: t.Any, description: str | None = None) -> t.Any:
def get_or_404(self, ident: t.Any, description: str | None = None) -> _T:
"""Like :meth:`~sqlalchemy.orm.Query.get` but aborts with a ``404 Not Found``
error instead of returning ``None``.

Expand All @@ -32,9 +35,9 @@ def get_or_404(self, ident: t.Any, description: str | None = None) -> t.Any:
if rv is None:
abort(404, description=description)

return rv
return rv # type: ignore

def first_or_404(self, description: str | None = None) -> t.Any:
def first_or_404(self, description: str | None = None) -> _T:
"""Like :meth:`~sqlalchemy.orm.Query.first` but aborts with a ``404 Not Found``
error instead of returning ``None``.

Expand All @@ -47,7 +50,7 @@ def first_or_404(self, description: str | None = None) -> t.Any:

return rv

def one_or_404(self, description: str | None = None) -> t.Any:
def one_or_404(self, description: str | None = None) -> _T:
"""Like :meth:`~sqlalchemy.orm.Query.one` but aborts with a ``404 Not Found``
error instead of raising ``NoResultFound`` or ``MultipleResultsFound``.

Expand All @@ -68,7 +71,7 @@ def paginate(
max_per_page: int | None = None,
error_out: bool = True,
count: bool = True,
) -> Pagination:
) -> Pagination[_T]:
"""Apply an offset and limit to the query based on the current page and number
of items per page, returning a :class:`.Pagination` object.

Expand Down
5 changes: 4 additions & 1 deletion tests/test_legacy_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from flask_sqlalchemy.query import Query


_T = t.TypeVar("_T")


@pytest.fixture(autouse=True)
def ignore_query_warning() -> t.Generator[None, None, None]:
if hasattr(sa_exc, "LegacyAPIWarning"):
Expand Down Expand Up @@ -98,7 +101,7 @@ class Child(db.Model):

@pytest.mark.usefixtures("app_ctx")
def test_custom_query_class(app: Flask) -> None:
class CustomQuery(Query):
class CustomQuery(Query[_T]):
pass

db = SQLAlchemy(app, query_class=CustomQuery)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from flask_sqlalchemy.pagination import Pagination


class RangePagination(Pagination):
class RangePagination(Pagination[t.Any]):
def __init__(
self, total: int | None = 150, page: int = 1, per_page: int = 10
) -> None:
Expand Down Expand Up @@ -131,7 +131,7 @@ def __call__(
max_per_page: int | None = None,
error_out: bool = True,
count: bool = True,
) -> Pagination:
) -> Pagination[t.Any]:
qs = {"page": page, "per_page": per_page}
with self.app.test_request_context(query_string=qs):
return self.db.paginate(
Expand Down