Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations for peewee #4776

Closed
wants to merge 32 commits into from
Closed

Add type annotations for peewee #4776

wants to merge 32 commits into from

Conversation

dargueta
Copy link
Contributor

@dargueta dargueta commented Nov 18, 2020

  • AsIs
  • AutoField
  • BareField
  • BigAutoField
  • BigBitField
  • BigIntegerField
  • BinaryUUIDField
  • BitField
  • BlobField
  • BooleanField
  • Case
  • Cast
  • CharField
  • Check
  • chunked
  • Column
  • CompositeKey
  • Context
  • Database
  • DatabaseError
  • DatabaseProxy
  • DataError
  • DateField
  • DateTimeField
  • DecimalField
  • DeferredForeignKey
  • DeferredThroughModel
  • DJANGO_MAP
  • DoesNotExist
  • DoubleField
  • DQ
  • EXCLUDED
  • Field
  • FixedCharField
  • FloatField
  • fn
  • ForeignKeyField
  • IdentityField
  • ImproperlyConfigured
  • Index
  • IntegerField
  • IntegrityError
  • InterfaceError
  • InternalError
  • IPField
  • JOIN
  • ManyToManyField
  • Model
  • ModelIndex
  • MySQLDatabase
  • NotSupportedError
  • OP
  • OperationalError
  • PostgresqlDatabase
  • PrimaryKeyField (Deprecated)
  • ProgrammingError
  • Proxy
  • QualifiedNames
  • SchemaManager
  • SmallIntegerField
  • Select
  • SQL
  • SqliteDatabase
  • Table
  • TextField
  • TimeField
  • TimestampField
  • Tuple
  • UUIDField
  • Value
  • ValuesList
  • Window

@srittau
Copy link
Collaborator

srittau commented Nov 18, 2020

A few short hints:

  • It's fine to leave types unannotated at the start. They are treated as Any. Of course, annotations are preferable.
  • It's also possible to have incomplete type stubs, marked with __getattr__(). See the CONTRIBUTING guide.
  • Forward references don't need to be quoted in stubs.

@jtschoonhoven
Copy link
Contributor

You've probably seen #4262 already, but flagging it here just in case. Thanks for working on this!

@dargueta
Copy link
Contributor Author

I haven't, actually. Thanks for pointing it out!

@srittau
Copy link
Collaborator

srittau commented Jan 23, 2021

Same heads up as for #4262: Next week we will reshuffle the layout of typeshed. This means that the files in this PR will also need to be moved. This shouldn't be much of a problem, though, as there should be no conflicts.

@dargueta
Copy link
Contributor Author

Thanks for the warning. I'll be sure to do that!

@dargueta
Copy link
Contributor Author

dargueta commented Jan 26, 2021

By the way @srittau I have a few problems with this code and was wondering if there are conventions for handling them. Most of these surround incompatible subclasses.

"Overloaded function signatures 1 and 2 overlap with incompatible return types"

Given code like this:

def __eq__(self, other):
    if isinstance(other, _HashableSource):
        return other is self
    return Expression(...)

I've tried typing it like this:

 @overload
 def __eq__(self, other: _HashableSource) -> bool: ...
 @overload
 def __eq__(self, other: object) -> Expression: ...

MyPy's guidelines for overloading make it seem like this is the proper way to annotate this (specifically this bit) but I'm still getting an error from the linter. This seems like a common enough case that there should be a way to type the function accurately, right?

"Return type Expression of '__eq__' incompatible with return type 'bool' in supertype 'object'"

The above code also throws a second error that I'm not sure how to get around. This is an ORM so a non-boolean return value from __eq__ is expected. I could use # type: ignore[override] but that somehow seems wrong to me.

Additional required arguments

There are numerous cases where we have something like:

class A:
    def method(self, x: int, y: int) -> bool: ...

class B(A):
    # Return type is the same but we have an additional required argument
    def method(self, x: int, y: int, z: int) -> bool: ...  

