Skip to content

Commit

Permalink
Types: don't leave generic types without a parameter (#2401)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism committed Jan 20, 2023
2 parents a6c7ee0 + 9afe27e commit c0092d2
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 112 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Version 8.1.4

Unreleased

- Improve type hinting for decorators and give all generic types parameters.
:issue:`2398`


Version 8.1.3
-------------
Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ ignore =
W503
# zip with strict=, requires python >= 3.10
B905
# string formatting opinion, B028 renamed to B907
B028
B907
# up to 88 allowed by bugbear B950
max-line-length = 80
per-file-ignores =
Expand All @@ -85,6 +88,7 @@ disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
disallow_any_generics = True
check_untyped_defs = True
no_implicit_optional = True
local_partial_types = True
Expand Down
38 changes: 19 additions & 19 deletions src/click/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def is_ascii_encoding(encoding: str) -> bool:
return False


def get_best_encoding(stream: t.IO) -> str:
def get_best_encoding(stream: t.IO[t.Any]) -> str:
"""Returns the default stream encoding if not found."""
rv = getattr(stream, "encoding", None) or sys.getdefaultencoding()
if is_ascii_encoding(rv):
Expand Down Expand Up @@ -153,7 +153,7 @@ def seekable(self) -> bool:
return True


def _is_binary_reader(stream: t.IO, default: bool = False) -> bool:
def _is_binary_reader(stream: t.IO[t.Any], default: bool = False) -> bool:
try:
return isinstance(stream.read(0), bytes)
except Exception:
Expand All @@ -162,7 +162,7 @@ def _is_binary_reader(stream: t.IO, default: bool = False) -> bool:
# closed. In this case, we assume the default.


def _is_binary_writer(stream: t.IO, default: bool = False) -> bool:
def _is_binary_writer(stream: t.IO[t.Any], default: bool = False) -> bool:
try:
stream.write(b"")
except Exception:
Expand All @@ -175,7 +175,7 @@ def _is_binary_writer(stream: t.IO, default: bool = False) -> bool:
return True


def _find_binary_reader(stream: t.IO) -> t.Optional[t.BinaryIO]:
def _find_binary_reader(stream: t.IO[t.Any]) -> t.Optional[t.BinaryIO]:
# We need to figure out if the given stream is already binary.
# This can happen because the official docs recommend detaching
# the streams to get binary streams. Some code might do this, so
Expand All @@ -193,7 +193,7 @@ def _find_binary_reader(stream: t.IO) -> t.Optional[t.BinaryIO]:
return None


