Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 49 additions & 29 deletions tests/test_qblike_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,21 @@
Bool,
Length,
GetArg,
GetInit,
GetMemberType,
GetName,
GetQuals,
GetSpecialAttr,
GetType,
GetInit,
InitField,
IsAssignable,
Iter,
IsEquivalent,
Member,
Members,
NewProtocol,
Slice,
UpdateClass,
)

from . import format_helper
Expand Down Expand Up @@ -134,16 +137,39 @@ class DbString:


class Table[name: str]:
def __init_subclass__[T](
cls: type[T],
) -> UpdateClass[
*[
Member[
GetName[m],
_Field[
GetArg[GetType[m], Field, Literal[0]],
T,
GetName[m],
],
GetQuals[m],
GetInit[m],
]
for m in Iter[Members[T]]
if IsAssignable[GetType[m], Field]
],
*[m for m in Iter[Members[T]] if not IsAssignable[GetType[m], Field]],
]:
super().__init_subclass__()


class Field[PyType]:
pass


class Field[Table, Name, PyType]:
class _Field[PyType, Table, Name]:
def __lt__(self, other: Any) -> Filter[Table]: ...


type FieldTable[T] = GetArg[T, Field, Literal[0]]
type FieldName[T] = GetArg[T, Field, Literal[1]]
type FieldPyType[T] = GetArg[T, Field, Literal[2]]
type FieldPyType[T] = GetArg[T, _Field, Literal[0]]
type FieldTable[T] = GetArg[T, _Field, Literal[1]]
type FieldName[T] = GetArg[T, _Field, Literal[2]]