class C(A):
    # Arguments are the same but the return type is incompatible
    def method(self, x: int, y: int) -> OtherClass: ...

The way I see it, there are two options:

  1. Use # type: ignore[override] on the subclass implementations.
  2. Break the inheritance and insert Protocol classes where necessary. This seems really wrong.

Any thoughts? Sorry to dump all this on you at once.

@srittau
Copy link
Collaborator

srittau commented Jan 26, 2021

By the way @srittau I have a few problems with this code and was wondering if there are conventions for handling them. Most of these surround incompatible subclasses.

"Overloaded function signatures 1 and 2 overlap with incompatible return types"

Given code like this:

def __eq__(self, other):
    if isinstance(other, _HashableSource):
        return other is self
    return Expression(...)

I've tried typing it like this:

 @overload
 def __eq__(self, other: _HashableSource) -> bool: ...
 @overload
 def __eq__(self, other: object) -> Expression: ...

MyPy's guidelines for overloading make it seem like this is the proper way to annotate this (specifically this bit) but I'm still getting an error from the linter. This seems like a common enough case that there should be a way to type the function accurately, right?

"Return type Expression of 'eq' incompatible with return type 'bool' in supertype 'object'"

The above code also throws a second error that I'm not sure how to get around. This is an ORM so a non-boolean return value from __eq__ is expected. I could use # type: ignore[override] but that somehow seems wrong to me.

Both of there require a # type: ignore. (Please don't use # type: ignore[override] as that is a non-standard mypy extension.) In the first example, mypy could be a bit less zealous in its warning, but that's the way it is. In the second example, think of the # type: ignore as "I know we do weird things here, please ignore."

Additional required arguments

There are numerous cases where we have something like:

class A:
    def method(self, x: int, y: int) -> bool: ...

class B(A):
    # Return type is the same but we have an additional required argument
    def method(self, x: int, y: int, z: int) -> bool: ...  

class C(A):
    # Arguments are the same but return type is incompatible
    def method(self, x: int, y: int) -> OtherClass: ...

The way I see it, there are two options:

  1. Use # type: ignore[override] on the subclass implementations.
  2. Break the inheritance and insert Protocol classes where necessary. This seems really wrong.

# type: ignore is the correct way again. In this case it's a marker that the you know that we are breaking the LSP here.

@srittau
Copy link
Collaborator

srittau commented Sep 11, 2021

@dargueta Any updates?

@dargueta
Copy link
Contributor Author

I've been poking at it locally but haven't pushed anything up yet.

@JelleZijlstra
Copy link
Member

Going to close this as stale for now to keep the list of open PRs manageable. If you're still interested in seeing this through, feel free to reopen or open a new PR.

@JelleZijlstra
Copy link
Member

Actually, reopening since there's fairly recent activity. I'd suggest landing a minimal version first instead of trying to get everything fully annotated.

@JelleZijlstra JelleZijlstra reopened this Dec 15, 2021
@mvanlonden
Copy link

mvanlonden commented Jan 8, 2022

Appreciate your work here @dargueta. One final failing check

@dargueta
Copy link
Contributor Author

I'll try to get to it this weekend.

@pylipp
Copy link
Contributor

pylipp commented Jan 31, 2022

Hej, does this patch solve the linter complaints?

diff --git a/third_party/2and3/peewee.pyi b/third_party/2and3/peewee.pyi
index 0d2adef2..423c44d2 100644
--- a/third_party/2and3/peewee.pyi
+++ b/third_party/2and3/peewee.pyi
@@ -34,7 +34,7 @@ from typing import (
 )
 from typing_extensions import Literal, Protocol
 
-T = TypeVar("T")
+_T = TypeVar("_T")
 _TModel = TypeVar("_TModel", bound="Model")
 _TConvFunc = Callable[[Any], Any]
 _TFunc = TypeVar("_TFunc", bound=Callable)