def _find_binary_writer(stream: t.IO) -> t.Optional[t.BinaryIO]:
def _find_binary_writer(stream: t.IO[t.Any]) -> t.Optional[t.BinaryIO]:
# We need to figure out if the given stream is already binary.
# This can happen because the official docs recommend detaching
# the streams to get binary streams. Some code might do this, so
Expand Down Expand Up @@ -241,11 +241,11 @@ def _is_compatible_text_stream(


def _force_correct_text_stream(
text_stream: t.IO,
text_stream: t.IO[t.Any],
encoding: t.Optional[str],
errors: t.Optional[str],
is_binary: t.Callable[[t.IO, bool], bool],
find_binary: t.Callable[[t.IO], t.Optional[t.BinaryIO]],
is_binary: t.Callable[[t.IO[t.Any], bool], bool],
find_binary: t.Callable[[t.IO[t.Any]], t.Optional[t.BinaryIO]],
force_readable: bool = False,
force_writable: bool = False,
) -> t.TextIO:
Expand Down Expand Up @@ -287,7 +287,7 @@ def _force_correct_text_stream(


def _force_correct_text_reader(
text_reader: t.IO,
text_reader: t.IO[t.Any],
encoding: t.Optional[str],
errors: t.Optional[str],
force_readable: bool = False,
Expand All @@ -303,7 +303,7 @@ def _force_correct_text_reader(


def _force_correct_text_writer(
text_writer: t.IO,
text_writer: t.IO[t.Any],
encoding: t.Optional[str],
errors: t.Optional[str],
force_writable: bool = False,
Expand Down Expand Up @@ -367,11 +367,11 @@ def get_text_stderr(


def _wrap_io_open(
file: t.Union[str, os.PathLike, int],
file: t.Union[str, "os.PathLike[t.AnyStr]", int],
mode: str,
encoding: t.Optional[str],
errors: t.Optional[str],
) -> t.IO:
) -> t.IO[t.Any]:
"""Handles not passing ``encoding`` and ``errors`` in binary mode."""
if "b" in mode:
return open(file, mode)
Expand All @@ -385,7 +385,7 @@ def open_stream(
encoding: t.Optional[str] = None,
errors: t.Optional[str] = "strict",
atomic: bool = False,
) -> t.Tuple[t.IO, bool]:
) -> t.Tuple[t.IO[t.Any], bool]:
binary = "b" in mode

# Standard streams first. These are simple because they ignore the
Expand Down Expand Up @@ -456,11 +456,11 @@ def open_stream(

f = _wrap_io_open(fd, mode, encoding, errors)
af = _AtomicFile(f, tmp_filename, os.path.realpath(filename))
return t.cast(t.IO, af), True
return t.cast(t.IO[t.Any], af), True


class _AtomicFile:
def __init__(self, f: t.IO, tmp_filename: str, real_filename: str) -> None:
def __init__(self, f: t.IO[t.Any], tmp_filename: str, real_filename: str) -> None:
self._f = f
self._tmp_filename = tmp_filename
self._real_filename = real_filename
Expand All @@ -483,7 +483,7 @@ def __getattr__(self, name: str) -> t.Any:
def __enter__(self) -> "_AtomicFile":
return self

def __exit__(self, exc_type, exc_value, tb): # type: ignore
def __exit__(self, exc_type: t.Optional[t.Type[BaseException]], *_: t.Any) -> None:
self.close(delete=exc_type is not None)

def __repr__(self) -> str:
Expand All @@ -494,15 +494,15 @@ def strip_ansi(value: str) -> str:
return _ansi_re.sub("", value)


def _is_jupyter_kernel_output(stream: t.IO) -> bool:
def _is_jupyter_kernel_output(stream: t.IO[t.Any]) -> bool:
while isinstance(stream, (_FixupStream, _NonClosingTextIOWrapper)):
stream = stream._stream

return stream.__class__.__module__.startswith("ipykernel.")


def should_strip_ansi(
stream: t.Optional[t.IO] = None, color: t.Optional[bool] = None
stream: t.Optional[t.IO[t.Any]] = None, color: t.Optional[bool] = None
) -> bool:
if color is None:
if stream is None:
Expand Down Expand Up @@ -576,7 +576,7 @@ def term_len(x: str) -> int:
return len(strip_ansi(x))


def isatty(stream: t.IO) -> bool:
def isatty(stream: t.IO[t.Any]) -> bool:
try:
return stream.isatty()
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions src/click/_termui_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def __init__(
self.is_hidden = not isatty(self.file)
self._last_line: t.Optional[str] = None

def __enter__(self) -> "ProgressBar":
def __enter__(self) -> "ProgressBar[V]":
self.entered = True
self.render_progress()
return self

def __exit__(self, exc_type, exc_value, tb): # type: ignore
def __exit__(self, *_: t.Any) -> None:
self.render_finish()

def __iter__(self) -> t.Iterator[V]:
Expand Down
36 changes: 27 additions & 9 deletions src/click/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def __enter__(self) -> "Context":
push_context(self)
return self

def __exit__(self, exc_type, exc_value, tb): # type: ignore
def __exit__(self, *_: t.Any) -> None:
self._depth -= 1
if self._depth == 0:
self.close()
Expand Down Expand Up @@ -706,12 +706,30 @@ def _make_sub_context(self, command: "Command") -> "Context":
"""
return type(self)(command, info_name=command.name, parent=self)

@t.overload
def invoke(
__self, # noqa: B902
__callback: "t.Callable[..., V]",
*args: t.Any,
**kwargs: t.Any,
) -> V:
...

@t.overload
def invoke(
__self, # noqa: B902
__callback: t.Union["Command", t.Callable[..., t.Any]],
__callback: "Command",
*args: t.Any,
**kwargs: t.Any,
) -> t.Any:
...

def invoke(
__self, # noqa: B902
__callback: t.Union["Command", "t.Callable[..., V]"],
*args: t.Any,
**kwargs: t.Any,
) -> t.Union[t.Any, V]:
"""Invokes a command callback in exactly the way it expects. There
are two ways to invoke this method:
Expand Down Expand Up @@ -739,7 +757,7 @@ def invoke(
"The given command does not have a callback that can be invoked."
)
else:
__callback = other_cmd.callback
__callback = t.cast("t.Callable[..., V]", other_cmd.callback)

ctx = __self._make_sub_context(other_cmd)

Expand Down Expand Up @@ -1841,7 +1859,7 @@ def command(
if self.command_class and kwargs.get("cls") is None:
kwargs["cls"] = self.command_class

func: t.Optional[t.Callable] = None
func: t.Optional[t.Callable[..., t.Any]] = None

if args and callable(args[0]):
assert (
Expand Down Expand Up @@ -1889,7 +1907,7 @@ def group(
"""
from .decorators import group

func: t.Optional[t.Callable] = None
func: t.Optional[t.Callable[..., t.Any]] = None

if args and callable(args[0]):
assert (
Expand Down Expand Up @@ -2260,7 +2278,7 @@ def type_cast_value(self, ctx: Context, value: t.Any) -> t.Any:
if value is None:
return () if self.multiple or self.nargs == -1 else None

def check_iter(value: t.Any) -> t.Iterator:
def check_iter(value: t.Any) -> t.Iterator[t.Any]:
try:
return _check_iter(value)
except TypeError:
Expand All @@ -2277,12 +2295,12 @@ def check_iter(value: t.Any) -> t.Iterator:
)
elif self.nargs == -1:

def convert(value: t.Any) -> t.Tuple:
def convert(value: t.Any) -> t.Tuple[t.Any, ...]:
return tuple(self.type(x, self, ctx) for x in check_iter(value))

else: # nargs > 1

def convert(value: t.Any) -> t.Tuple:
def convert(value: t.Any) -> t.Tuple[t.Any, ...]:
value = tuple(check_iter(value))

if len(value) != self.nargs:
Expand Down Expand Up @@ -2817,7 +2835,7 @@ def get_default(
if self.is_flag and not self.is_bool_flag:
for param in ctx.command.params:
if param.name == self.name and param.default:
return param.flag_value # type: ignore
return t.cast(Option, param).flag_value

return None

Expand Down
Loading

0 comments on commit c0092d2

Please sign in to comment.