Skip to content

Commit

Permalink
✨Add foreign_key_args and foreign_key_kwargs arguments to `Field(…
Browse files Browse the repository at this point in the history
…...)` to let the user define additional `sqlalchemy.orm.ForeignKey` attributes, such as `ondelete` and `onupdate`, for foreign keys defined in a base model.
  • Loading branch information
earshinov committed Apr 12, 2024
1 parent c75743d commit 252c047
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
24 changes: 23 additions & 1 deletion sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
sa_column = kwargs.pop("sa_column", Undefined)
sa_column_args = kwargs.pop("sa_column_args", Undefined)
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
sa_foreign_key_args = kwargs.pop("sa_foreign_key_args", Undefined)
sa_foreign_key_kwargs = kwargs.pop("sa_foreign_key_kwargs", Undefined)
if sa_column is not Undefined:
if sa_column_args is not Undefined:
raise RuntimeError(
Expand Down Expand Up @@ -153,6 +155,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
self.sa_column = sa_column
self.sa_column_args = sa_column_args
self.sa_column_kwargs = sa_column_kwargs
self.sa_foreign_key_args = sa_foreign_key_args
self.sa_foreign_key_kwargs = sa_foreign_key_kwargs


class RelationshipInfo(Representation):
Expand Down Expand Up @@ -222,6 +226,8 @@ def Field(
sa_type: Union[Type[Any], UndefinedType] = Undefined,
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
...
Expand Down Expand Up @@ -303,6 +309,8 @@ def Field(
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
Expand Down Expand Up @@ -340,6 +348,8 @@ def Field(
sa_column=sa_column,
sa_column_args=sa_column_args,
sa_column_kwargs=sa_column_kwargs,
sa_foreign_key_args=sa_foreign_key_args,
sa_foreign_key_kwargs=sa_foreign_key_kwargs,
**current_schema_extra,
)
post_init_field_info(field_info)
Expand Down Expand Up @@ -638,7 +648,19 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
unique = False
if foreign_key:
assert isinstance(foreign_key, str)
args.append(ForeignKey(foreign_key))
sa_foreign_key_args = getattr(field_info, "sa_foreign_key_args", Undefined)
fk_args = (
[]
if sa_foreign_key_args is Undefined
else list(cast(Sequence[Any], sa_foreign_key_args))
)
sa_foreign_key_kwargs = getattr(field_info, "sa_foreign_key_kwargs", Undefined)
fk_kwargs = (
{}
if sa_foreign_key_kwargs is Undefined
else cast(Dict[Any, Any], sa_foreign_key_kwargs)
)
args.append(ForeignKey(foreign_key, *fk_args, **fk_kwargs))
kwargs = {
"primary_key": primary_key,
"nullable": nullable,
Expand Down
76 changes: 76 additions & 0 deletions tests/test_foreign_key_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Optional

import pytest
import sqlalchemy.event
import sqlalchemy.exc
from sqlalchemy import ForeignKey, create_engine, func
from sqlmodel import Field, SQLModel, select
from sqlmodel.orm.session import Session


def test_fk_constructed_in_base_model_fails(clear_sqlmodel) -> None:
class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

class Base(SQLModel):
owner_id: Optional[int] = Field(
default=None, sa_column_args=(ForeignKey("user.id", ondelete="SET NULL"),)
)

class Asset(Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

with pytest.raises(sqlalchemy.exc.InvalidRequestError) as e:

class Document(Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

assert "This ForeignKey already has a parent" in str(e.errisinstance)


def test_fk_args_in_base_model_work(clear_sqlmodel) -> None:
class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

class Base(SQLModel):
owner_id: Optional[int] = Field(
default=None,
foreign_key="user.id",
sa_foreign_key_kwargs={"ondelete": "SET NULL"},
)

class Asset(Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

class Document(Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

engine = create_engine("sqlite://")
sqlalchemy.event.listen(
engine, "connect", lambda conn, *args: conn.execute("pragma foreign_keys=ON")
)

SQLModel.metadata.create_all(engine)

# Test that the ON DELETE SET NULL we assigned actually works
with Session(engine) as session:
user = User()
session.add(user)
session.commit()
session.refresh(user)

asset = Asset(owner_id=user.id)
session.add(asset)
session.commit()
session.refresh(asset)
assert asset.owner_id == user.id

session.delete(user)
session.commit()
assert session.scalar(select(func.count()).select_from(User)) == 0

# Normally, one would also define a relationship (in the Asset class, `owner: Optional[User] = Relationship("User")`)
# so that SQLAlchemy knows that Asset and User are related, marks the Asset as dirty and refreshes it when requested.
# But Relationships are a separate complicated topic, which we don't want to touch here.
asset = session.exec(select(Asset)).one()
assert asset.owner_id is None

0 comments on commit 252c047

Please sign in to comment.