Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨Properly support inheritance of Relationship attributes #886

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
run: pip install --upgrade "pydantic>=1.10.0,<2.0.0"
- name: Install Pydantic v2
if: matrix.pydantic-version == 'pydantic-v2'
run: pip install --upgrade "pydantic>=2.0.2,<3.0.0"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be rolled back before merging

run: pip install --upgrade "pydantic>=2.0.2,<2.7.0"
- name: Lint
# Do not run on Python 3.7 as mypy behaves differently
if: matrix.python-version != '3.7' && matrix.pydantic-version == 'pydantic-v2'
Expand Down
30 changes: 28 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,28 @@ def __new__(
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
backup_base_annotations: Dict[Type[Any], Dict[str, Any]] = {}
for base in bases:
base_relationships = getattr(base, "__sqlmodel_relationships__", None)
if base_relationships:
relationships.update(base_relationships)
#
# Temporarily pluck out `__annotations__` corresponding to relationships from base classes, otherwise these annotations
# make their way into `cls.model_fields` as `FieldInfo(..., required=True)`, even when the relationships are declared
# optional. When a model instance is then constructed using `model_validate` and an optional relationship field is not
# passed, this leads to an incorrect `pydantic.ValidationError`.
#
# We can't just clean up `new_cls.model_fields` after `new_cls` is constructed because by this time
# Pydantic has created model schema and validation rules, so this won't fix the problem.
#
base_annotations = getattr(base, "__annotations__", None)
if base_annotations:
backup_base_annotations[base] = base_annotations
base.__annotations__ = {
name: typ
for name, typ in base_annotations.items()
if name not in base_relationships
}
dict_for_pydantic = {}
original_annotations = get_annotations(class_dict)
pydantic_annotations = {}
Expand Down Expand Up @@ -449,6 +471,9 @@ def __new__(
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
}
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
# Restore base annotations
for base, annotations in backup_base_annotations.items():
base.__annotations__ = annotations
new_cls.__annotations__ = {
**relationship_annotations,
**pydantic_annotations,
Expand All @@ -471,8 +496,9 @@ def get_config(name: str) -> Any:
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
col = get_column_from_field(v)
setattr(new_cls, k, col)
if k not in relationships:
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict.
# This could be done by reading new_cls.model_config['table'] in FastAPI, but
Expand Down
148 changes: 148 additions & 0 deletions tests/test_inherit_relationship.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import datetime
from typing import Optional

import pydantic
from sqlalchemy import DateTime, func
from sqlalchemy.orm import declared_attr, relationship
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
from sqlmodel._compat import IS_PYDANTIC_V2


def test_inherit_relationship(clear_sqlmodel) -> None:
def now():
return datetime.datetime.now(tz=datetime.timezone.utc)

class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str

class CreatedUpdatedMixin(SQLModel):
# Fields in reusable base models must be defined using `sa_type` and `sa_column_kwargs` instead of `sa_column`
# https://github.com/tiangolo/sqlmodel/discussions/743
#
# created_at: datetime.datetime = Field(default_factory=now, sa_column=DateTime(default=now))
created_at: datetime.datetime = Field(
default_factory=now, sa_type=DateTime, sa_column_kwargs={"default": now}
)

# With Pydantic V2, it is also possible to define `created_by` like this:
#
# ```python
# @declared_attr
# def _created_by(cls):
# return relationship(User, foreign_keys=cls.created_by_id)
#
# created_by: Optional[User] = Relationship(sa_relationship=_created_by))
# ```
#
# The difference from Pydantic V1 is that Pydantic V2 plucks attributes with names starting with '_' (but not '__')
# from class attributes and stores them separately as instances of `pydantic.ModelPrivateAttr` somewhere in depths of
# Pydantic internals. Under Pydantic V1 this doesn't happen, so SQLAlchemy ends up having two class attributes
# (`_created_by` and `created_by`) corresponding to one database attribute, causing a conflict and unreliable behavior.
# The approach with a lambda always works because it doesn't produce the second class attribute and thus eliminates
# the possibility of a conflict entirely.
#
created_by_id: Optional[int] = Field(default=None, foreign_key="user.id")
created_by: Optional[User] = Relationship(
sa_relationship=declared_attr(
lambda cls: relationship(User, foreign_keys=cls.created_by_id)
)
)

updated_at: datetime.datetime = Field(
default_factory=now, sa_type=DateTime, sa_column_kwargs={"default": now}
)
updated_by_id: Optional[int] = Field(default=None, foreign_key="user.id")
updated_by: Optional[User] = Relationship(
sa_relationship=declared_attr(
lambda cls: relationship(User, foreign_keys=cls.updated_by_id)
)
)

class Asset(CreatedUpdatedMixin, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str

# Demonstrate that the mixin can be applied to more than 1 model
class Document(CreatedUpdatedMixin, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str

engine = create_engine("sqlite://")

SQLModel.metadata.create_all(engine)

john = User(name="John")
jane = User(name="Jane")
asset = Asset(name="Test", created_by=john, updated_by=jane)
doc = Document(name="Resume", created_by=jane, updated_by=john)

with Session(engine) as session:
session.add(asset)
session.add(doc)
session.commit()

with Session(engine) as session:
assert session.scalar(select(func.count()).select_from(User)) == 2

asset = session.exec(select(Asset)).one()
assert asset.created_by.name == "John"
assert asset.updated_by.name == "Jane"

doc = session.exec(select(Document)).one()
assert doc.created_by.name == "Jane"
assert doc.updated_by.name == "John"


def test_inherit_relationship_model_validate(clear_sqlmodel) -> None:
class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

class Mixin(SQLModel):
owner_id: Optional[int] = Field(default=None, foreign_key="user.id")
owner: Optional[User] = Relationship(
sa_relationship=declared_attr(
lambda cls: relationship(User, foreign_keys=cls.owner_id)
)
)

class Asset(Mixin, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

class AssetCreate(pydantic.BaseModel):
pass

asset_create = AssetCreate()

engine = create_engine("sqlite://")

SQLModel.metadata.create_all(engine)

user = User()

# Owner must be optional
asset = Asset.model_validate(asset_create)
with Session(engine) as session:
session.add(asset)
session.commit()
session.refresh(asset)
assert asset.id is not None
assert asset.owner_id is None
assert asset.owner is None

# When set, owner must be saved
#
# Under Pydantic V2, relationship fields set it `model_validate` are not saved,
# with or without inheritance. Consider it a known issue.
#
if IS_PYDANTIC_V2:
asset = Asset.model_validate(asset_create, update={"owner": user})
with Session(engine) as session:
session.add(asset)
session.commit()
session.refresh(asset)
session.refresh(user)
assert asset.id is not None
assert user.id is not None
assert asset.owner_id == user.id
assert asset.owner.id == user.id
Loading