From 5989375dc3953a5709c411e11f658c242319858e Mon Sep 17 00:00:00 2001 From: Andreas Backx Date: Wed, 22 Apr 2026 18:46:54 +0200 Subject: [PATCH] ParamType typing improvements Co-authored-by: Kevin Deldycke --- CHANGES.rst | 18 ++++ src/click/core.py | 10 +- src/click/termui.py | 4 +- src/click/types.py | 220 +++++++++++++++++++++++++----------------- tests/test_imports.py | 2 + 5 files changed, 159 insertions(+), 95 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index d5b05c526e..00a9a5ef66 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,5 +1,23 @@ .. currentmodule:: click +Version 8.4.0 +------------- + +Unreleased + +- :class:`ParamType` typing improvements. :pr:`3371` + + - :class:`ParamType` is now a generic abstract base class, + parameterized by its converted value type. + - :meth:`~ParamType.convert` return types are narrowed on all + concrete types (``str`` for :class:`STRING`, ``int`` for + :class:`INT`, etc.). + - :meth:`~ParamType.to_info_dict` returns specific + :class:`~typing.TypedDict` subclasses instead of + ``dict[str, Any]``. + - :class:`CompositeParamType` and the number-range base are now + generic with abstract methods. + Version 8.3.3 ------------- diff --git a/src/click/core.py b/src/click/core.py index d940dd80e1..13d8841da1 100644 --- a/src/click/core.py +++ b/src/click/core.py @@ -2149,7 +2149,7 @@ class Parameter: def __init__( self, param_decls: cabc.Sequence[str] | None = None, - type: types.ParamType | t.Any | None = None, + type: types.ParamType[t.Any] | t.Any | None = None, required: bool = False, # XXX The default historically embed two concepts: # - the declaration of a Parameter object carrying the default (handy to @@ -2181,7 +2181,7 @@ def __init__( self.name, self.opts, self.secondary_opts = self._parse_decls( param_decls or (), expose_value ) - self.type: types.ParamType = types.convert_type(type, default) + self.type: types.ParamType[t.Any] = types.convert_type(type, default) # Default nargs to what the type tells us if we have that # information available. @@ -2648,7 +2648,7 @@ def shell_complete(self, ctx: Context, incomplete: str) -> list[CompletionItem]: """Return a list of completions for the incomplete value. If a ``shell_complete`` function was given during init, it is used. Otherwise, the :attr:`type` - :meth:`~click.types.ParamType.shell_complete` function is used. + :meth:`~click.types.ParamType[t.Any].shell_complete` function is used. :param ctx: Invocation context for this command. :param incomplete: Value being completed. May be empty. @@ -2749,7 +2749,7 @@ def __init__( multiple: bool = False, count: bool = False, allow_from_autoenv: bool = True, - type: types.ParamType | t.Any | None = None, + type: types.ParamType[t.Any] | t.Any | None = None, help: str | None = None, hidden: bool = False, show_choices: bool = True, @@ -2825,7 +2825,7 @@ def __init__( if type is None: # A flag without a flag_value is a boolean flag. if flag_value is UNSET: - self.type: types.ParamType = types.BoolParamType() + self.type: types.ParamType[t.Any] = types.BoolParamType() # If the flag value is a boolean, use BoolParamType. elif isinstance(flag_value, bool): self.type = types.BoolParamType() diff --git a/src/click/termui.py b/src/click/termui.py index 48f671b217..6801e30fa4 100644 --- a/src/click/termui.py +++ b/src/click/termui.py @@ -63,7 +63,7 @@ def _build_prompt( show_default: bool | str = False, default: t.Any | None = None, show_choices: bool = True, - type: ParamType | None = None, + type: ParamType[t.Any] | None = None, ) -> str: prompt = text if type is not None and show_choices and isinstance(type, Choice): @@ -87,7 +87,7 @@ def prompt( default: t.Any | None = None, hide_input: bool = False, confirmation_prompt: bool | str = False, - type: ParamType | t.Any | None = None, + type: ParamType[t.Any] | t.Any | None = None, value_proc: t.Callable[[str], t.Any] | None = None, prompt_suffix: str = ": ", show_default: bool | str = True, diff --git a/src/click/types.py b/src/click/types.py index e71c1c21e4..5d51d20d57 100644 --- a/src/click/types.py +++ b/src/click/types.py @@ -1,11 +1,13 @@ from __future__ import annotations +import abc import collections.abc as cabc import enum import os import stat import sys import typing as t +import uuid from datetime import datetime from gettext import gettext as _ from gettext import ngettext @@ -27,7 +29,12 @@ ParamTypeValue = t.TypeVar("ParamTypeValue") -class ParamType: +class ParamTypeInfoDict(t.TypedDict): + param_type: str + name: str + + +class ParamType(t.Generic[ParamTypeValue], abc.ABC): """Represents the type of a parameter. Validates and converts values from the command line or Python into the correct type. @@ -59,7 +66,7 @@ class ParamType: #: Windows). envvar_list_splitter: t.ClassVar[str | None] = None - def to_info_dict(self) -> dict[str, t.Any]: + def to_info_dict(self) -> ParamTypeInfoDict: """Gather information that could be useful for a tool generating user-facing documentation. @@ -85,9 +92,10 @@ def __call__( value: t.Any, param: Parameter | None = None, ctx: Context | None = None, - ) -> t.Any: + ) -> ParamTypeValue | None: if value is not None: return self.convert(value, param, ctx) + return None def get_metavar(self, param: Parameter, ctx: Context) -> str | None: """Returns the metavar default for this param if it provides one.""" @@ -101,7 +109,7 @@ def get_missing_message(self, param: Parameter, ctx: Context | None) -> str | No def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> ParamTypeValue: """Convert the value to the correct type. This is not called if the value is ``None`` (the missing value). @@ -121,7 +129,7 @@ def convert( :param ctx: The current context that arrived at this value. May be ``None``. """ - return value + return value # type: ignore[no-any-return] def split_envvar_value(self, rv: str) -> cabc.Sequence[str]: """Given a value from an environment variable this splits it up @@ -160,23 +168,25 @@ def shell_complete( return [] -class CompositeParamType(ParamType): +class CompositeParamType(ParamType[ParamTypeValue]): is_composite = True @property - def arity(self) -> int: # type: ignore - raise NotImplementedError() + @abc.abstractmethod + def arity(self) -> int: ... # type: ignore[override] + + +class FuncParamTypeInfoDict(ParamTypeInfoDict): + func: t.Callable[[t.Any], t.Any] -class FuncParamType(ParamType): +class FuncParamType(ParamType[t.Any]): def __init__(self, func: t.Callable[[t.Any], t.Any]) -> None: self.name: str = func.__name__ self.func = func - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict["func"] = self.func - return info_dict + def to_info_dict(self) -> FuncParamTypeInfoDict: + return {"func": self.func, **super().to_info_dict()} def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None @@ -192,7 +202,7 @@ def convert( self.fail(value, param, ctx) -class UnprocessedParamType(ParamType): +class UnprocessedParamType(ParamType[t.Any]): name = "text" def convert( @@ -204,12 +214,12 @@ def __repr__(self) -> str: return "UNPROCESSED" -class StringParamType(ParamType): +class StringParamType(ParamType[str]): name = "text" def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> str: if isinstance(value, bytes): enc = _get_argv_encoding() try: @@ -223,14 +233,19 @@ def convert( value = value.decode("utf-8", "replace") else: value = value.decode("utf-8", "replace") - return value + return value # type: ignore[no-any-return] return str(value) def __repr__(self) -> str: return "STRING" -class Choice(ParamType, t.Generic[ParamTypeValue]): +class ChoiceInfoDict(ParamTypeInfoDict): + choices: cabc.Sequence[t.Any] + case_sensitive: bool + + +class Choice(ParamType[ParamTypeValue], t.Generic[ParamTypeValue]): """The choice type allows a value to be checked against a fixed set of supported values. @@ -261,11 +276,12 @@ def __init__( self.choices: cabc.Sequence[ParamTypeValue] = tuple(choices) self.case_sensitive = case_sensitive - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict["choices"] = self.choices - info_dict["case_sensitive"] = self.case_sensitive - return info_dict + def to_info_dict(self) -> ChoiceInfoDict: + return { + "choices": self.choices, + "case_sensitive": self.case_sensitive, + **super().to_info_dict(), + } def _normalized_mapping( self, ctx: Context | None = None @@ -398,7 +414,11 @@ def shell_complete( return [CompletionItem(c) for c in matched] -class DateTime(ParamType): +class DateTimeInfoDict(ParamTypeInfoDict): + formats: cabc.Sequence[str] + + +class DateTime(ParamType[datetime]): """The DateTime type converts date strings into `datetime` objects. The format strings which are checked are configurable, but default to some @@ -428,10 +448,8 @@ def __init__(self, formats: cabc.Sequence[str] | None = None): "%Y-%m-%d %H:%M:%S", ] - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict["formats"] = self.formats - return info_dict + def to_info_dict(self) -> DateTimeInfoDict: + return {"formats": self.formats, **super().to_info_dict()} def get_metavar(self, param: Parameter, ctx: Context) -> str | None: return f"[{'|'.join(self.formats)}]" @@ -444,7 +462,7 @@ def _try_to_convert_date(self, value: t.Any, format: str) -> datetime | None: def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> datetime: if isinstance(value, datetime): return value @@ -469,12 +487,12 @@ def __repr__(self) -> str: return "DateTime" -class _NumberParamTypeBase(ParamType): - _number_class: t.ClassVar[type[t.Any]] +class _NumberParamTypeBase(ParamType[ParamTypeValue]): + _number_class: t.Callable[[t.Any], ParamTypeValue] def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> ParamTypeValue: try: return self._number_class(value) except ValueError: @@ -487,7 +505,15 @@ def convert( ) -class _NumberRangeBase(_NumberParamTypeBase): +class NumberRangeInfoDict(ParamTypeInfoDict): + min: float | None + max: float | None + min_open: bool + max_open: bool + clamp: bool + + +class _NumberRangeBase(_NumberParamTypeBase[ParamTypeValue]): def __init__( self, min: float | None = None, @@ -502,29 +528,28 @@ def __init__( self.max_open = max_open self.clamp = clamp - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict.update( - min=self.min, - max=self.max, - min_open=self.min_open, - max_open=self.max_open, - clamp=self.clamp, - ) - return info_dict + def to_info_dict(self) -> NumberRangeInfoDict: + return { + "min": self.min, + "max": self.max, + "min_open": self.min_open, + "max_open": self.max_open, + "clamp": self.clamp, + **super().to_info_dict(), + } def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> ParamTypeValue: import operator rv = super().convert(value, param, ctx) lt_min: bool = self.min is not None and ( operator.le if self.min_open else operator.lt - )(rv, self.min) + )(rv, self.min) # type: ignore[arg-type] gt_max: bool = self.max is not None and ( operator.ge if self.max_open else operator.gt - )(rv, self.max) + )(rv, self.max) # type: ignore[arg-type] if self.clamp: if lt_min: @@ -544,7 +569,10 @@ def convert( return rv - def _clamp(self, bound: float, dir: t.Literal[1, -1], open: bool) -> float: + @abc.abstractmethod + def _clamp( + self, bound: ParamTypeValue, dir: t.Literal[1, -1], open: bool + ) -> ParamTypeValue: """Find the valid value to clamp to bound in the given direction. @@ -552,7 +580,7 @@ def _clamp(self, bound: float, dir: t.Literal[1, -1], open: bool) -> float: :param dir: 1 or -1 indicating the direction to move. :param open: If true, the range does not include the bound. """ - raise NotImplementedError + ... def _describe_range(self) -> str: """Describe the range for use in help text.""" @@ -573,7 +601,7 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {self._describe_range()}{clamp}>" -class IntParamType(_NumberParamTypeBase): +class IntParamType(_NumberParamTypeBase[int]): name = "integer" _number_class = int @@ -581,7 +609,7 @@ def __repr__(self) -> str: return "INT" -class IntRange(_NumberRangeBase, IntParamType): +class IntRange(_NumberRangeBase[int], IntParamType): """Restrict an :data:`click.INT` value to a range of accepted values. See :ref:`ranges`. @@ -598,16 +626,14 @@ class IntRange(_NumberRangeBase, IntParamType): name = "integer range" - def _clamp( # type: ignore - self, bound: int, dir: t.Literal[1, -1], open: bool - ) -> int: + def _clamp(self, bound: int, dir: t.Literal[1, -1], open: bool) -> int: if not open: return bound return bound + dir -class FloatParamType(_NumberParamTypeBase): +class FloatParamType(_NumberParamTypeBase[float]): name = "float" _number_class = float @@ -615,7 +641,7 @@ def __repr__(self) -> str: return "FLOAT" -class FloatRange(_NumberRangeBase, FloatParamType): +class FloatRange(_NumberRangeBase[float], FloatParamType): """Restrict a :data:`click.FLOAT` value to a range of accepted values. See :ref:`ranges`. @@ -658,7 +684,7 @@ def _clamp(self, bound: float, dir: t.Literal[1, -1], open: bool) -> float: raise RuntimeError("Clamping is not supported for open bounds.") -class BoolParamType(ParamType): +class BoolParamType(ParamType[bool]): name = "boolean" bool_states: dict[str, bool] = { @@ -727,14 +753,12 @@ def __repr__(self) -> str: return "BOOL" -class UUIDParameterType(ParamType): +class UUIDParameterType(ParamType[uuid.UUID]): name = "uuid" def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: - import uuid - + ) -> uuid.UUID: if isinstance(value, uuid.UUID): return value @@ -751,7 +775,12 @@ def __repr__(self) -> str: return "UUID" -class File(ParamType): +class FileInfoDict(ParamTypeInfoDict): + mode: str + encoding: str | None + + +class File(ParamType[t.IO[t.Any]]): """Declares a parameter to be a file for reading or writing. The file is automatically closed once the context tears down (after the command finished working). @@ -798,10 +827,12 @@ def __init__( self.lazy = lazy self.atomic = atomic - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict.update(mode=self.mode, encoding=self.encoding) - return info_dict + def to_info_dict(self) -> FileInfoDict: + return { + "mode": self.mode, + "encoding": self.encoding, + **super().to_info_dict(), + } def resolve_lazy_flag(self, value: str | os.PathLike[str]) -> bool: if self.lazy is not None: @@ -876,7 +907,16 @@ def _is_file_like(value: t.Any) -> te.TypeGuard[t.IO[t.Any]]: return hasattr(value, "read") or hasattr(value, "write") -class Path(ParamType): +class PathInfoDict(ParamTypeInfoDict): + exists: bool + file_okay: bool + dir_okay: bool + writable: bool + readable: bool + allow_dash: bool + + +class Path(ParamType[str | bytes | os.PathLike[str]]): """The ``Path`` type is similar to the :class:`File` type, but returns the filename instead of an open file. Various checks can be enabled to validate the type of file and permissions. @@ -940,17 +980,16 @@ def __init__( else: self.name = _("path") - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict.update( - exists=self.exists, - file_okay=self.file_okay, - dir_okay=self.dir_okay, - writable=self.writable, - readable=self.readable, - allow_dash=self.allow_dash, - ) - return info_dict + def to_info_dict(self) -> PathInfoDict: + return { + "exists": self.exists, + "file_okay": self.file_okay, + "dir_okay": self.dir_okay, + "writable": self.writable, + "readable": self.readable, + "allow_dash": self.allow_dash, + **super().to_info_dict(), + } def coerce_path_result( self, value: str | os.PathLike[str] @@ -1057,7 +1096,11 @@ def shell_complete( return [CompletionItem(incomplete, type=type)] -class Tuple(CompositeParamType): +class TupleInfoDict(ParamTypeInfoDict): + types: cabc.Sequence[ParamTypeInfoDict] + + +class Tuple(CompositeParamType[tuple[t.Any, ...]]): """The default behavior of Click is to apply a type on a value directly. This works well in most cases, except for when `nargs` is set to a fixed count and different types should be used for different items. In this @@ -1071,25 +1114,26 @@ class Tuple(CompositeParamType): :param types: a list of types that should be used for the tuple items. """ - def __init__(self, types: cabc.Sequence[type[t.Any] | ParamType]) -> None: - self.types: cabc.Sequence[ParamType] = [convert_type(ty) for ty in types] + def __init__(self, types: cabc.Sequence[type[t.Any] | ParamType[t.Any]]) -> None: + self.types: cabc.Sequence[ParamType[t.Any]] = [convert_type(ty) for ty in types] - def to_info_dict(self) -> dict[str, t.Any]: - info_dict = super().to_info_dict() - info_dict["types"] = [t.to_info_dict() for t in self.types] - return info_dict + def to_info_dict(self) -> TupleInfoDict: + return { + "types": [ty.to_info_dict() for ty in self.types], + **super().to_info_dict(), + } @property - def name(self) -> str: # type: ignore + def name(self) -> str: # type: ignore[override] return f"<{' '.join(ty.name for ty in self.types)}>" @property - def arity(self) -> int: # type: ignore + def arity(self) -> int: # type: ignore[override] return len(self.types) def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> tuple[t.Any, ...]: len_type = len(self.types) len_value = len(value) @@ -1109,7 +1153,7 @@ def convert( ) -def convert_type(ty: t.Any | None, default: t.Any | None = None) -> ParamType: +def convert_type(ty: t.Any | None, default: t.Any | None = None) -> ParamType[t.Any]: """Find the most appropriate :class:`ParamType` for the given Python type. If the type isn't provided, it can be inferred from a default value. diff --git a/tests/test_imports.py b/tests/test_imports.py index 917b245f29..74b78642bc 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -27,6 +27,7 @@ def tracking_import(module, locals=None, globals=None, fromlist=None, ALLOWED_IMPORTS = { "__future__", + "abc", "codecs", "collections", "collections.abc", @@ -49,6 +50,7 @@ def tracking_import(module, locals=None, globals=None, fromlist=None, "threading", "types", "typing", + "uuid", "weakref", }