@@ -213,16 +213,16 @@ class _DynamicColumn:
     @overload
     def __get__(self, instance: None, instance_type: type) -> _DynamicColumn: ...
     @overload
-    def __get__(self, instance: T, instance_type: Type[T]) -> ColumnFactory: ...
+    def __get__(self, instance: _T, instance_type: Type[_T]) -> ColumnFactory: ...
 
 class _ExplicitColumn:
     @overload
     def __get__(self, instance: None, instance_type: type) -> _ExplicitColumn: ...
     @overload
-    def __get__(self, instance: T, instance_type: Type[T]) -> NoReturn: ...
+    def __get__(self, instance: _T, instance_type: Type[_T]) -> NoReturn: ...
 
 class _SupportsAlias(Protocol):
-    def alias(self: T, name: str) -> T: ...
+    def alias(self: _T, name: str) -> _T: ...
 
 class Source(_SupportsAlias, Node):
     c: ClassVar[_DynamicColumn]
@@ -420,7 +420,7 @@ class _DynamicEntity:
     @overload
     def __get__(self, instance: None, instance_type: type) -> _DynamicEntity: ...
     @overload
-    def __get__(self, instance: T, instance_type: Type[T]) -> EntityFactory: ...
+    def __get__(self, instance: _T, instance_type: Type[_T]) -> EntityFactory: ...
 
 class Alias(WrappedNode):
     c: ClassVar[_DynamicEntity]
@@ -587,7 +587,7 @@ class ForUpdate(Node):
     def __init__(
         self,
         expr: Union[Literal[True], str],
-        of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...],]] = ...,
+        of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...], ]] = ...,
         nowait: Optional[bool] = ...,
     ): ...
     def __sql__(self, ctx: Context) -> Context: ...
@@ -632,7 +632,7 @@ def qualify_names(node: Expression) -> Expression: ...
 @overload
 def qualify_names(node: ColumnBase) -> QualifiedNames: ...
 @overload
-def qualify_names(node: T) -> T: ...
+def qualify_names(node: _T) -> _T: ...
 
 class OnConflict(Node):
     @overload
@@ -1001,7 +1001,7 @@ class Database(_callable_context_manager):
     def begin(self) -> None: ...
     def commit(self) -> None: ...
     def rollback(self) -> None: ...
-    def batch_commit(self, it: Iterable[T], n: int) -> Iterator[T]: ...
+    def batch_commit(self, it: Iterable[_T], n: int) -> Iterator[_T]: ...
     def table_exists(self, table_name: str, schema: Optional[str] = ...) -> str: ...
     def get_tables(self, schema: Optional[str] = ...) -> List[str]: ...
     def get_indexes(self, table: str, schema: Optional[str] = ...) -> List[IndexMetadata]: ...
@@ -1193,24 +1193,24 @@ class _savepoint(_callable_context_manager):
     def __enter__(self) -> _savepoint: ...
     def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ...
 
-class CursorWrapper(Generic[T]):
+class CursorWrapper(Generic[_T]):
     cursor: __ICursor
     count: int
     index: int
     initialized: bool
     populated: bool
-    row_cache: List[T]
+    row_cache: List[_T]
     def __init__(self, cursor: __ICursor): ...
-    def __iter__(self) -> Union[ResultIterator[T], Iterator[T]]: ...
+    def __iter__(self) -> Union[ResultIterator[_T], Iterator[_T]]: ...
     @overload
-    def __getitem__(self, item: int) -> T: ...
+    def __getitem__(self, item: int) -> _T: ...
     @overload
-    def __getitem__(self, item: slice) -> List[T]: ...
+    def __getitem__(self, item: slice) -> List[_T]: ...
     def __len__(self) -> int: ...
     def initialize(self) -> None: ...
-    def iterate(self, cache: bool = ...) -> T: ...
-    def process_row(self, row: tuple) -> T: ...
-    def iterator(self) -> Iterator[T]: ...
+    def iterate(self, cache: bool = ...) -> _T: ...
+    def process_row(self, row: tuple) -> _T: ...
+    def iterator(self) -> Iterator[_T]: ...
     def fill_cache(self, n: int = ...) -> None: ...
 
 class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ...
