Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Raise a more clear error when a type is not valid #425

Merged
merged 11 commits into from
Oct 23, 2023
79 changes: 40 additions & 39 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,45 +374,46 @@ def __init__(


def get_sqlalchemy_type(field: ModelField) -> Any:
if issubclass(field.type_, str):
if field.field_info.max_length:
return AutoString(length=field.field_info.max_length)
return AutoString
if issubclass(field.type_, float):
return Float
if issubclass(field.type_, bool):
return Boolean
if issubclass(field.type_, int):
return Integer
if issubclass(field.type_, datetime):
return DateTime
if issubclass(field.type_, date):
return Date
if issubclass(field.type_, timedelta):
return Interval
if issubclass(field.type_, time):
return Time
if issubclass(field.type_, Enum):
return sa_Enum(field.type_)
if issubclass(field.type_, bytes):
return LargeBinary
if issubclass(field.type_, Decimal):
return Numeric(
precision=getattr(field.type_, "max_digits", None),
scale=getattr(field.type_, "decimal_places", None),
)
if issubclass(field.type_, ipaddress.IPv4Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv4Network):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Network):
return AutoString
if issubclass(field.type_, Path):
return AutoString
if issubclass(field.type_, uuid.UUID):
return GUID
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
if issubclass(field.type_, str):
if field.field_info.max_length:
return AutoString(length=field.field_info.max_length)
return AutoString
if issubclass(field.type_, float):
return Float
if issubclass(field.type_, bool):
return Boolean
if issubclass(field.type_, int):
return Integer
if issubclass(field.type_, datetime):
return DateTime
if issubclass(field.type_, date):
return Date
if issubclass(field.type_, timedelta):
return Interval
if issubclass(field.type_, time):
return Time
if issubclass(field.type_, Enum):
return sa_Enum(field.type_)
if issubclass(field.type_, bytes):
return LargeBinary
if issubclass(field.type_, Decimal):
return Numeric(
precision=getattr(field.type_, "max_digits", None),
scale=getattr(field.type_, "decimal_places", None),
)
if issubclass(field.type_, ipaddress.IPv4Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv4Network):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Address):
return AutoString
if issubclass(field.type_, ipaddress.IPv6Network):
return AutoString
if issubclass(field.type_, Path):
return AutoString
if issubclass(field.type_, uuid.UUID):
return GUID
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")


Expand Down
28 changes: 28 additions & 0 deletions tests/test_sqlalchemy_type_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Dict, List, Optional, Union

import pytest
from sqlmodel import Field, SQLModel


def test_type_list_breaks() -> None:
with pytest.raises(ValueError):

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
tags: List[str]


def test_type_dict_breaks() -> None:
with pytest.raises(ValueError):

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
tags: Dict[str, Any]


def test_type_union_breaks() -> None:
with pytest.raises(ValueError):

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
tags: Union[int, str]