Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Counter generic over the value #11632

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
41 changes: 21 additions & 20 deletions stdlib/collections/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ _KT = TypeVar("_KT")
_VT = TypeVar("_VT")
_KT_co = TypeVar("_KT_co", covariant=True)
_VT_co = TypeVar("_VT_co", covariant=True)
_C = TypeVar("_C", default=int)

# namedtuple is special-cased in the type checker; the initializer is ignored.
def namedtuple(
Expand Down Expand Up @@ -261,55 +262,55 @@ class deque(MutableSequence[_T]):
if sys.version_info >= (3, 9):
def __class_getitem__(cls, item: Any, /) -> GenericAlias: ...

class Counter(dict[_T, int], Generic[_T]):
class Counter(dict[_T, _C]):
@overload
def __init__(self, iterable: None = None, /) -> None: ...
@overload
def __init__(self: Counter[str], iterable: None = None, /, **kwargs: int) -> None: ...
def __init__(self: Counter[str, _C], iterable: None = None, /, **kwargs: _C) -> None: ...
@overload
def __init__(self, mapping: SupportsKeysAndGetItem[_T, int], /) -> None: ...
def __init__(self: Counter[_T, _C], mapping: SupportsKeysAndGetItem[_T, _C], /) -> None: ...
@overload
def __init__(self, iterable: Iterable[_T], /) -> None: ...
def copy(self) -> Self: ...
def elements(self) -> Iterator[_T]: ...
def most_common(self, n: int | None = None) -> list[tuple[_T, int]]: ...
@classmethod
def fromkeys(cls, iterable: Any, v: int | None = None) -> NoReturn: ... # type: ignore[override]
def fromkeys(cls, iterable: Any, v: _C | None = None) -> NoReturn: ... # type: ignore[override]
@overload
def subtract(self, iterable: None = None, /) -> None: ...
@overload
def subtract(self, mapping: Mapping[_T, int], /) -> None: ...
def subtract(self, mapping: Mapping[_T, _C], /) -> None: ...
@overload
def subtract(self, iterable: Iterable[_T], /) -> None: ...
# Unlike dict.update(), use Mapping instead of SupportsKeysAndGetItem for the first overload
# (source code does an `isinstance(other, Mapping)` check)
#
# The second overload is also deliberately different to dict.update()
# (if it were `Iterable[_T] | Iterable[tuple[_T, int]]`,
# (if it were `Iterable[_T] | Iterable[tuple[_T, _C]]`,
# the tuples would be added as keys, breaking type safety)
@overload # type: ignore[override]
def update(self, m: Mapping[_T, int], /, **kwargs: int) -> None: ...
def update(self, m: Mapping[_T, _C], /, **kwargs: _C) -> None: ...
@overload
def update(self, iterable: Iterable[_T], /, **kwargs: int) -> None: ...
def update(self, iterable: Iterable[_T], /, **kwargs: _C) -> None: ...
@overload
def update(self, iterable: None = None, /, **kwargs: int) -> None: ...
def __missing__(self, key: _T) -> int: ...
def update(self, iterable: None = None, /, **kwargs: _C) -> None: ...
def __missing__(self, key: _T) -> _C: ...
def __delitem__(self, elem: object) -> None: ...
if sys.version_info >= (3, 10):
def __eq__(self, other: object) -> bool: ...
def __ne__(self, other: object) -> bool: ...

def __add__(self, other: Counter[_S]) -> Counter[_T | _S]: ...
def __sub__(self, other: Counter[_T]) -> Counter[_T]: ...
def __and__(self, other: Counter[_T]) -> Counter[_T]: ...
def __or__(self, other: Counter[_S]) -> Counter[_T | _S]: ... # type: ignore[override]
def __pos__(self) -> Counter[_T]: ...
def __neg__(self) -> Counter[_T]: ...
def __add__(self, other: Counter[_S, _C]) -> Counter[_T | _S, _C]: ...
def __sub__(self, other: Counter[_T, _C]) -> Counter[_T, _C]: ...
def __and__(self, other: Counter[_T, _C]) -> Counter[_T, _C]: ...
def __or__(self, other: Counter[_S, _C]) -> Counter[_T | _S, _C]: ... # type: ignore[override]
def __pos__(self) -> Counter[_T, _C]: ...
def __neg__(self) -> Counter[_T, _C]: ...
# several type: ignores because __iadd__ is supposedly incompatible with __add__, etc.
def __iadd__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[misc]
def __isub__(self, other: SupportsItems[_T, int]) -> Self: ...
def __iand__(self, other: SupportsItems[_T, int]) -> Self: ...
def __ior__(self, other: SupportsItems[_T, int]) -> Self: ... # type: ignore[override,misc]
def __iadd__(self, other: SupportsItems[_T, _C]) -> Self: ... # type: ignore[misc]
def __isub__(self, other: SupportsItems[_T, _C]) -> Self: ...
def __iand__(self, other: SupportsItems[_T, _C]) -> Self: ...
def __ior__(self, other: SupportsItems[_T, _C]) -> Self: ... # type: ignore[override,misc]
if sys.version_info >= (3, 10):
def total(self) -> int: ...
def __le__(self, other: Counter[Any]) -> bool: ...
Expand Down
38 changes: 38 additions & 0 deletions test_cases/stdlib/collections/check_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from collections import Counter
from typing import Any, cast
from typing_extensions import assert_type


class Foo: ...


# Test the constructor
# mypy derives Never for the first type argument while, pyright derives Unknown
assert_type(Counter(), "Counter[Any, int]")
assert_type(Counter(foo=42.2), "Counter[str, float]")
assert_type(Counter({42: "bar"}), "Counter[int, str]")
assert_type(Counter([1, 2, 3]), "Counter[int, int]")

int_c: Counter[str] = Counter()
assert_type(int_c, "Counter[str, int]")
assert_type(int_c["a"], int)
int_c["a"] = 1
int_c["a"] += 3
int_c["a"] += 3.5 # type: ignore

float_c = Counter(foo=42.2)
assert_type(float_c, "Counter[str, float]")
assert_type(float_c["a"], float)
float_c["a"] = 1.0
float_c["a"] += 3.0
float_c["a"] += 42
float_c["a"] += "42" # type: ignore

custom_c = cast("Counter[str, Foo]", Counter())
assert_type(custom_c, "Counter[str, Foo]")
assert_type(custom_c["a"], Foo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At runtime this is actually an int though. I wonder if we need to make all these methods return _C | int.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line should probably not be accepted, as Counter() is a Counter[unknown, int], which is incompatible with Counter[..., Foo]. I'm not sure the test makes much sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this sort of problem apply to any Counter with a non-int value type, though? This seems like a fundamental problem with this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At runtime this is actually an int though. I wonder if we need to make all these methods return _C | int.

I wonder whether type checkers support __missing__, in which case, this should happen automatically when we add it to the stubs. But returning _C | int makes some sense to me for getter methods.

custom_c["a"] = Foo()
custom_c["a"] += Foo() # type: ignore
custom_c["a"] += 42 # type: ignore

Check failure on line 38 in test_cases/stdlib/collections/check_counter.py

View workflow job for this annotation

GitHub Actions / Test typeshed with pyright (Linux, 3.11)

Unnecessary "# type: ignore" comment
Loading