From 252c04778fff6ed973a276af8c8c0107509878dc Mon Sep 17 00:00:00 2001 From: Evgeny Arshinov Date: Fri, 12 Apr 2024 18:01:33 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8Add=20`foreign=5Fkey=5Fargs`=20and=20`?= =?UTF-8?q?foreign=5Fkey=5Fkwargs`=20arguments=20to=20`Field(...)`=20to=20?= =?UTF-8?q?let=20the=20user=20define=20additional=20`sqlalchemy.orm.Foreig?= =?UTF-8?q?nKey`=20attributes,=20such=20as=20`ondelete`=20and=20`onupdate`?= =?UTF-8?q?,=20for=20foreign=20keys=20defined=20in=20a=20base=20model.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 24 ++++++++++- tests/test_foreign_key_args.py | 76 ++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 tests/test_foreign_key_args.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 9e8330d69..8b7a20777 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -106,6 +106,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: sa_column = kwargs.pop("sa_column", Undefined) sa_column_args = kwargs.pop("sa_column_args", Undefined) sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) + sa_foreign_key_args = kwargs.pop("sa_foreign_key_args", Undefined) + sa_foreign_key_kwargs = kwargs.pop("sa_foreign_key_kwargs", Undefined) if sa_column is not Undefined: if sa_column_args is not Undefined: raise RuntimeError( @@ -153,6 +155,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs + self.sa_foreign_key_args = sa_foreign_key_args + self.sa_foreign_key_kwargs = sa_foreign_key_kwargs class RelationshipInfo(Representation): @@ -222,6 +226,8 @@ def Field( sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -303,6 +309,8 @@ def Field( sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} @@ -340,6 +348,8 @@ def Field( sa_column=sa_column, sa_column_args=sa_column_args, sa_column_kwargs=sa_column_kwargs, + sa_foreign_key_args=sa_foreign_key_args, + sa_foreign_key_kwargs=sa_foreign_key_kwargs, **current_schema_extra, ) post_init_field_info(field_info) @@ -638,7 +648,19 @@ def get_column_from_field(field: Any) -> Column: # type: ignore unique = False if foreign_key: assert isinstance(foreign_key, str) - args.append(ForeignKey(foreign_key)) + sa_foreign_key_args = getattr(field_info, "sa_foreign_key_args", Undefined) + fk_args = ( + [] + if sa_foreign_key_args is Undefined + else list(cast(Sequence[Any], sa_foreign_key_args)) + ) + sa_foreign_key_kwargs = getattr(field_info, "sa_foreign_key_kwargs", Undefined) + fk_kwargs = ( + {} + if sa_foreign_key_kwargs is Undefined + else cast(Dict[Any, Any], sa_foreign_key_kwargs) + ) + args.append(ForeignKey(foreign_key, *fk_args, **fk_kwargs)) kwargs = { "primary_key": primary_key, "nullable": nullable, diff --git a/tests/test_foreign_key_args.py b/tests/test_foreign_key_args.py new file mode 100644 index 000000000..8d4f95871 --- /dev/null +++ b/tests/test_foreign_key_args.py @@ -0,0 +1,76 @@ +from typing import Optional + +import pytest +import sqlalchemy.event +import sqlalchemy.exc +from sqlalchemy import ForeignKey, create_engine, func +from sqlmodel import Field, SQLModel, select +from sqlmodel.orm.session import Session + + +def test_fk_constructed_in_base_model_fails(clear_sqlmodel) -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Base(SQLModel): + owner_id: Optional[int] = Field( + default=None, sa_column_args=(ForeignKey("user.id", ondelete="SET NULL"),) + ) + + class Asset(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + with pytest.raises(sqlalchemy.exc.InvalidRequestError) as e: + + class Document(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + assert "This ForeignKey already has a parent" in str(e.errisinstance) + + +def test_fk_args_in_base_model_work(clear_sqlmodel) -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Base(SQLModel): + owner_id: Optional[int] = Field( + default=None, + foreign_key="user.id", + sa_foreign_key_kwargs={"ondelete": "SET NULL"}, + ) + + class Asset(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Document(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + engine = create_engine("sqlite://") + sqlalchemy.event.listen( + engine, "connect", lambda conn, *args: conn.execute("pragma foreign_keys=ON") + ) + + SQLModel.metadata.create_all(engine) + + # Test that the ON DELETE SET NULL we assigned actually works + with Session(engine) as session: + user = User() + session.add(user) + session.commit() + session.refresh(user) + + asset = Asset(owner_id=user.id) + session.add(asset) + session.commit() + session.refresh(asset) + assert asset.owner_id == user.id + + session.delete(user) + session.commit() + assert session.scalar(select(func.count()).select_from(User)) == 0 + + # Normally, one would also define a relationship (in the Asset class, `owner: Optional[User] = Relationship("User")`) + # so that SQLAlchemy knows that Asset and User are related, marks the Asset as dirty and refreshes it when requested. + # But Relationships are a separate complicated topic, which we don't want to touch here. + asset = session.exec(select(Asset)).one() + assert asset.owner_id is None