Skip to content

Commit

Permalink
feat: add unique constraint param to Field function
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Gibson authored and raphaelgibson committed Jun 10, 2022
1 parent 4d20051 commit 2b39e32
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
6 changes: 6 additions & 0 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", Undefined)
foreign_key = kwargs.pop("foreign_key", Undefined)
unique = kwargs.pop("unique", False)
index = kwargs.pop("index", Undefined)
sa_column = kwargs.pop("sa_column", Undefined)
sa_column_args = kwargs.pop("sa_column_args", Undefined)
Expand All @@ -88,6 +89,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
self.primary_key = primary_key
self.nullable = nullable
self.foreign_key = foreign_key
self.unique = unique
self.index = index
self.sa_column = sa_column
self.sa_column_args = sa_column_args
Expand Down Expand Up @@ -149,6 +151,7 @@ def Field(
regex: Optional[str] = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
unique: bool = False,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
Expand Down Expand Up @@ -179,6 +182,7 @@ def Field(
regex=regex,
primary_key=primary_key,
foreign_key=foreign_key,
unique=unique,
nullable=nullable,
index=index,
sa_column=sa_column,
Expand Down Expand Up @@ -432,12 +436,14 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
nullable = field_nullable
args = []
foreign_key = getattr(field.field_info, "foreign_key", None)
unique = getattr(field.field_info, "unique", False)
if foreign_key:
args.append(ForeignKey(foreign_key))
kwargs = {
"primary_key": primary_key,
"nullable": nullable,
"index": index,
"unique": unique
}
sa_default = Undefined
if field.field_info.default_factory:
Expand Down
91 changes: 91 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
from typing import Optional

from sqlmodel import Field, Session, SQLModel, create_engine
from sqlalchemy.exc import IntegrityError


def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel):
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None

hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")

engine = create_engine("sqlite://")

SQLModel.metadata.create_all(engine)

with Session(engine) as session:
session.add(hero_1)
session.commit()
session.refresh(hero_1)

with Session(engine) as session:
session.add(hero_2)
session.commit()
session.refresh(hero_2)

with Session(engine) as session:
heroes = session.query(Hero).all()
assert len(heroes) == 2
assert heroes[0].name == heroes[1].name


def test_should_allow_duplicate_row_if_unique_constraint_is_false(clear_sqlmodel):
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str = Field(unique=False)
age: Optional[int] = None

hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")

engine = create_engine("sqlite://")

SQLModel.metadata.create_all(engine)

with Session(engine) as session:
session.add(hero_1)
session.commit()
session.refresh(hero_1)

with Session(engine) as session:
session.add(hero_2)
session.commit()
session.refresh(hero_2)

with Session(engine) as session:
heroes = session.query(Hero).all()
assert len(heroes) == 2
assert heroes[0].name == heroes[1].name


def test_should_raise_exception_when_try_to_duplicate_row_if_unique_constraint_is_true(clear_sqlmodel):
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str = Field(unique=True)
age: Optional[int] = None

hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")

engine = create_engine("sqlite://")

SQLModel.metadata.create_all(engine)

with Session(engine) as session:
session.add(hero_1)
session.commit()
session.refresh(hero_1)

with pytest.raises(IntegrityError):
with Session(engine) as session:
session.add(hero_2)
session.commit()
session.refresh(hero_2)

0 comments on commit 2b39e32

Please sign in to comment.