@@ -1219,16 +1219,16 @@ class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ...
 class NamedTupleCursorWrapper(CursorWrapper[tuple]):
     tuple_class: Type[tuple]
 
-class ObjectCursorWrapper(DictCursorWrapper[T]):
-    constructor: Callable[..., T]
-    def __init__(self, cursor: __ICursor, constructor: Callable[..., T]): ...
-    def process_row(self, row: tuple) -> T: ...  # type: ignore
+class ObjectCursorWrapper(DictCursorWrapper[_T]):
+    constructor: Callable[..., _T]
+    def __init__(self, cursor: __ICursor, constructor: Callable[..., _T]): ...
+    def process_row(self, row: tuple) -> _T: ...  # type: ignore
 
-class ResultIterator(Generic[T]):
-    cursor_wrapper: CursorWrapper[T]
+class ResultIterator(Generic[_T]):
+    cursor_wrapper: CursorWrapper[_T]
     index: int
-    def __init__(self, cursor_wrapper: CursorWrapper[T]): ...
-    def __iter__(self) -> Iterator[T]: ...
+    def __init__(self, cursor_wrapper: CursorWrapper[_T]): ...
+    def __iter__(self) -> Iterator[_T]: ...
 
 # FIELDS
 
@@ -1240,7 +1240,7 @@ class FieldAccessor:
     @overload
     def __get__(self, instance: None, instance_type: type) -> Field: ...
     @overload
-    def __get__(self, instance: T, instance_type: Type[T]) -> Any: ...
+    def __get__(self, instance: _T, instance_type: Type[_T]) -> Any: ...
 
 class ForeignKeyAccessor(FieldAccessor):
     model: Type[Model]
@@ -1324,9 +1324,9 @@ class Field(ColumnBase):
     def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ...
     @property
     def column(self) -> Column: ...
-    def adapt(self, value: T) -> T: ...
-    def db_value(self, value: T) -> T: ...
-    def python_value(self, value: T) -> T: ...
+    def adapt(self, value: _T) -> _T: ...
+    def db_value(self, value: _T) -> _T: ...
+    def python_value(self, value: _T) -> _T: ...
     def to_value(self, value: Any) -> Value: ...
     def get_sort_key(self, ctx: Context) -> Tuple[int, int]: ...
     def __sql__(self, ctx: Context) -> Context: ...
@@ -1338,7 +1338,7 @@ class IntegerField(Field):
     @overload
     def adapt(self, value: Union[str, float, bool]) -> int: ...  # type: ignore
     @overload
-    def adapt(self, value: T) -> T: ...
+    def adapt(self, value: _T) -> _T: ...
 
 class BigIntegerField(IntegerField): ...
 class SmallIntegerField(IntegerField): ...
@@ -1357,7 +1357,7 @@ class FloatField(Field):
     @overload
     def adapt(self, value: Union[str, float, bool]) -> float: ...  # type: ignore
     @overload
-    def adapt(self, value: T) -> T: ...
+    def adapt(self, value: _T) -> _T: ...
 
 class DoubleField(FloatField): ...
 
@@ -1381,7 +1381,7 @@ class DecimalField(Field):
     @overload
     def db_value(self, value: Union[float, decimal.Decimal]) -> decimal.Decimal: ...  # type: ignore
     @overload
-    def db_value(self, value: T) -> T: ...
+    def db_value(self, value: _T) -> _T: ...
     @overload
     def python_value(self, value: None) -> None: ...
     @overload
@@ -1404,7 +1404,7 @@ class BlobField(Field):
     @overload
     def db_value(self, value: Union[str, bytes]) -> bytearray: ...
     @overload
