Skip to content

Commit

Permalink
feat(set): type hints for .count() and .select()
Browse files Browse the repository at this point in the history
  • Loading branch information
robinvandernoord committed Jun 27, 2023
1 parent f6ebd38 commit 1f23581
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 7 deletions.
11 changes: 9 additions & 2 deletions example_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class AllFieldsBasic(TypedTable):
upload = TypedField(str, type="upload", uploadfield="upload_data")
upload_data: bytes
reference: OtherTable
reference_two: typing.Optional[db.other_table]
reference_two: typing.Optional[db.other_table] # type: ignore
list_string: list[str]
list_integer: list[int]
list_reference: list[OtherTable]
Expand All @@ -112,7 +112,7 @@ class AllFieldsAdvanced(TypedTable):
upload = TypedField(str, type="upload", uploadfield="upload_data")
upload_data = TypedField(bytes)
reference = TypedField(OtherTable)
reference_two = TypedField(db.other_table, notnull=False)
reference_two: int = TypedField(db.other_table, notnull=False)
list_string = TypedField(list[str])
list_integer = TypedField(list[int])
list_reference = TypedField(list[OtherTable])
Expand Down Expand Up @@ -249,3 +249,10 @@ class AllFieldsExplicit(TypedTable):
# for field in rowa:
# print(field, type(rowa[field]))
print(rowb)

counted = db(AllFieldsExplicit).count()

rows: TypedRows[AllFieldsExplicit] = db(AllFieldsExplicit).select()

for row in rows:
print(row.id)
42 changes: 37 additions & 5 deletions src/typedal/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from decimal import Decimal

import pydal
from pydal.objects import Field, Query, Row, Rows, Table
from pydal.objects import Field, Row, Rows, Table

# use typing.cast(type, ...) to make mypy happy with unions
T_annotation = typing.Type[typing.Any] | types.UnionType
Expand Down Expand Up @@ -83,7 +83,7 @@ class TypeDAL(pydal.DAL): # type: ignore
"notnull": True,
}

def define(self, cls: T) -> Table:
def define(self, cls: T) -> T:
"""
Can be used as a decorator on a class that inherits `TypedTable`, \
or as a regular method if you need to define your classes before you have access to a 'db' instance.
Expand Down Expand Up @@ -140,9 +140,9 @@ class Article(TypedTable):

# the ACTUAL output is not TypedTable but rather pydal.Table
# but telling the editor it is T helps with hinting.
return table
return cls

def __call__(self, *_args: Query | bool, **kwargs: typing.Any) -> pydal.objects.Set:
def __call__(self, *_args: typing.Any, **kwargs: typing.Any) -> "TypedSet":
"""
A db instance can be called directly to perform a query.
Expand All @@ -162,7 +162,8 @@ def __call__(self, *_args: Query | bool, **kwargs: typing.Any) -> pydal.objects.
# table defined without @db.define decorator!
args[0] = cls.id != None

return super().__call__(*args, **kwargs)
_set = super().__call__(*args, **kwargs)
return typing.cast(TypedSet, _set)

# todo: insert etc shadowen?

Expand Down Expand Up @@ -419,3 +420,34 @@ class TypedRows(typing.Collection[S], Rows): # type: ignore
Example:
people: TypedRows[Person] = db(Person).select()
"""


T_Table = typing.TypeVar("T_Table", bound=Table)


class TypedSet(pydal.objects.Set): # type: ignore # pragma: no cover
"""
Used to make pydal Set more typed.
This class is not actually used, only 'cast' by TypeDAL.__call__
"""

def count(self, distinct: bool = None, cache: dict[str, typing.Any] = None) -> int:
"""
Count returns an int.
"""
result = super().count(distinct, cache)
return typing.cast(int, result)

def select(self, *fields: typing.Any, **attributes: typing.Any) -> TypedRows[T_Table]:
"""
Select returns a TypedRows of a user defined table.
Example:
result: TypedRows[MyTable] = db(MyTable.id > 0).select()
for row in result:
typing.reveal_type(row) # MyTable
"""
rows = super().select(*fields, **attributes)
return typing.cast(TypedRows[T_Table], rows)
12 changes: 12 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,15 @@ class OtherNewTable(TypedTable):
assert str(ListReferenceField("somenewtable")) == "TypedField.list:reference somenewtable"
assert str(JSONField()) == "TypedField.json"
assert str(BigintField()) == "TypedField.bigint"

# test typedset:
counted1 = db(SomeNewTable).count()
counted2 = db(OtherNewTable).count()
counted3 = db(db.some_new_table).count()

assert counted1 == counted2 == counted3 == 0

select2: TypedRows[SomeNewTable] = db(SomeNewTable).select()

for row in select2:
raise ValueError("no rows should exist")
34 changes: 34 additions & 0 deletions tests/test_mypy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
import typing

from src.typedal import TypeDAL, TypedTable, TypedRows

db = TypeDAL("sqlite:memory")


@db.define
class MyTable(TypedTable):
...


old_style = db.define_table("old_table")


@pytest.mark.mypy_testing
def mypy_test_typedset() -> None:
counted1 = db(MyTable).count()
counted2 = db(db.old_style).count()
counted3 = db(old_style).count()

typing.reveal_type(counted1) # R: builtins.int
typing.reveal_type(counted2) # R: builtins.int
typing.reveal_type(counted3) # R: builtins.int

select1 = db(MyTable).select() # E: Need type annotation for "select1"
select2: TypedRows[MyTable] = db(MyTable).select()

typing.reveal_type(select1) # R: src.typedal.core.TypedRows[Any]
typing.reveal_type(select2) # R: src.typedal.core.TypedRows[tests.test_mypy.MyTable]

for row in select2:
typing.reveal_type(row) # R: tests.test_mypy.MyTable

0 comments on commit 1f23581

Please sign in to comment.