From 0ada61125058ce54bd3dd0f4d4a61fe740a38955 Mon Sep 17 00:00:00 2001 From: Tal Hayon Date: Sat, 17 Apr 2021 09:21:34 +0300 Subject: [PATCH 1/2] Fixes #1110. Use SupportsIndex in str, bytes, bytesarray functions where applicable. --- stdlib/builtins.pyi | 48 ++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 2a8e3dc12f41..b60233e82d63 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -59,7 +59,7 @@ from typing import ( overload, runtime_checkable, ) -from typing_extensions import Literal +from typing_extensions import Literal, SupportsIndex if sys.version_info >= (3, 9): from types import GenericAlias @@ -327,6 +327,8 @@ class complex: class _FormatMapMapping(Protocol): def __getitem__(self, __key: str) -> Any: ... +_StartEndArg = Optional[Union[int, SupportsIndex]] + class str(Sequence[str]): @overload def __new__(cls: Type[_T], o: object = ...) -> _T: ... @@ -335,16 +337,14 @@ class str(Sequence[str]): def capitalize(self) -> str: ... def casefold(self) -> str: ... def center(self, __width: int, __fillchar: str = ...) -> str: ... - def count(self, x: str, __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def count(self, x: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def encode(self, encoding: str = ..., errors: str = ...) -> bytes: ... - def endswith( - self, __suffix: Union[str, Tuple[str, ...]], __start: Optional[int] = ..., __end: Optional[int] = ... - ) -> bool: ... + def endswith(self, __suffix: Union[str, Tuple[str, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> bool: ... def expandtabs(self, tabsize: int = ...) -> str: ... - def find(self, __sub: str, __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def find(self, __sub: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def format(self, *args: object, **kwargs: object) -> str: ... def format_map(self, map: _FormatMapMapping) -> str: ... - def index(self, __sub: str, __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def index(self, __sub: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def isalnum(self) -> bool: ... def isalpha(self) -> bool: ... if sys.version_info >= (3, 7): @@ -367,8 +367,8 @@ class str(Sequence[str]): if sys.version_info >= (3, 9): def removeprefix(self, __prefix: str) -> str: ... def removesuffix(self, __suffix: str) -> str: ... - def rfind(self, __sub: str, __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... - def rindex(self, __sub: str, __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def rfind(self, __sub: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def rindex(self, __sub: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def rjust(self, __width: int, __fillchar: str = ...) -> str: ... def rpartition(self, __sep: str) -> Tuple[str, str, str]: ... def rsplit(self, sep: Optional[str] = ..., maxsplit: int = ...) -> List[str]: ... @@ -376,7 +376,7 @@ class str(Sequence[str]): def split(self, sep: Optional[str] = ..., maxsplit: int = ...) -> List[str]: ... def splitlines(self, keepends: bool = ...) -> List[str]: ... def startswith( - self, __prefix: Union[str, Tuple[str, ...]], __start: Optional[int] = ..., __end: Optional[int] = ... + self, __prefix: Union[str, Tuple[str, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... ) -> bool: ... def strip(self, __chars: Optional[str] = ...) -> str: ... def swapcase(self) -> str: ... @@ -423,18 +423,18 @@ class bytes(ByteString): def __new__(cls: Type[_T], o: SupportsBytes) -> _T: ... def capitalize(self) -> bytes: ... def center(self, __width: int, __fillchar: bytes = ...) -> bytes: ... - def count(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def count(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def decode(self, encoding: str = ..., errors: str = ...) -> str: ... def endswith( - self, __suffix: Union[bytes, Tuple[bytes, ...]], __start: Optional[int] = ..., __end: Optional[int] = ... + self, __suffix: Union[bytes, Tuple[bytes, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... ) -> bool: ... def expandtabs(self, tabsize: int = ...) -> bytes: ... - def find(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def find(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... if sys.version_info >= (3, 8): def hex(self, sep: Union[str, bytes] = ..., bytes_per_sep: int = ...) -> str: ... else: def hex(self) -> str: ... - def index(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def index(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def isalnum(self) -> bool: ... def isalpha(self) -> bool: ... if sys.version_info >= (3, 7): @@ -453,8 +453,8 @@ class bytes(ByteString): if sys.version_info >= (3, 9): def removeprefix(self, __prefix: bytes) -> bytes: ... def removesuffix(self, __suffix: bytes) -> bytes: ... - def rfind(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... - def rindex(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def rfind(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def rindex(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def rjust(self, __width: int, __fillchar: bytes = ...) -> bytes: ... def rpartition(self, __sep: bytes) -> Tuple[bytes, bytes, bytes]: ... def rsplit(self, sep: Optional[bytes] = ..., maxsplit: int = ...) -> List[bytes]: ... @@ -462,7 +462,7 @@ class bytes(ByteString): def split(self, sep: Optional[bytes] = ..., maxsplit: int = ...) -> List[bytes]: ... def splitlines(self, keepends: bool = ...) -> List[bytes]: ... def startswith( - self, __prefix: Union[bytes, Tuple[bytes, ...]], __start: Optional[int] = ..., __end: Optional[int] = ... + self, __prefix: Union[bytes, Tuple[bytes, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... ) -> bool: ... def strip(self, __bytes: Optional[bytes] = ...) -> bytes: ... def swapcase(self) -> bytes: ... @@ -509,20 +509,20 @@ class bytearray(MutableSequence[int], ByteString): def append(self, __item: int) -> None: ... def capitalize(self) -> bytearray: ... def center(self, __width: int, __fillchar: bytes = ...) -> bytearray: ... - def count(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def count(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def copy(self) -> bytearray: ... def decode(self, encoding: str = ..., errors: str = ...) -> str: ... def endswith( - self, __suffix: Union[bytes, Tuple[bytes, ...]], __start: Optional[int] = ..., __end: Optional[int] = ... + self, __suffix: Union[bytes, Tuple[bytes, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... ) -> bool: ... def expandtabs(self, tabsize: int = ...) -> bytearray: ... def extend(self, __iterable_of_ints: Iterable[int]) -> None: ... - def find(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def find(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... if sys.version_info >= (3, 8): def hex(self, sep: Union[str, bytes] = ..., bytes_per_sep: int = ...) -> str: ... else: def hex(self) -> str: ... - def index(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def index(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def insert(self, __index: int, __item: int) -> None: ... def isalnum(self) -> bool: ... def isalpha(self) -> bool: ... @@ -542,8 +542,8 @@ class bytearray(MutableSequence[int], ByteString): def removeprefix(self, __prefix: bytes) -> bytearray: ... def removesuffix(self, __suffix: bytes) -> bytearray: ... def replace(self, __old: bytes, __new: bytes, __count: int = ...) -> bytearray: ... - def rfind(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... - def rindex(self, __sub: Union[bytes, int], __start: Optional[int] = ..., __end: Optional[int] = ...) -> int: ... + def rfind(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def rindex(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... def rjust(self, __width: int, __fillchar: bytes = ...) -> bytearray: ... def rpartition(self, __sep: bytes) -> Tuple[bytearray, bytearray, bytearray]: ... def rsplit(self, sep: Optional[bytes] = ..., maxsplit: int = ...) -> List[bytearray]: ... @@ -551,7 +551,7 @@ class bytearray(MutableSequence[int], ByteString): def split(self, sep: Optional[bytes] = ..., maxsplit: int = ...) -> List[bytearray]: ... def splitlines(self, keepends: bool = ...) -> List[bytearray]: ... def startswith( - self, __prefix: Union[bytes, Tuple[bytes, ...]], __start: Optional[int] = ..., __end: Optional[int] = ... + self, __prefix: Union[bytes, Tuple[bytes, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... ) -> bool: ... def strip(self, __bytes: Optional[bytes] = ...) -> bytearray: ... def swapcase(self) -> bytearray: ... From 7e8b3a78b5432e9ac46711af731a7ceb98ee2536 Mon Sep 17 00:00:00 2001 From: Tal Hayon Date: Sat, 17 Apr 2021 21:36:46 +0300 Subject: [PATCH 2/2] Use file defined _SupportIndex and remove int from Union --- stdlib/builtins.pyi | 86 ++++++++++++++++++++++++++++++++------------- 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index b60233e82d63..e88d9d8d27ae 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -59,7 +59,7 @@ from typing import ( overload, runtime_checkable, ) -from typing_extensions import Literal, SupportsIndex +from typing_extensions import Literal if sys.version_info >= (3, 9): from types import GenericAlias @@ -327,8 +327,6 @@ class complex: class _FormatMapMapping(Protocol): def __getitem__(self, __key: str) -> Any: ... -_StartEndArg = Optional[Union[int, SupportsIndex]] - class str(Sequence[str]): @overload def __new__(cls: Type[_T], o: object = ...) -> _T: ... @@ -337,14 +335,19 @@ class str(Sequence[str]): def capitalize(self) -> str: ... def casefold(self) -> str: ... def center(self, __width: int, __fillchar: str = ...) -> str: ... - def count(self, x: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def count(self, x: str, __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ...) -> int: ... def encode(self, encoding: str = ..., errors: str = ...) -> bytes: ... - def endswith(self, __suffix: Union[str, Tuple[str, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> bool: ... + def endswith( + self, + __suffix: Union[str, Tuple[str, ...]], + __start: Optional[_SupportsIndex] = ..., + __end: Optional[_SupportsIndex] = ..., + ) -> bool: ... def expandtabs(self, tabsize: int = ...) -> str: ... - def find(self, __sub: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def find(self, __sub: str, __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ...) -> int: ... def format(self, *args: object, **kwargs: object) -> str: ... def format_map(self, map: _FormatMapMapping) -> str: ... - def index(self, __sub: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def index(self, __sub: str, __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ...) -> int: ... def isalnum(self) -> bool: ... def isalpha(self) -> bool: ... if sys.version_info >= (3, 7): @@ -367,8 +370,8 @@ class str(Sequence[str]): if sys.version_info >= (3, 9): def removeprefix(self, __prefix: str) -> str: ... def removesuffix(self, __suffix: str) -> str: ... - def rfind(self, __sub: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... - def rindex(self, __sub: str, __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def rfind(self, __sub: str, __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ...) -> int: ... + def rindex(self, __sub: str, __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ...) -> int: ... def rjust(self, __width: int, __fillchar: str = ...) -> str: ... def rpartition(self, __sep: str) -> Tuple[str, str, str]: ... def rsplit(self, sep: Optional[str] = ..., maxsplit: int = ...) -> List[str]: ... @@ -376,7 +379,10 @@ class str(Sequence[str]): def split(self, sep: Optional[str] = ..., maxsplit: int = ...) -> List[str]: ... def splitlines(self, keepends: bool = ...) -> List[str]: ... def startswith( - self, __prefix: Union[str, Tuple[str, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... + self, + __prefix: Union[str, Tuple[str, ...]], + __start: Optional[_SupportsIndex] = ..., + __end: Optional[_SupportsIndex] = ..., ) -> bool: ... def strip(self, __chars: Optional[str] = ...) -> str: ... def swapcase(self) -> str: ... @@ -423,18 +429,27 @@ class bytes(ByteString): def __new__(cls: Type[_T], o: SupportsBytes) -> _T: ... def capitalize(self) -> bytes: ... def center(self, __width: int, __fillchar: bytes = ...) -> bytes: ... - def count(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def count( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... def decode(self, encoding: str = ..., errors: str = ...) -> str: ... def endswith( - self, __suffix: Union[bytes, Tuple[bytes, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... + self, + __suffix: Union[bytes, Tuple[bytes, ...]], + __start: Optional[_SupportsIndex] = ..., + __end: Optional[_SupportsIndex] = ..., ) -> bool: ... def expandtabs(self, tabsize: int = ...) -> bytes: ... - def find(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def find( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... if sys.version_info >= (3, 8): def hex(self, sep: Union[str, bytes] = ..., bytes_per_sep: int = ...) -> str: ... else: def hex(self) -> str: ... - def index(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def index( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... def isalnum(self) -> bool: ... def isalpha(self) -> bool: ... if sys.version_info >= (3, 7): @@ -453,8 +468,12 @@ class bytes(ByteString): if sys.version_info >= (3, 9): def removeprefix(self, __prefix: bytes) -> bytes: ... def removesuffix(self, __suffix: bytes) -> bytes: ... - def rfind(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... - def rindex(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def rfind( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... + def rindex( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... def rjust(self, __width: int, __fillchar: bytes = ...) -> bytes: ... def rpartition(self, __sep: bytes) -> Tuple[bytes, bytes, bytes]: ... def rsplit(self, sep: Optional[bytes] = ..., maxsplit: int = ...) -> List[bytes]: ... @@ -462,7 +481,10 @@ class bytes(ByteString): def split(self, sep: Optional[bytes] = ..., maxsplit: int = ...) -> List[bytes]: ... def splitlines(self, keepends: bool = ...) -> List[bytes]: ... def startswith( - self, __prefix: Union[bytes, Tuple[bytes, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... + self, + __prefix: Union[bytes, Tuple[bytes, ...]], + __start: Optional[_SupportsIndex] = ..., + __end: Optional[_SupportsIndex] = ..., ) -> bool: ... def strip(self, __bytes: Optional[bytes] = ...) -> bytes: ... def swapcase(self) -> bytes: ... @@ -509,20 +531,29 @@ class bytearray(MutableSequence[int], ByteString): def append(self, __item: int) -> None: ... def capitalize(self) -> bytearray: ... def center(self, __width: int, __fillchar: bytes = ...) -> bytearray: ... - def count(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def count( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... def copy(self) -> bytearray: ... def decode(self, encoding: str = ..., errors: str = ...) -> str: ... def endswith( - self, __suffix: Union[bytes, Tuple[bytes, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... + self, + __suffix: Union[bytes, Tuple[bytes, ...]], + __start: Optional[_SupportsIndex] = ..., + __end: Optional[_SupportsIndex] = ..., ) -> bool: ... def expandtabs(self, tabsize: int = ...) -> bytearray: ... def extend(self, __iterable_of_ints: Iterable[int]) -> None: ... - def find(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def find( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... if sys.version_info >= (3, 8): def hex(self, sep: Union[str, bytes] = ..., bytes_per_sep: int = ...) -> str: ... else: def hex(self) -> str: ... - def index(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def index( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... def insert(self, __index: int, __item: int) -> None: ... def isalnum(self) -> bool: ... def isalpha(self) -> bool: ... @@ -542,8 +573,12 @@ class bytearray(MutableSequence[int], ByteString): def removeprefix(self, __prefix: bytes) -> bytearray: ... def removesuffix(self, __suffix: bytes) -> bytearray: ... def replace(self, __old: bytes, __new: bytes, __count: int = ...) -> bytearray: ... - def rfind(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... - def rindex(self, __sub: Union[bytes, int], __start: _StartEndArg = ..., __end: _StartEndArg = ...) -> int: ... + def rfind( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... + def rindex( + self, __sub: Union[bytes, int], __start: Optional[_SupportsIndex] = ..., __end: Optional[_SupportsIndex] = ... + ) -> int: ... def rjust(self, __width: int, __fillchar: bytes = ...) -> bytearray: ... def rpartition(self, __sep: bytes) -> Tuple[bytearray, bytearray, bytearray]: ... def rsplit(self, sep: Optional[bytes] = ..., maxsplit: int = ...) -> List[bytearray]: ... @@ -551,7 +586,10 @@ class bytearray(MutableSequence[int], ByteString): def split(self, sep: Optional[bytes] = ..., maxsplit: int = ...) -> List[bytearray]: ... def splitlines(self, keepends: bool = ...) -> List[bytearray]: ... def startswith( - self, __prefix: Union[bytes, Tuple[bytes, ...]], __start: _StartEndArg = ..., __end: _StartEndArg = ... + self, + __prefix: Union[bytes, Tuple[bytes, ...]], + __start: Optional[_SupportsIndex] = ..., + __end: Optional[_SupportsIndex] = ..., ) -> bool: ... def strip(self, __bytes: Optional[bytes] = ...) -> bytearray: ... def swapcase(self) -> bytearray: ...