From 311dfaa2f45793f33778d68fbbf699e4d5946763 Mon Sep 17 00:00:00 2001 From: Meir Komet Date: Sat, 27 Jan 2024 02:48:13 +0200 Subject: [PATCH] Support for optional as UnionType (Python >= 3.10) Signed-off-by: Meir Komet --- tests/test_type_conversion.py | 21 +++++++++++++++++++++ typer/main.py | 20 +++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py index a4102daadc..27491547fc 100644 --- a/tests/test_type_conversion.py +++ b/tests/test_type_conversion.py @@ -1,3 +1,4 @@ +import sys from enum import Enum from pathlib import Path from typing import Any, List, Optional, Tuple @@ -29,6 +30,26 @@ def opt(user: Optional[str] = None): assert "User: Camila" in result.output +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +def test_optional_uniontype(): + app = typer.Typer() + + @app.command() + def opt(user: str | None = None): + if user: + print(f"User: {user}") + else: + print("No user") + + result = runner.invoke(app) + assert result.exit_code == 0 + assert "No user" in result.output + + result = runner.invoke(app, ["--user", "Camila"]) + assert result.exit_code == 0 + assert "User: Camila" in result.output + + def test_no_type(): app = typer.Typer() diff --git a/typer/main.py b/typer/main.py index 9de5f5960d..8878a10cea 100644 --- a/typer/main.py +++ b/typer/main.py @@ -46,6 +46,12 @@ except ImportError: # pragma: nocover rich = None # type: ignore +if sys.version_info >= (3, 10): + from types import UnionType +else: + # Python < 3.10 doesn't have UnionType, so we define it manually as non inheritable type + UnionType = type("", (), {}) + _original_except_hook = sys.excepthook _typer_developer_exception_attr_name = "__typer_developer_exception__" @@ -816,8 +822,9 @@ def get_click_param( is_tuple = False parameter_type: Any = None is_flag = None + is_union_type = lenient_issubclass(type(main_type), UnionType) origin = getattr(main_type, "__origin__", None) - if origin is not None: + if origin is not None or is_union_type: # Handle Optional[SomeType] if origin is Union: types = [] @@ -828,6 +835,17 @@ def get_click_param( assert len(types) == 1, "Typer Currently doesn't support Union types" main_type = types[0] origin = getattr(main_type, "__origin__", None) + # Handle (SomeType | None) + elif is_union_type: + types = [] + for type_ in main_type.__args__: + if type_ is NoneType: + continue + types.append(type_) + assert ( + len(types) == 1 + ), "Typer Currently doesn't support UnionType other than (T | None)" + main_type = types[0] # Handle Tuples and Lists if lenient_issubclass(origin, List): main_type = main_type.__args__[0]