diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py index 51e83314de..904a686d2e 100644 --- a/tests/test_type_conversion.py +++ b/tests/test_type_conversion.py @@ -7,6 +7,8 @@ import typer from typer.testing import CliRunner +from .utils import needs_py310 + runner = CliRunner() @@ -29,6 +31,26 @@ def opt(user: Optional[str] = None): assert "User: Camila" in result.output +@needs_py310 +def test_union_type_optional(): + app = typer.Typer() + + @app.command() + def opt(user: str | None = None): + if user: + print(f"User: {user}") + else: + print("No user") + + result = runner.invoke(app) + assert result.exit_code == 0 + assert "No user" in result.output + + result = runner.invoke(app, ["--user", "Camila"]) + assert result.exit_code == 0 + assert "User: Camila" in result.output + + def test_optional_tuple(): app = typer.Typer() diff --git a/typer/main.py b/typer/main.py index 9db26975ca..d291dc1944 100644 --- a/typer/main.py +++ b/typer/main.py @@ -13,6 +13,7 @@ import click +from ._typing import get_args, get_origin, is_union from .completion import get_completion_inspect_parameters from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption from .models import ( @@ -825,30 +826,31 @@ def get_click_param( is_tuple = False parameter_type: Any = None is_flag = None - origin = getattr(main_type, "__origin__", None) + origin = get_origin(main_type) + if origin is not None: - # Handle Optional[SomeType] - if origin is Union: + # Handle SomeType | None and Optional[SomeType] + if is_union(origin): types = [] - for type_ in main_type.__args__: + for type_ in get_args(main_type): if type_ is NoneType: continue types.append(type_) assert len(types) == 1, "Typer Currently doesn't support Union types" main_type = types[0] - origin = getattr(main_type, "__origin__", None) + origin = get_origin(main_type) # Handle Tuples and Lists if lenient_issubclass(origin, List): - main_type = main_type.__args__[0] - assert not getattr( - main_type, "__origin__", None + main_type = get_args(main_type)[0] + assert not get_origin( + main_type ), "List types with complex sub-types are not currently supported" is_list = True elif lenient_issubclass(origin, Tuple): # type: ignore types = [] - for type_ in main_type.__args__: - assert not getattr( - type_, "__origin__", None + for type_ in get_args(main_type): + assert not get_origin( + type_ ), "Tuple types with complex sub-types are not currently supported" types.append( get_click_type(annotation=type_, parameter_info=parameter_info) @@ -865,7 +867,7 @@ def get_click_param( convertor=convertor, default_value=default_value ) if is_tuple: - convertor = generate_tuple_convertor(main_type.__args__) + convertor = generate_tuple_convertor(get_args(main_type)) if isinstance(parameter_info, OptionInfo): if main_type is bool and parameter_info.is_flag is not False: is_flag = True @@ -1019,7 +1021,7 @@ def get_param_completion( incomplete_name = None unassigned_params = list(parameters.values()) for param_sig in unassigned_params[:]: - origin = getattr(param_sig.annotation, "__origin__", None) + origin = get_origin(param_sig.annotation) if lenient_issubclass(param_sig.annotation, click.Context): ctx_name = param_sig.name unassigned_params.remove(param_sig)