From 1f2358169c6690e36153e99983c53a6c370f5e40 Mon Sep 17 00:00:00 2001 From: Robin van der Noord Date: Tue, 27 Jun 2023 17:50:46 +0200 Subject: [PATCH] feat(set): type hints for .count() and .select() --- example_new.py | 11 +++++++++-- src/typedal/core.py | 42 +++++++++++++++++++++++++++++++++++++----- tests/test_main.py | 12 ++++++++++++ tests/test_mypy.py | 34 ++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 7 deletions(-) create mode 100644 tests/test_mypy.py diff --git a/example_new.py b/example_new.py index 3212431..b3dd965 100644 --- a/example_new.py +++ b/example_new.py @@ -85,7 +85,7 @@ class AllFieldsBasic(TypedTable): upload = TypedField(str, type="upload", uploadfield="upload_data") upload_data: bytes reference: OtherTable - reference_two: typing.Optional[db.other_table] + reference_two: typing.Optional[db.other_table] # type: ignore list_string: list[str] list_integer: list[int] list_reference: list[OtherTable] @@ -112,7 +112,7 @@ class AllFieldsAdvanced(TypedTable): upload = TypedField(str, type="upload", uploadfield="upload_data") upload_data = TypedField(bytes) reference = TypedField(OtherTable) - reference_two = TypedField(db.other_table, notnull=False) + reference_two: int = TypedField(db.other_table, notnull=False) list_string = TypedField(list[str]) list_integer = TypedField(list[int]) list_reference = TypedField(list[OtherTable]) @@ -249,3 +249,10 @@ class AllFieldsExplicit(TypedTable): # for field in rowa: # print(field, type(rowa[field])) print(rowb) + +counted = db(AllFieldsExplicit).count() + +rows: TypedRows[AllFieldsExplicit] = db(AllFieldsExplicit).select() + +for row in rows: + print(row.id) diff --git a/src/typedal/core.py b/src/typedal/core.py index ac98c9a..5dfc75d 100644 --- a/src/typedal/core.py +++ b/src/typedal/core.py @@ -8,7 +8,7 @@ from decimal import Decimal import pydal -from pydal.objects import Field, Query, Row, Rows, Table +from pydal.objects import Field, Row, Rows, Table # use typing.cast(type, ...) to make mypy happy with unions T_annotation = typing.Type[typing.Any] | types.UnionType @@ -83,7 +83,7 @@ class TypeDAL(pydal.DAL): # type: ignore "notnull": True, } - def define(self, cls: T) -> Table: + def define(self, cls: T) -> T: """ Can be used as a decorator on a class that inherits `TypedTable`, \ or as a regular method if you need to define your classes before you have access to a 'db' instance. @@ -140,9 +140,9 @@ class Article(TypedTable): # the ACTUAL output is not TypedTable but rather pydal.Table # but telling the editor it is T helps with hinting. - return table + return cls - def __call__(self, *_args: Query | bool, **kwargs: typing.Any) -> pydal.objects.Set: + def __call__(self, *_args: typing.Any, **kwargs: typing.Any) -> "TypedSet": """ A db instance can be called directly to perform a query. @@ -162,7 +162,8 @@ def __call__(self, *_args: Query | bool, **kwargs: typing.Any) -> pydal.objects. # table defined without @db.define decorator! args[0] = cls.id != None - return super().__call__(*args, **kwargs) + _set = super().__call__(*args, **kwargs) + return typing.cast(TypedSet, _set) # todo: insert etc shadowen? @@ -419,3 +420,34 @@ class TypedRows(typing.Collection[S], Rows): # type: ignore Example: people: TypedRows[Person] = db(Person).select() """ + + +T_Table = typing.TypeVar("T_Table", bound=Table) + + +class TypedSet(pydal.objects.Set): # type: ignore # pragma: no cover + """ + Used to make pydal Set more typed. + + This class is not actually used, only 'cast' by TypeDAL.__call__ + """ + + def count(self, distinct: bool = None, cache: dict[str, typing.Any] = None) -> int: + """ + Count returns an int. + """ + result = super().count(distinct, cache) + return typing.cast(int, result) + + def select(self, *fields: typing.Any, **attributes: typing.Any) -> TypedRows[T_Table]: + """ + Select returns a TypedRows of a user defined table. + + Example: + result: TypedRows[MyTable] = db(MyTable.id > 0).select() + + for row in result: + typing.reveal_type(row) # MyTable + """ + rows = super().select(*fields, **attributes) + return typing.cast(TypedRows[T_Table], rows) diff --git a/tests/test_main.py b/tests/test_main.py index 0c09383..e69814a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -257,3 +257,15 @@ class OtherNewTable(TypedTable): assert str(ListReferenceField("somenewtable")) == "TypedField.list:reference somenewtable" assert str(JSONField()) == "TypedField.json" assert str(BigintField()) == "TypedField.bigint" + + # test typedset: + counted1 = db(SomeNewTable).count() + counted2 = db(OtherNewTable).count() + counted3 = db(db.some_new_table).count() + + assert counted1 == counted2 == counted3 == 0 + + select2: TypedRows[SomeNewTable] = db(SomeNewTable).select() + + for row in select2: + raise ValueError("no rows should exist") \ No newline at end of file diff --git a/tests/test_mypy.py b/tests/test_mypy.py new file mode 100644 index 0000000..516ad8c --- /dev/null +++ b/tests/test_mypy.py @@ -0,0 +1,34 @@ +import pytest +import typing + +from src.typedal import TypeDAL, TypedTable, TypedRows + +db = TypeDAL("sqlite:memory") + + +@db.define +class MyTable(TypedTable): + ... + + +old_style = db.define_table("old_table") + + +@pytest.mark.mypy_testing +def mypy_test_typedset() -> None: + counted1 = db(MyTable).count() + counted2 = db(db.old_style).count() + counted3 = db(old_style).count() + + typing.reveal_type(counted1) # R: builtins.int + typing.reveal_type(counted2) # R: builtins.int + typing.reveal_type(counted3) # R: builtins.int + + select1 = db(MyTable).select() # E: Need type annotation for "select1" + select2: TypedRows[MyTable] = db(MyTable).select() + + typing.reveal_type(select1) # R: src.typedal.core.TypedRows[Any] + typing.reveal_type(select2) # R: src.typedal.core.TypedRows[tests.test_mypy.MyTable] + + for row in select2: + typing.reveal_type(row) # R: tests.test_mypy.MyTable