class ColumnArgs(TypedDict, total=False):
Expand Down Expand Up @@ -236,7 +262,7 @@ class DbLinkSource[Args: DbLinkSourceArgs](InitField[Args]):
*[
GetName[m]
for m in Iter[Attrs[T]]
if IsAssignable[GetType[m], Field]
if IsAssignable[GetType[m], _Field]
],
],
]
Expand All @@ -249,7 +275,7 @@ class DbLinkSource[Args: DbLinkSourceArgs](InitField[Args]):
*[
GetName[m]
for m in Iter[Attrs[T]]
if IsAssignable[GetType[m], Field]
if IsAssignable[GetType[m], _Field]
and any(
IsAssignable[FieldName[GetType[m]], f] for f in Iter[FieldNames]
)
Expand All @@ -268,7 +294,7 @@ class DbLinkSource[Args: DbLinkSourceArgs](InitField[Args]):
else [MakeQueryEntryAllFields[New]]
),
]
type AddField[Entries, New: Field] = tuple[
type AddField[Entries, New: _Field] = tuple[
*[ # Existing entries
(
e # Non-matching entry
Expand All @@ -286,7 +312,7 @@ class DbLinkSource[Args: DbLinkSourceArgs](InitField[Args]):
if not Bool[EntriesHasTable[Entries, FieldTable[New]]]
),
]
type AddEntries[Entries, News: tuple[Table | Field, ...]] = (
type AddEntries[Entries, News: tuple[Table | _Field, ...]] = (
Entries
if IsAssignable[Length[News], Literal[0]]
else AddEntries[
Expand Down Expand Up @@ -351,50 +377,44 @@ def execute[Es: tuple[type[Table], ...]](


class User(Table[Literal["users"]]):
id: Field[User, Literal["id"], int] = column(
id: Field[int] = column(
db_type=DbInteger(), primary_key=True, autoincrement=True
)
name: Field[User, Literal["name"], str] = column(
db_type=DbString(length=150), nullable=False
)
email: Field[User, Literal["email"], str] = column(
name: Field[str] = column(db_type=DbString(length=150), nullable=False)
email: Field[str] = column(
db_type=DbString(length=100), unique=True, nullable=False
)
age: Field[User, Literal["age"], int | None] = column(db_type=DbInteger())
active: Field[User, Literal["active"], bool] = column(
age: Field[int | None] = column(db_type=DbInteger())
active: Field[bool] = column(
db_type=DbBoolean(), default=True, nullable=False
)
posts: Field[User, Literal["posts"], list[Post]] = column(
posts: Field[list[Post]] = column(
db_type=DbLinkSource(source="Post", cardinality=Cardinality.MANY)
)


class Post(Table[Literal["posts"]]):
id: Field[Post, Literal["id"], int] = column(
id: Field[int] = column(
db_type=DbInteger(), primary_key=True, autoincrement=True
)
content: Field[Post, Literal["content"], str] = column(
db_type=DbString(length=1000), nullable=False
)
author: Field[Post, Literal["author"], User] = column(
content: Field[str] = column(db_type=DbString(length=1000), nullable=False)
author: Field[User] = column(
db_type=DbLinkTarget(target=User), nullable=False
)
comments: Field[Post, Literal["comments"], list[Comment]] = column(
comments: Field[list[Comment]] = column(
db_type=DbLinkSource(source="Comment", cardinality=Cardinality.MANY)
)


class Comment(Table[Literal["comments"]]):
id: Field[Comment, Literal["id"], int] = column(
id: Field[int] = column(
db_type=DbInteger(), primary_key=True, autoincrement=True
)
content: Field[Comment, Literal["content"], str] = column(
db_type=DbString(length=1000), nullable=False
)
author: Field[Comment, Literal["author"], User] = column(
content: Field[str] = column(db_type=DbString(length=1000), nullable=False)
author: Field[User] = column(
db_type=DbLinkTarget(target=User), nullable=False
)
post: Field[Comment, Literal["post"], Post] = column(
post: Field[Post] = column(
db_type=DbLinkTarget(target=Post), nullable=False
)

Expand Down
52 changes: 50 additions & 2 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,55 @@ class D[T](C[T]):
)


def test_update_class_members_11():
class A:
a: int

def __init_subclass__[T](
cls: type[T],
) -> UpdateClass[*Members[T]]:
super().__init_subclass__()

def f(self) -> int: ...

class B(A):
b: str

def g(self) -> str: ...

attrs = eval_typing(Attrs[B])
assert (
attrs
== tuple[
Member[Literal["a"], int, Never, Never, B],
Member[Literal["b"], str, Never, Never, B],
]
)

members = eval_typing(MembersExceptInitSubclass[B])
assert (
members
== tuple[
Member[Literal["a"], int, Never, Never, B],
Member[Literal["b"], str, Never, Never, B],
Member[
Literal["f"],
Callable[[Param[Literal["self"], Self]], int],
Literal["ClassVar"],
object,
B,
],
Member[
Literal["g"],
Callable[[Param[Literal["self"], Self]], str],
Literal["ClassVar"],
object,
B,
],
]
)


def test_update_class_inheritance_01():
# current class init subclass is not applied
class A:
Expand Down Expand Up @@ -2327,7 +2376,6 @@ class C(B[float]):
assert eval_typing(GetArg[C, A, Literal[1]]) is float


@pytest.mark.xfail(reason="TODO")
def test_update_class_empty_01():
class A:
a: int
Expand All @@ -2341,7 +2389,7 @@ class B(A):
b: int

attrs = eval_typing(Attrs[B])
assert attrs == tuple[()]
assert attrs == tuple[Member[Literal["a"], int, Never, Never, A]]


##############
Expand Down
18 changes: 11 additions & 7 deletions typemap/type_eval/_eval_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typemap.type_eval import _apply_generic, _typing_inspect
from typemap.type_eval._eval_typing import (
_child_context,
_eval_args,
_eval_types,
EvalContext,
)
Expand Down Expand Up @@ -192,7 +193,8 @@ def _eval_init_subclass(
"""Get type after all __init_subclass__ with UpdateClass are evaluated."""
for abox in box.mro[1:]: # Skip the type itself
with _child_context() as ctx:
if ms := _get_update_class_members(box, abox, ctx=ctx):
ms = _get_update_class_members(box, abox, ctx=ctx)
if ms is not None:
nbox = _apply_generic.box(
_create_updated_class(box, ms, ctx=ctx)
)
Expand All @@ -208,7 +210,7 @@ def _get_update_class_members(
box: _apply_generic.Boxed,
boxed_base: _apply_generic.Boxed,
ctx: EvalContext,
) -> list[Member] | None:
) -> typing.Sequence[Member] | None:
cls = box.cls

# Get __init_subclass__ from the base class's origin if base is generic.
Expand Down Expand Up @@ -267,13 +269,13 @@ def _get_update_class_members(
_typing_inspect.is_generic_alias(evaled_ret)
and typing.get_origin(evaled_ret) is UpdateClass
):
return [m for m in typing.get_args(evaled_ret)]
return _eval_args(typing.get_args(evaled_ret), ctx)

return None


def _create_updated_class(
box: _apply_generic.Boxed, ms: list[Member], ctx: EvalContext
box: _apply_generic.Boxed, ms: typing.Sequence[Member], ctx: EvalContext
) -> type:
t = box.cls
dct: dict[str, object] = {}
Expand All @@ -289,9 +291,11 @@ def _create_updated_class(
typ = _eval_types(typ, ctx)
tquals = _eval_types(quals, ctx)

if type_eval.issubtype(
typing.Literal["ClassVar"], tquals
) and _is_method_like(typ):
if (
type_eval.issubtype(typing.Literal["ClassVar"], tquals)
and _is_method_like(typ)
and _typing_inspect.get_head(typ) is not GenericCallable
):
dct[member_name] = _callable_type_to_method(member_name, typ, ctx)
else:
# Update/add the annotation
Expand Down