-    def db_value(self, value: T) -> T: ...
+    def db_value(self, value: _T) -> _T: ...
 
 class BitField(BitwiseMixin, BigIntegerField):
     def __init__(self, *args: object, default: Optional[int] = ..., **kwargs: object): ...
@@ -1434,13 +1434,13 @@ class BigBitField(BlobField):
     @overload
     def db_value(self, value: None) -> None: ...
     @overload
-    def db_value(self, value: T) -> bytes: ...
+    def db_value(self, value: _T) -> bytes: ...
 
 class UUIDField(Field):
     @overload
     def db_value(self, value: AnyStr) -> str: ...
     @overload
-    def db_value(self, value: T) -> T: ...
+    def db_value(self, value: _T) -> _T: ...
     @overload
     def python_value(self, value: Union[uuid.UUID, AnyStr]) -> uuid.UUID: ...
     @overload
@@ -1458,7 +1458,7 @@ class BinaryUUIDField(BlobField):
 
 def format_date_time(value: str, formats: Iterable[str], post_process: Optional[_TConvFunc] = ...) -> str: ...
 @overload
-def simple_date_time(value: T) -> T: ...
+def simple_date_time(value: _T) -> _T: ...
 
 class _BaseFormattedField(Field):
     # TODO (dargueta): This is a class variable that can be overridden for instances
@@ -1478,7 +1478,7 @@ class DateTimeField(_BaseFormattedField):
     def minute(self) -> int: ...
     @property
     def second(self) -> int: ...
-    def adapt(self, value: T) -> T: ...
+    def adapt(self, value: _T) -> _T: ...
     def to_timestamp(self) -> Function: ...
     def truncate(self, part: str) -> Function: ...
 
@@ -1492,7 +1492,7 @@ class DateField(_BaseFormattedField):
     @overload
     def adapt(self, value: datetime.datetime) -> datetime.date: ...
     @overload
-    def adapt(self, value: T) -> T: ...
+    def adapt(self, value: _T) -> _T: ...
     def to_timestamp(self) -> Function: ...
     def truncate(self, part: str) -> Function: ...
 
@@ -1500,7 +1500,7 @@ class TimeField(_BaseFormattedField):
     @overload
     def adapt(self, value: Union[datetime.datetime, datetime.timedelta]) -> datetime.time: ...
     @overload
-    def adapt(self, value: T) -> T: ...
+    def adapt(self, value: _T) -> _T: ...
     @property
     def hour(self) -> int: ...
     @property
@@ -1525,7 +1525,7 @@ class TimestampField(BigIntegerField):
     @overload
     def python_value(self, value: Union[int, float]) -> datetime.datetime: ...
     @overload
-    def python_value(self, value: T) -> T: ...
+    def python_value(self, value: _T) -> _T: ...
     def from_timestamp(self) -> float: ...
     @property
     def year(self) -> int: ...
@@ -1632,12 +1632,12 @@ class ManyToManyFieldAccessor(FieldAccessor):
     dest_fk: ForeignKeyField
     def __init__(self, model: Type[Model], field: ForeignKeyField, name: str): ...
     @overload
-    def __get__(self, instance: None, instance_type: Type[T] = ..., force_query: bool = ...) -> Field: ...
+    def __get__(self, instance: None, instance_type: Type[_T] = ..., force_query: bool = ...) -> Field: ...
     @overload
     def __get__(
-        self, instance: T, instance_type: Type[T] = ..., force_query: bool = ...
+        self, instance: _T, instance_type: Type[_T] = ..., force_query: bool = ...
     ) -> Union[List[str], ManyToManyQuery]: ...
-    def __set__(self, instance: T, value) -> None: ...
+    def __set__(self, instance: _T, value) -> None: ...
 
 class ManyToManyField(MetaField):
     accessor_class: ClassVar[Type[ManyToManyFieldAccessor]]
@@ -1679,7 +1679,7 @@ class CompositeKey(MetaField):
     @overload
     def __get__(self, instance: None, instance_type: type) -> CompositeKey: ...
     @overload
