From 99f8ce3894564444f33770577a934749e479034e Mon Sep 17 00:00:00 2001 From: Daniil Fajnberg <60156134+daniil-berg@users.noreply.github.com> Date: Thu, 26 Oct 2023 12:18:05 +0200 Subject: [PATCH 1/7] =?UTF-8?q?=E2=9C=A8=20Add=20support=20for=20all=20`Fi?= =?UTF-8?q?eld`=20parameters=20from=20Pydantic=20`1.9.0`=20and=20above,=20?= =?UTF-8?q?make=20Pydantic=20`1.9.0`=20the=20minimum=20required=20version?= =?UTF-8?q?=20(#440)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sebastián Ramírez --- pyproject.toml | 2 +- sqlmodel/main.py | 16 ++++++++- tests/test_pydantic/__init__.py | 0 tests/test_pydantic/test_field.py | 57 +++++++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 tests/test_pydantic/__init__.py create mode 100644 tests/test_pydantic/test_field.py diff --git a/pyproject.toml b/pyproject.toml index c7956daaa..181064e4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.7" SQLAlchemy = ">=1.4.36,<2.0.0" -pydantic = "^1.8.2" +pydantic = "^1.9.0" sqlalchemy2-stubs = {version = "*", allow-prereleases = true} [tool.poetry.group.dev.dependencies] diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 07e600e4d..3015aa9fb 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -145,12 +145,17 @@ def Field( lt: Optional[float] = None, le: Optional[float] = None, multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, min_items: Optional[int] = None, max_items: Optional[int] = None, + unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, primary_key: bool = False, foreign_key: Optional[Any] = None, unique: bool = False, @@ -176,12 +181,17 @@ def Field( lt=lt, le=le, multiple_of=multiple_of, + max_digits=max_digits, + decimal_places=decimal_places, min_items=min_items, max_items=max_items, + unique_items=unique_items, min_length=min_length, max_length=max_length, allow_mutation=allow_mutation, regex=regex, + discriminator=discriminator, + repr=repr, primary_key=primary_key, foreign_key=foreign_key, unique=unique, @@ -587,7 +597,11 @@ def parse_obj( def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # Don't show SQLAlchemy private attributes - return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")] + return [ + (k, v) + for k, v in super().__repr_args__() + if not (isinstance(k, str) and k.startswith("_sa_")) + ] # From Pydantic, override to enforce validation with dict @classmethod diff --git a/tests/test_pydantic/__init__.py b/tests/test_pydantic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py new file mode 100644 index 000000000..9d7bc7762 --- /dev/null +++ b/tests/test_pydantic/test_field.py @@ -0,0 +1,57 @@ +from decimal import Decimal +from typing import Optional, Union + +import pytest +from pydantic import ValidationError +from sqlmodel import Field, SQLModel +from typing_extensions import Literal + + +def test_decimal(): + class Model(SQLModel): + dec: Decimal = Field(max_digits=4, decimal_places=2) + + Model(dec=Decimal("3.14")) + Model(dec=Decimal("69.42")) + + with pytest.raises(ValidationError): + Model(dec=Decimal("3.142")) + with pytest.raises(ValidationError): + Model(dec=Decimal("0.069")) + with pytest.raises(ValidationError): + Model(dec=Decimal("420")) + + +def test_discriminator(): + # Example adapted from + # [Pydantic docs](https://pydantic-docs.helpmanual.io/usage/types/#discriminated-unions-aka-tagged-unions): + + class Cat(SQLModel): + pet_type: Literal["cat"] + meows: int + + class Dog(SQLModel): + pet_type: Literal["dog"] + barks: float + + class Lizard(SQLModel): + pet_type: Literal["reptile", "lizard"] + scales: bool + + class Model(SQLModel): + pet: Union[Cat, Dog, Lizard] = Field(..., discriminator="pet_type") + n: int + + Model(pet={"pet_type": "dog", "barks": 3.14}, n=1) # type: ignore[arg-type] + + with pytest.raises(ValidationError): + Model(pet={"pet_type": "dog"}, n=1) # type: ignore[arg-type] + + +def test_repr(): + class Model(SQLModel): + id: Optional[int] = Field(primary_key=True) + foo: str = Field(repr=False) + + instance = Model(id=123, foo="bar") + assert "foo=" not in repr(instance) From 8d1423253812af999aed55ce56d2918211039b08 Mon Sep 17 00:00:00 2001 From: github-actions Date: Thu, 26 Oct 2023 10:18:41 +0000 Subject: [PATCH 2/7] =?UTF-8?q?=F0=9F=93=9D=20Update=20release=20notes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/release-notes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/release-notes.md b/docs/release-notes.md index 900804d8d..4956a1247 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -2,6 +2,7 @@ ## Latest Changes +* ✨ Add support for all `Field` parameters from Pydantic `1.9.0` and above, make Pydantic `1.9.0` the minimum required version. PR [#440](https://github.com/tiangolo/sqlmodel/pull/440) by [@daniil-berg](https://github.com/daniil-berg). ## 0.0.9 ### Breaking Changes From 7fdfee10a5275b6f076d18d70e584da6b632a313 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 26 Oct 2023 18:32:26 +0400 Subject: [PATCH 3/7] =?UTF-8?q?=F0=9F=94=A7=20Adopt=20Ruff=20for=20formatt?= =?UTF-8?q?ing=20(#679)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 16 +++------------- pyproject.toml | 9 ++++++++- scripts/format.sh | 2 +- scripts/lint.sh | 2 +- tests/conftest.py | 3 +-- 5 files changed, 14 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 022ef24a0..61aaf7141 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,23 +13,13 @@ repos: - --unsafe - id: end-of-file-fixer - id: trailing-whitespace -- repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 - hooks: - - id: pyupgrade - args: - - --py3-plus - - --keep-runtime-typing -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.1 +- repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.1.2 hooks: - id: ruff args: - --fix -- repo: https://github.com/psf/black - rev: 23.10.0 - hooks: - - id: black + - id: ruff-format ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate diff --git a/pyproject.toml b/pyproject.toml index 181064e4e..20188513c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ sqlalchemy2-stubs = {version = "*", allow-prereleases = true} [tool.poetry.group.dev.dependencies] pytest = "^7.0.1" mypy = "0.971" +# Needed by the code generator using templates black = "^22.10.0" mkdocs-material = "9.1.21" pillow = "^9.3.0" @@ -46,7 +47,7 @@ mdx-include = "^1.4.1" coverage = {extras = ["toml"], version = "^6.2"} fastapi = "^0.68.1" requests = "^2.26.0" -ruff = "^0.1.1" +ruff = "^0.1.2" [build-system] requires = ["poetry-core"] @@ -87,11 +88,13 @@ select = [ "I", # isort "C", # flake8-comprehensions "B", # flake8-bugbear + "UP", # pyupgrade ] ignore = [ "E501", # line too long, handled by black "B008", # do not perform function calls in argument defaults "C901", # too complex + "W191", # indentation contains tabs ] [tool.ruff.per-file-ignores] @@ -99,3 +102,7 @@ ignore = [ [tool.ruff.isort] known-third-party = ["sqlmodel", "sqlalchemy", "pydantic", "fastapi"] + +[tool.ruff.pyupgrade] +# Preserve types, even if a file imports `from __future__ import annotations`. +keep-runtime-typing = true diff --git a/scripts/format.sh b/scripts/format.sh index b6aebd10d..70c12e579 100755 --- a/scripts/format.sh +++ b/scripts/format.sh @@ -2,4 +2,4 @@ set -x ruff sqlmodel tests docs_src scripts --fix -black sqlmodel tests docs_src scripts +ruff format sqlmodel tests docs_src scripts diff --git a/scripts/lint.sh b/scripts/lint.sh index b328e3d9a..f66882239 100755 --- a/scripts/lint.sh +++ b/scripts/lint.sh @@ -5,4 +5,4 @@ set -x mypy sqlmodel ruff sqlmodel tests docs_src scripts -black sqlmodel tests docs_src --check +ruff format sqlmodel tests docs_src --check diff --git a/tests/conftest.py b/tests/conftest.py index cd66420c8..2b8e5fc29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,8 +42,7 @@ def coverage_run(*, module: str, cwd: Union[str, Path]) -> subprocess.CompletedP module, ], cwd=str(cwd), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, encoding="utf-8", ) return result From 13cc722110b9d3b37cd00a9a0480fdd4bccb289a Mon Sep 17 00:00:00 2001 From: github-actions Date: Thu, 26 Oct 2023 14:32:59 +0000 Subject: [PATCH 4/7] =?UTF-8?q?=F0=9F=93=9D=20Update=20release=20notes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/release-notes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/release-notes.md b/docs/release-notes.md index 4956a1247..11ad6fc23 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -2,6 +2,7 @@ ## Latest Changes +* 🔧 Adopt Ruff for formatting. PR [#679](https://github.com/tiangolo/sqlmodel/pull/679) by [@tiangolo](https://github.com/tiangolo). * ✨ Add support for all `Field` parameters from Pydantic `1.9.0` and above, make Pydantic `1.9.0` the minimum required version. PR [#440](https://github.com/tiangolo/sqlmodel/pull/440) by [@daniil-berg](https://github.com/daniil-berg). ## 0.0.9 From e4e1385eedc700ad8c4e079841a85c32a29f1cff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 26 Oct 2023 18:34:49 +0400 Subject: [PATCH 5/7] =?UTF-8?q?=F0=9F=94=96=20Release=20version=200.0.10?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/release-notes.md | 10 +++++++++- sqlmodel/__init__.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/release-notes.md b/docs/release-notes.md index 11ad6fc23..61cd9f66a 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -2,8 +2,16 @@ ## Latest Changes -* 🔧 Adopt Ruff for formatting. PR [#679](https://github.com/tiangolo/sqlmodel/pull/679) by [@tiangolo](https://github.com/tiangolo). +## 0.0.10 + +### Features + * ✨ Add support for all `Field` parameters from Pydantic `1.9.0` and above, make Pydantic `1.9.0` the minimum required version. PR [#440](https://github.com/tiangolo/sqlmodel/pull/440) by [@daniil-berg](https://github.com/daniil-berg). + +### Internal + +* 🔧 Adopt Ruff for formatting. PR [#679](https://github.com/tiangolo/sqlmodel/pull/679) by [@tiangolo](https://github.com/tiangolo). + ## 0.0.9 ### Breaking Changes diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index b51084b3f..6b65860d3 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.9" +__version__ = "0.0.10" # Re-export from SQLAlchemy from sqlalchemy.engine import create_mock_engine as create_mock_engine From 717594ef13b64d0d4fd1ce7c6305c945f68460d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 28 Oct 2023 17:55:23 +0400 Subject: [PATCH 6/7] =?UTF-8?q?=E2=9C=A8=20Do=20not=20allow=20invalid=20co?= =?UTF-8?q?mbinations=20of=20field=20parameters=20for=20columns=20and=20re?= =?UTF-8?q?lationships,=20`sa=5Fcolumn`=20excludes=20`sa=5Fcolumn=5Fargs`,?= =?UTF-8?q?=20`primary=5Fkey`,=20`nullable`,=20etc.=20(#681)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ Make sa_column exclusive, do not allow incompatible arguments, sa_column_args, primary_key, etc * ✅ Add tests for new errors when incorrectly using sa_column * ✅ Add tests for sa_column_args and sa_column_kwargs * ♻️ Do not allow sa_relationship with sa_relationship_args or sa_relationship_kwargs * ✅ Add tests for relationship errors * ✅ Fix test for sa_column_args --- sqlmodel/main.py | 151 ++++++++++++++++++++++++++-- tests/test_field_sa_args_kwargs.py | 39 +++++++ tests/test_field_sa_column.py | 99 ++++++++++++++++++ tests/test_field_sa_relationship.py | 53 ++++++++++ 4 files changed, 332 insertions(+), 10 deletions(-) create mode 100644 tests/test_field_sa_args_kwargs.py create mode 100644 tests/test_field_sa_column.py create mode 100644 tests/test_field_sa_relationship.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3015aa9fb..f48e388e1 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -22,6 +22,7 @@ TypeVar, Union, cast, + overload, ) from pydantic import BaseConfig, BaseModel @@ -87,6 +88,28 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: "Passing sa_column_kwargs is not supported when " "also passing a sa_column" ) + if primary_key is not Undefined: + raise RuntimeError( + "Passing primary_key is not supported when " + "also passing a sa_column" + ) + if nullable is not Undefined: + raise RuntimeError( + "Passing nullable is not supported when " "also passing a sa_column" + ) + if foreign_key is not Undefined: + raise RuntimeError( + "Passing foreign_key is not supported when " + "also passing a sa_column" + ) + if unique is not Undefined: + raise RuntimeError( + "Passing unique is not supported when " "also passing a sa_column" + ) + if index is not Undefined: + raise RuntimeError( + "Passing index is not supported when " "also passing a sa_column" + ) super().__init__(default=default, **kwargs) self.primary_key = primary_key self.nullable = nullable @@ -126,6 +149,86 @@ def __init__( self.sa_relationship_kwargs = sa_relationship_kwargs +@overload +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, + nullable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: + ... + + +@overload +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: + ... + + def Field( default: Any = Undefined, *, @@ -156,9 +259,9 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: bool = False, - foreign_key: Optional[Any] = None, - unique: bool = False, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore @@ -206,6 +309,27 @@ def Field( return field_info +@overload +def Relationship( + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship_args: Optional[Sequence[Any]] = None, + sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, +) -> Any: + ... + + +@overload +def Relationship( + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship: Optional[RelationshipProperty] = None, # type: ignore +) -> Any: + ... + + def Relationship( *, back_populates: Optional[str] = None, @@ -440,21 +564,28 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", False) + primary_key = getattr(field.field_info, "primary_key", Undefined) + if primary_key is Undefined: + primary_key = False index = getattr(field.field_info, "index", Undefined) if index is Undefined: index = False nullable = not primary_key and _is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field - if hasattr(field.field_info, "nullable"): - field_nullable = getattr(field.field_info, "nullable") # noqa: B009 - if field_nullable != Undefined: - nullable = field_nullable + field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009 + if field_nullable != Undefined: + assert not isinstance(field_nullable, UndefinedType) + nullable = field_nullable args = [] - foreign_key = getattr(field.field_info, "foreign_key", None) - unique = getattr(field.field_info, "unique", False) + foreign_key = getattr(field.field_info, "foreign_key", Undefined) + if foreign_key is Undefined: + foreign_key = None + unique = getattr(field.field_info, "unique", Undefined) + if unique is Undefined: + unique = False if foreign_key: + assert isinstance(foreign_key, str) args.append(ForeignKey(foreign_key)) kwargs = { "primary_key": primary_key, diff --git a/tests/test_field_sa_args_kwargs.py b/tests/test_field_sa_args_kwargs.py new file mode 100644 index 000000000..94a1a1348 --- /dev/null +++ b/tests/test_field_sa_args_kwargs.py @@ -0,0 +1,39 @@ +from typing import Optional + +from sqlalchemy import ForeignKey +from sqlmodel import Field, SQLModel, create_engine + + +def test_sa_column_args(clear_sqlmodel, caplog) -> None: + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + team_id: Optional[int] = Field( + default=None, + sa_column_args=[ForeignKey("team.id")], + ) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + create_table_log = [ + message for message in caplog.messages if "CREATE TABLE hero" in message + ][0] + assert "FOREIGN KEY(team_id) REFERENCES team (id)" in create_table_log + + +def test_sa_column_kargs(clear_sqlmodel, caplog) -> None: + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_kwargs={"primary_key": True}, + ) + + engine = create_engine("sqlite://", echo=True) + SQLModel.metadata.create_all(engine) + create_table_log = [ + message for message in caplog.messages if "CREATE TABLE item" in message + ][0] + assert "PRIMARY KEY (id)" in create_table_log diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py new file mode 100644 index 000000000..51cfdfa79 --- /dev/null +++ b/tests/test_field_sa_column.py @@ -0,0 +1,99 @@ +from typing import Optional + +import pytest +from sqlalchemy import Column, Integer, String +from sqlmodel import Field, SQLModel + + +def test_sa_column_takes_precedence() -> None: + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column=Column(String, primary_key=True, nullable=False), + ) + + # It would have been nullable with no sa_column + assert Item.id.nullable is False # type: ignore + assert isinstance(Item.id.type, String) # type: ignore + + +def test_sa_column_no_sa_args() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_args=[Integer], + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_sa_kargs() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_column_kwargs={"primary_key": True}, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_primary_key() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + primary_key=True, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_nullable() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + nullable=True, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_foreign_key() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + team_id: Optional[int] = Field( + default=None, + foreign_key="team.id", + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_unique() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + unique=True, + sa_column=Column(Integer, primary_key=True), + ) + + +def test_sa_column_no_index() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + index=True, + sa_column=Column(Integer, primary_key=True), + ) diff --git a/tests/test_field_sa_relationship.py b/tests/test_field_sa_relationship.py new file mode 100644 index 000000000..7606fd86d --- /dev/null +++ b/tests/test_field_sa_relationship.py @@ -0,0 +1,53 @@ +from typing import List, Optional + +import pytest +from sqlalchemy.orm import relationship +from sqlmodel import Field, Relationship, SQLModel + + +def test_sa_relationship_no_args() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + headquarters: str + + heroes: List["Hero"] = Relationship( + back_populates="team", + sa_relationship_args=["Hero"], + sa_relationship=relationship("Hero", back_populates="team"), + ) + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship(back_populates="heroes") + + +def test_sa_relationship_no_kwargs() -> None: + with pytest.raises(RuntimeError): + + class Team(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + headquarters: str + + heroes: List["Hero"] = Relationship( + back_populates="team", + sa_relationship_kwargs={"lazy": "selectin"}, + sa_relationship=relationship("Hero", back_populates="team"), + ) + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship(back_populates="heroes") From 6457775a0f6994cdaf48ecaf286bb16d4bd44840 Mon Sep 17 00:00:00 2001 From: github-actions Date: Sat, 28 Oct 2023 13:55:56 +0000 Subject: [PATCH 7/7] =?UTF-8?q?=F0=9F=93=9D=20Update=20release=20notes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/release-notes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/release-notes.md b/docs/release-notes.md index 61cd9f66a..9e5a87696 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -2,6 +2,7 @@ ## Latest Changes +* ✨ Do not allow invalid combinations of field parameters for columns and relationships, `sa_column` excludes `sa_column_args`, `primary_key`, `nullable`, etc.. PR [#681](https://github.com/tiangolo/sqlmodel/pull/681) by [@tiangolo](https://github.com/tiangolo). ## 0.0.10 ### Features