Skip to content

Commit

Permalink
Improve type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
詹家辉 committed Dec 20, 2023
1 parent 538f043 commit 31229f7
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/flask_sqlalchemy/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
*,
metadata: sa.MetaData | None = None,
session_options: dict[str, t.Any] | None = None,
query_class: type[Query] = Query,
query_class: type[Query[Model]] = Query,
model_class: _FSA_MCT = Model, # type: ignore[assignment]
engine_options: dict[str, t.Any] | None = None,
add_models_to_shell: bool = True,
Expand Down Expand Up @@ -808,7 +808,7 @@ def paginate(
max_per_page: int | None = None,
error_out: bool = True,
count: bool = True,
) -> Pagination:
) -> Pagination[t.Any]:
"""Apply an offset and limit to a select statment based on the current page and
number of items per page, returning a :class:`.Pagination` object.
Expand Down
9 changes: 5 additions & 4 deletions src/flask_sqlalchemy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
import typing as t
import typing_extensions as te

import sqlalchemy as sa
import sqlalchemy.orm as sa_orm
Expand All @@ -18,9 +19,9 @@ class _QueryProperty:
:meta private:
"""

def __get__(self, obj: Model | None, cls: type[Model]) -> Query:
def __get__(self, obj: Model | None, cls: type[Model]) -> Query[Model]:
return cls.query_class(
cls, session=cls.__fsa__.session() # type: ignore[arg-type]
cls, session=cls.__fsa__.session()
)


Expand All @@ -39,12 +40,12 @@ class Model:
:meta private:
"""

query_class: t.ClassVar[type[Query]] = Query
query_class: t.ClassVar[type[Query["Model"]]] = Query
"""Query class used by :attr:`query`. Defaults to :attr:`.SQLAlchemy.Query`, which
defaults to :class:`.Query`.
"""

query: t.ClassVar[Query] = _QueryProperty() # type: ignore[assignment]
query: t.ClassVar[Query[te.Self]] = _QueryProperty() # type: ignore[assignment]
"""A SQLAlchemy query for a model. Equivalent to ``db.session.query(Model)``. Can be
customized per-model by overriding :attr:`query_class`.
Expand Down
22 changes: 13 additions & 9 deletions src/flask_sqlalchemy/pagination.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing as t
import typing_extensions as te
from math import ceil

import sqlalchemy as sa
Expand All @@ -9,7 +10,10 @@
from flask import request


class Pagination:
_T = t.TypeVar("_T")


class Pagination(t.Generic[_T]):
"""Apply an offset and limit to the query based on the current page and number of
items per page.
Expand Down Expand Up @@ -147,7 +151,7 @@ def _query_offset(self) -> int:
"""
return (self.page - 1) * self.per_page

def _query_items(self) -> list[t.Any]:
def _query_items(self) -> list[_T]:
"""Execute the query to get the items on the current page.
Uses init arguments stored in :attr:`_query_args`.
Expand Down Expand Up @@ -212,7 +216,7 @@ def prev_num(self) -> int | None:

return self.page - 1

def prev(self, *, error_out: bool = False) -> Pagination:
def prev(self, *, error_out: bool = False) -> te.Self:
"""Query the :class:`Pagination` object for the previous page.
:param error_out: Abort with a ``404 Not Found`` error if no items are returned
Expand Down Expand Up @@ -242,7 +246,7 @@ def next_num(self) -> int | None:

return self.page + 1

