Skip to content

Commit

Permalink
✨ Add support for hybrid_property
Browse files Browse the repository at this point in the history
  • Loading branch information
van51 committed Oct 31, 2022
1 parent 75ce455 commit 55c9526
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 2 deletions.
11 changes: 10 additions & 1 deletion sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
Expand Down Expand Up @@ -290,7 +291,11 @@ def get_config(name: str) -> Any:
# If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.table = config_table
for k, v in new_cls.__fields__.items():
col = get_column_from_field(v)
col = v
# Treat `hybrid_property` properties as already specified columns
# and let sqlalchemy take care of them
if not issubclass(v.type_, hybrid_property):
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict.
Expand Down Expand Up @@ -326,6 +331,10 @@ def __init__(
if getattr(cls.__config__, "table", False) and not base_is_table:
dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items():
# Ignore `hybrid_property` properties as already specified columns
# and let sqlalchemy take care of them
if issubclass(field_value.type_, hybrid_property):
continue
dict_used[field_name] = get_column_from_field(field_value)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
Expand Down
1 change: 1 addition & 0 deletions sqlmodel/sql/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Select(_Select, Generic[_TSelect]):
class SelectOfScalar(_Select, Generic[_TSelect]):
inherit_cache = True


else:
from typing import GenericMeta # type: ignore

Expand Down
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
from pydantic import BaseModel
from sqlmodel import SQLModel
from sqlmodel import SQLModel, create_engine
from sqlmodel.main import default_registry

top_level_path = Path(__file__).resolve().parent.parent
Expand All @@ -23,6 +23,13 @@ def clear_sqlmodel():
default_registry.dispose()


@pytest.fixture()
def in_memory_engine(clear_sqlmodel):
engine = create_engine("sqlite:///memory")
yield engine
SQLModel.metadata.drop_all(engine, checkfirst=True)


@pytest.fixture()
def cov_tmp_path(tmp_path: Path):
yield tmp_path
Expand Down
41 changes: 41 additions & 0 deletions tests/test_sqlalchemy_properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Optional

from sqlalchemy import func
from sqlalchemy.ext.hybrid import hybrid_property
from sqlmodel import Field, Session, SQLModel, select


def test_hybrid_property(in_memory_engine):
class Interval(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
length: float

@hybrid_property
def radius(self) -> float:
return abs(self.length) / 2

@radius.expression
def radius(cls) -> float:
return func.abs(cls.length) / 2

class Config:
arbitrary_types_allowed = True

SQLModel.metadata.create_all(in_memory_engine)
session = Session(in_memory_engine)

interval = Interval(length=-2)
assert interval.radius == 1

session.add(interval)
session.commit()
interval_2 = session.exec(select(Interval)).all()[0]
assert interval_2.radius == 1

interval_3 = session.exec(select(Interval).where(Interval.radius == 1)).all()[0]
assert interval_3.radius == 1

intervals = session.exec(select(Interval).where(Interval.radius > 1)).all()
assert len(intervals) == 0

assert session.exec(select(Interval.radius + 1)).all()[0] == 2.0

0 comments on commit 55c9526

Please sign in to comment.