Skip to content

Commit

Permalink
Merge branch '3.0.x'
Browse files Browse the repository at this point in the history
  • Loading branch information
pamelafox committed Jun 19, 2023
2 parents e81bcc5 + b206d0a commit aa31854
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 76 deletions.
9 changes: 9 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ Unreleased
- Pass extra keyword arguments from ``get_or_404`` to ``session.get``. :issue:`1149`


Version 3.0.4
-------------

Released 2023-06-19

- Fix type hint for ``get_or_404`` return value. :pr:`1208`
- Fix type hints for pyright (used by VS Code Pylance extension). :issue:`1205`


Version 3.0.3
-------------

Expand Down
50 changes: 26 additions & 24 deletions src/flask_sqlalchemy/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from weakref import WeakKeyDictionary

import sqlalchemy as sa
import sqlalchemy.event
import sqlalchemy.exc
import sqlalchemy.orm
import sqlalchemy.event as sa_event
import sqlalchemy.exc as sa_exc
import sqlalchemy.orm as sa_orm
from flask import abort
from flask import current_app
from flask import Flask
Expand All @@ -23,6 +23,8 @@
from .session import Session
from .table import _Table

_O = t.TypeVar("_O", bound=object) # Based on sqlalchemy.orm._typing.py


class SQLAlchemy:
"""Integrates SQLAlchemy with Flask. This handles setting up one or more engines,
Expand Down Expand Up @@ -122,7 +124,7 @@ def __init__(
metadata: sa.MetaData | None = None,
session_options: dict[str, t.Any] | None = None,
query_class: type[Query] = Query,
model_class: type[Model] | sa.orm.DeclarativeMeta = Model,
model_class: type[Model] | sa_orm.DeclarativeMeta = Model,
engine_options: dict[str, t.Any] | None = None,
add_models_to_shell: bool = True,
):
Expand Down Expand Up @@ -322,7 +324,7 @@ def init_app(self, app: Flask) -> None:

def _make_scoped_session(
self, options: dict[str, t.Any]
) -> sa.orm.scoped_session[Session]:
) -> sa_orm.scoped_session[Session]:
"""Create a :class:`sqlalchemy.orm.scoping.scoped_session` around the factory
from :meth:`_make_session_factory`. The result is available as :attr:`session`.
Expand All @@ -345,11 +347,11 @@ def _make_scoped_session(
"""
scope = options.pop("scopefunc", _app_ctx_id)
factory = self._make_session_factory(options)
return sa.orm.scoped_session(factory, scope)
return sa_orm.scoped_session(factory, scope)

def _make_session_factory(
self, options: dict[str, t.Any]
) -> sa.orm.sessionmaker[Session]:
) -> sa_orm.sessionmaker[Session]:
"""Create the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker` used by
:meth:`_make_scoped_session`.
Expand All @@ -372,7 +374,7 @@ def _make_session_factory(
"""
options.setdefault("class_", Session)
options.setdefault("query_cls", self.Query)
return sa.orm.sessionmaker(db=self, **options)
return sa_orm.sessionmaker(db=self, **options)

def _teardown_session(self, exc: BaseException | None) -> None:
"""Remove the current session at the end of the request.
Expand Down Expand Up @@ -437,7 +439,7 @@ def __new__(
return Table

def _make_declarative_base(
self, model: type[Model] | sa.orm.DeclarativeMeta
self, model: type[Model] | sa_orm.DeclarativeMeta
) -> type[t.Any]:
"""Create a SQLAlchemy declarative model class. The result is available as
:attr:`Model`.
Expand All @@ -458,9 +460,9 @@ def _make_declarative_base(
.. versionchanged:: 2.3
``model`` can be an already created declarative model class.
"""
if not isinstance(model, sa.orm.DeclarativeMeta):
if not isinstance(model, sa_orm.DeclarativeMeta):
metadata = self._make_metadata(None)
model = sa.orm.declarative_base(
model = sa_orm.declarative_base(
metadata=metadata, cls=model, name="Model", metaclass=DefaultMeta
)

Expand Down Expand Up @@ -614,12 +616,12 @@ def engine(self) -> sa.engine.Engine:

def get_or_404(
self,
entity: type[t.Any],
entity: type[_O],
ident: t.Any,
*,
description: str | None = None,
**kwargs: t.Any,
) -> t.Any:
) -> t.Optional[_O]:
"""Like :meth:`session.get() <sqlalchemy.orm.Session.get>` but aborts with a
``404 Not Found`` error instead of returning ``None``.
Expand Down Expand Up @@ -672,7 +674,7 @@ def one_or_404(
"""
try:
return self.session.execute(statement).scalar_one()
except (sa.exc.NoResultFound, sa.exc.MultipleResultsFound):
except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound):
abort(404, description=description)

def paginate(
Expand Down Expand Up @@ -751,7 +753,7 @@ def _call_for_binds(
if key is None:
message = f"'SQLALCHEMY_DATABASE_URI' config is not set. {message}"

raise sa.exc.UnboundExecutionError(message) from None
raise sa_exc.UnboundExecutionError(message) from None

metadata = self.metadatas[key]
getattr(metadata, op_name)(bind=engine)
Expand Down Expand Up @@ -828,31 +830,31 @@ def _set_rel_query(self, kwargs: dict[str, t.Any]) -> None:

def relationship(
self, *args: t.Any, **kwargs: t.Any
) -> sa.orm.RelationshipProperty[t.Any]:
) -> sa_orm.RelationshipProperty[t.Any]:
"""A :func:`sqlalchemy.orm.relationship` that applies this extension's
:attr:`Query` class for dynamic relationships and backrefs.
.. versionchanged:: 3.0
The :attr:`Query` class is set on ``backref``.
"""
self._set_rel_query(kwargs)
return sa.orm.relationship(*args, **kwargs)
return sa_orm.relationship(*args, **kwargs)

def dynamic_loader(
self, argument: t.Any, **kwargs: t.Any
) -> sa.orm.RelationshipProperty[t.Any]:
) -> sa_orm.RelationshipProperty[t.Any]:
"""A :func:`sqlalchemy.orm.dynamic_loader` that applies this extension's
:attr:`Query` class for relationships and backrefs.
.. versionchanged:: 3.0
The :attr:`Query` class is set on ``backref``.
"""
self._set_rel_query(kwargs)
return sa.orm.dynamic_loader(argument, **kwargs)
return sa_orm.dynamic_loader(argument, **kwargs)

def _relation(
self, *args: t.Any, **kwargs: t.Any
) -> sa.orm.RelationshipProperty[t.Any]:
) -> sa_orm.RelationshipProperty[t.Any]:
"""A :func:`sqlalchemy.orm.relationship` that applies this extension's
:attr:`Query` class for dynamic relationships and backrefs.
Expand All @@ -864,20 +866,20 @@ def _relation(
The :attr:`Query` class is set on ``backref``.
"""
self._set_rel_query(kwargs)
f = sa.orm.relation # type: ignore[attr-defined]
return f(*args, **kwargs) # type: ignore[no-any-return]
f = sa_orm.relationship
return f(*args, **kwargs)

def __getattr__(self, name: str) -> t.Any:
if name == "relation":
return self._relation

if name == "event":
return sa.event
return sa_event

if name.startswith("_"):
raise AttributeError(name)

for mod in (sa, sa.orm):
for mod in (sa, sa_orm):
if hasattr(mod, name):
return getattr(mod, name)

Expand Down
10 changes: 5 additions & 5 deletions src/flask_sqlalchemy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing as t

import sqlalchemy as sa
import sqlalchemy.orm
import sqlalchemy.orm as sa_orm

from .query import Query

Expand Down Expand Up @@ -174,21 +174,21 @@ def should_set_tablename(cls: type) -> bool:
joined-table inheritance. If no primary key is found, the name will be unset.
"""
if cls.__dict__.get("__abstract__", False) or not any(
isinstance(b, sa.orm.DeclarativeMeta) for b in cls.__mro__[1:]
isinstance(b, sa_orm.DeclarativeMeta) for b in cls.__mro__[1:]
):
return False

for base in cls.__mro__:
if "__tablename__" not in base.__dict__:
continue

if isinstance(base.__dict__["__tablename__"], sa.orm.declared_attr):
if isinstance(base.__dict__["__tablename__"], sa_orm.declared_attr):
return False

return not (
base is cls
or base.__dict__.get("__abstract__", False)
or not isinstance(base, sa.orm.DeclarativeMeta)
or not isinstance(base, sa_orm.DeclarativeMeta)
)

return True
Expand All @@ -200,7 +200,7 @@ def camel_to_snake_case(name: str) -> str:
return name.lower().lstrip("_")


class DefaultMeta(BindMetaMixin, NameMetaMixin, sa.orm.DeclarativeMeta):
class DefaultMeta(BindMetaMixin, NameMetaMixin, sa_orm.DeclarativeMeta):
"""SQLAlchemy declarative metaclass that provides ``__bind_key__`` and
``__tablename__`` support.
"""
4 changes: 2 additions & 2 deletions src/flask_sqlalchemy/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from math import ceil

import sqlalchemy as sa
import sqlalchemy.orm
import sqlalchemy.orm as sa_orm
from flask import abort
from flask import request

Expand Down Expand Up @@ -336,7 +336,7 @@ def _query_items(self) -> list[t.Any]:

def _query_count(self) -> int:
select = self._query_args["select"]
sub = select.options(sa.orm.lazyload("*")).order_by(None).subquery()
sub = select.options(sa_orm.lazyload("*")).order_by(None).subquery()
session = self._query_args["session"]
out = session.execute(sa.select(sa.func.count()).select_from(sub)).scalar()
return out # type: ignore[no-any-return]
Expand Down
9 changes: 4 additions & 5 deletions src/flask_sqlalchemy/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

import typing as t

import sqlalchemy as sa
import sqlalchemy.exc
import sqlalchemy.orm
import sqlalchemy.exc as sa_exc
import sqlalchemy.orm as sa_orm
from flask import abort

from .pagination import Pagination
from .pagination import QueryPagination


class Query(sa.orm.Query): # type: ignore[type-arg]
class Query(sa_orm.Query): # type: ignore[type-arg]
"""SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with some extra methods
useful for querying in a web application.
Expand Down Expand Up @@ -58,7 +57,7 @@ def one_or_404(self, description: str | None = None) -> t.Any:
"""
try:
return self.one()
except (sa.exc.NoResultFound, sa.exc.MultipleResultsFound):
except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound):
abort(404, description=description)

def paginate(
Expand Down
6 changes: 3 additions & 3 deletions src/flask_sqlalchemy/record_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from time import perf_counter

import sqlalchemy as sa
import sqlalchemy.event
import sqlalchemy.event as sa_event
from flask import current_app
from flask import g
from flask import has_app_context
Expand Down Expand Up @@ -72,8 +72,8 @@ def duration(self) -> float:


def _listen(engine: sa.engine.Engine) -> None:
sa.event.listen(engine, "before_cursor_execute", _record_start, named=True)
sa.event.listen(engine, "after_cursor_execute", _record_end, named=True)
sa_event.listen(engine, "before_cursor_execute", _record_start, named=True)
sa_event.listen(engine, "after_cursor_execute", _record_end, named=True)


def _record_start(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None:
Expand Down
12 changes: 6 additions & 6 deletions src/flask_sqlalchemy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import typing as t

import sqlalchemy as sa
import sqlalchemy.exc
import sqlalchemy.orm
import sqlalchemy.exc as sa_exc
import sqlalchemy.orm as sa_orm
from flask.globals import app_ctx

if t.TYPE_CHECKING:
from .extension import SQLAlchemy


class Session(sa.orm.Session):
class Session(sa_orm.Session):
"""A SQLAlchemy :class:`~sqlalchemy.orm.Session` class that chooses what engine to
use based on the bind key associated with the metadata associated with the thing
being queried.
Expand Down Expand Up @@ -55,9 +55,9 @@ def get_bind(
if mapper is not None:
try:
mapper = sa.inspect(mapper)
except sa.exc.NoInspectionAvailable as e:
except sa_exc.NoInspectionAvailable as e:
if isinstance(mapper, type):
raise sa.orm.exc.UnmappedClassError(mapper) from e
raise sa_orm.exc.UnmappedClassError(mapper) from e

raise

Expand Down Expand Up @@ -88,7 +88,7 @@ def _clause_to_engine(
key = clause.metadata.info["bind_key"]

if key not in engines:
raise sa.exc.UnboundExecutionError(
raise sa_exc.UnboundExecutionError(
f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config."
)

Expand Down
16 changes: 8 additions & 8 deletions src/flask_sqlalchemy/track_modifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import typing as t

import sqlalchemy as sa
import sqlalchemy.event
import sqlalchemy.orm
import sqlalchemy.event as sa_event
import sqlalchemy.orm as sa_orm
from flask import current_app
from flask import has_app_context
from flask.signals import Namespace # type: ignore[attr-defined]
Expand All @@ -29,12 +29,12 @@
"""


def _listen(session: sa.orm.scoped_session[Session]) -> None:
sa.event.listen(session, "before_flush", _record_ops, named=True)
sa.event.listen(session, "before_commit", _record_ops, named=True)
sa.event.listen(session, "before_commit", _before_commit)
sa.event.listen(session, "after_commit", _after_commit)
sa.event.listen(session, "after_rollback", _after_rollback)
def _listen(session: sa_orm.scoped_session[Session]) -> None:
sa_event.listen(session, "before_flush", _record_ops, named=True)
sa_event.listen(session, "before_commit", _record_ops, named=True)
sa_event.listen(session, "before_commit", _before_commit)
sa_event.listen(session, "after_commit", _after_commit)
sa_event.listen(session, "after_rollback", _after_rollback)


def _record_ops(session: Session, **kwargs: t.Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def db(app: Flask) -> SQLAlchemy:


@pytest.fixture
def Todo(app: Flask, db: SQLAlchemy) -> t.Any:
def Todo(app: Flask, db: SQLAlchemy) -> t.Generator[t.Any, None, None]:
class Todo(db.Model):
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String)
Expand Down

0 comments on commit aa31854

Please sign in to comment.