def next(self, *, error_out: bool = False) -> Pagination:
def next(self, *, error_out: bool = False) -> te.Self:
"""Query the :class:`Pagination` object for the next page.
:param error_out: Abort with a ``404 Not Found`` error if no items are returned
Expand Down Expand Up @@ -321,18 +325,18 @@ def iter_pages(

yield from range(right_start, pages_end)

def __iter__(self) -> t.Iterator[t.Any]:
def __iter__(self) -> t.Iterator[_T]:
yield from self.items


class SelectPagination(Pagination):
class SelectPagination(Pagination[_T]):
"""Returned by :meth:`.SQLAlchemy.paginate`. Takes ``select`` and ``session``
arguments in addition to the :class:`Pagination` arguments.
.. versionadded:: 3.0
"""

def _query_items(self) -> list[t.Any]:
def _query_items(self) -> list[_T]:
select = self._query_args["select"]
select = select.limit(self.per_page).offset(self._query_offset)
session = self._query_args["session"]
Expand All @@ -346,14 +350,14 @@ def _query_count(self) -> int:
return out # type: ignore[no-any-return]


class QueryPagination(Pagination):
class QueryPagination(Pagination[_T]):
"""Returned by :meth:`.Query.paginate`. Takes a ``query`` argument in addition to
the :class:`Pagination` arguments.
.. versionadded:: 3.0
"""

def _query_items(self) -> list[t.Any]:
def _query_items(self) -> list[_T]:
query = self._query_args["query"]
out = query.limit(self.per_page).offset(self._query_offset).all()
return out # type: ignore[no-any-return]
Expand Down
15 changes: 9 additions & 6 deletions src/flask_sqlalchemy/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from .pagination import QueryPagination


class Query(sa_orm.Query): # type: ignore[type-arg]
_T = t.TypeVar("_T")


class Query(sa_orm.Query[_T]):
"""SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with some extra methods
useful for querying in a web application.
Expand All @@ -20,7 +23,7 @@ class Query(sa_orm.Query): # type: ignore[type-arg]
Renamed to ``Query`` from ``BaseQuery``.
"""

def get_or_404(self, ident: t.Any, description: str | None = None) -> t.Any:
def get_or_404(self, ident: t.Any, description: str | None = None) -> _T:
"""Like :meth:`~sqlalchemy.orm.Query.get` but aborts with a ``404 Not Found``
error instead of returning ``None``.
Expand All @@ -32,9 +35,9 @@ def get_or_404(self, ident: t.Any, description: str | None = None) -> t.Any:
if rv is None:
abort(404, description=description)

return rv
return rv # type: ignore

def first_or_404(self, description: str | None = None) -> t.Any:
def first_or_404(self, description: str | None = None) -> _T:
"""Like :meth:`~sqlalchemy.orm.Query.first` but aborts with a ``404 Not Found``
error instead of returning ``None``.
Expand All @@ -47,7 +50,7 @@ def first_or_404(self, description: str | None = None) -> t.Any:

return rv

def one_or_404(self, description: str | None = None) -> t.Any:
def one_or_404(self, description: str | None = None) -> _T:
"""Like :meth:`~sqlalchemy.orm.Query.one` but aborts with a ``404 Not Found``
error instead of raising ``NoResultFound`` or ``MultipleResultsFound``.
Expand All @@ -68,7 +71,7 @@ def paginate(
max_per_page: int | None = None,
error_out: bool = True,
count: bool = True,
) -> Pagination:
) -> Pagination[_T]:
"""Apply an offset and limit to the query based on the current page and number
of items per page, returning a :class:`.Pagination` object.
Expand Down
5 changes: 4 additions & 1 deletion tests/test_legacy_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from flask_sqlalchemy.query import Query


_T = t.TypeVar("_T")


@pytest.fixture(autouse=True)
def ignore_query_warning() -> t.Generator[None, None, None]:
if hasattr(sa_exc, "LegacyAPIWarning"):
Expand Down Expand Up @@ -98,7 +101,7 @@ class Child(db.Model):

@pytest.mark.usefixtures("app_ctx")
def test_custom_query_class(app: Flask) -> None:
class CustomQuery(Query):
class CustomQuery(Query[_T]):
pass

db = SQLAlchemy(app, query_class=CustomQuery)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from flask_sqlalchemy.pagination import Pagination


class RangePagination(Pagination):

class RangePagination(Pagination[t.Any]):
def __init__(
self, total: int | None = 150, page: int = 1, per_page: int = 10
) -> None:
Expand Down Expand Up @@ -131,7 +132,7 @@ def __call__(
max_per_page: int | None = None,
error_out: bool = True,
count: bool = True,
) -> Pagination:
) -> Pagination[t.Any]:
qs = {"page": page, "per_page": per_page}
with self.app.test_request_context(query_string=qs):
return self.db.paginate(
Expand Down

0 comments on commit 31229f7

Please sign in to comment.