-    def __get__(self, instance: T, instance_type: Type[T]) -> tuple: ...
+    def __get__(self, instance: _T, instance_type: Type[_T]) -> tuple: ...
     def __set__(self, instance: Model, value: Union[list, tuple]) -> None: ...
     def __eq__(self, other: Expression) -> Expression: ...
     def __ne__(self, other: Expression) -> Expression: ...
@@ -1843,7 +1843,7 @@ class Model(Node, metaclass=ModelBase):
     @classmethod
     def delete(cls) -> ModelDelete: ...
     @classmethod
-    def create(cls: Type[T], **query) -> T: ...
+    def create(cls: Type[_T], **query) -> _T: ...
     @classmethod
     def bulk_create(cls, model_list: Iterable[Type[Model]], batch_size: Optional[int] = ...) -> None: ...
     @classmethod
@@ -1874,7 +1874,7 @@ class Model(Node, metaclass=ModelBase):
     @property
     def dirty_fields(self) -> List[Field]: ...
     def dependencies(self, search_nullable: bool = ...) -> Iterator[Tuple[Union[bool, Node], ForeignKeyField]]: ...
-    def delete_instance(self: T, recursive: bool = ..., delete_nullable: bool = ...) -> T: ...
+    def delete_instance(self: _T, recursive: bool = ..., delete_nullable: bool = ...) -> _T: ...
     def __hash__(self) -> int: ...
     def __eq__(self, other: object) -> bool: ...
     def __ne__(self, other: object) -> bool: ...
@@ -1898,8 +1898,7 @@ class Model(Node, metaclass=ModelBase):
     @classmethod
     def truncate_table(cls, **options: object) -> None: ...
     @classmethod
-    def index(cls, *fields: Union[Field, Node, str], **kwargs: object) -> ModelIndex:
-        return ModelIndex(cls, fields, **kwargs)
+    def index(cls, *fields: Union[Field, Node, str], **kwargs: object) -> ModelIndex: ...
     @classmethod
     def add_index(cls, *fields: Union[str, SQL, Index], **kwargs: object) -> None: ...

@dargueta
Copy link
Contributor Author

I'll check it out, thanks

@dargueta
Copy link
Contributor Author

dargueta commented Feb 20, 2022

@pylipp git says that the patch is corrupt on the very last line but it looks fine to me. Any idea what that's about?

Never mind, I just did it manually. Hopefully I didn't miss anything.

@github-actions

This comment has been minimized.

@github-actions

This comment has been minimized.

@github-actions

This comment has been minimized.

@github-actions

This comment has been minimized.

@github-actions

This comment has been minimized.

@AlexWaygood
Copy link
Member

@dargueta, would you be able to look at the remaining mypy errors? In some cases the "overlapping overloads" errors can probably be type: ignored, but some of them might point to real problems. I'm not really familiar enough with the runtime implementation to be able to distinguish the false positives from the true positives without doing a lot of work.

@Akuli
Copy link
Collaborator

Akuli commented Jul 28, 2022

We could split this up into multiple smaller PRs by first auto-generating peewee stubs (there are instructions in CONTRIBUTING.md), and then improving them by applying individual things from this PR. That might be easier than trying to get the whole CI to pass at once, but then dargueta likely won't show up as an author of most of the improvement commits.

@AlexWaygood AlexWaygood added the help wanted An actionable problem of low to medium complexity where a PR would be very welcome label Aug 20, 2022
@github-actions
Copy link
Contributor

github-actions bot commented Oct 3, 2022

According to mypy_primer, this change has no effect on the checked open source code. 🤖🎉

@srittau
Copy link
Collaborator

srittau commented Oct 4, 2022

Closing this now as we now have a first auto-generated version in typeshed. As I understand it, @Akuli is planning on stealing from this PR in smaller PRs. Thank you, @dargueta, and I'm sorry for any delays on our part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted An actionable problem of low to medium complexity where a PR would be very welcome
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants