Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Fix support for UnionType with Python 3.11 #548

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
22 changes: 22 additions & 0 deletions tests/test_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import typer
from typer.testing import CliRunner

from .utils import needs_py310

runner = CliRunner()


Expand All @@ -29,6 +31,26 @@ def opt(user: Optional[str] = None):
assert "User: Camila" in result.output


@needs_py310
def test_union_type_optional():
app = typer.Typer()

@app.command()
def opt(user: str | None = None):
if user:
print(f"User: {user}")
else:
print("No user")

result = runner.invoke(app)
assert result.exit_code == 0
assert "No user" in result.output

result = runner.invoke(app, ["--user", "Camila"])
assert result.exit_code == 0
assert "User: Camila" in result.output


def test_optional_tuple():
app = typer.Typer()

Expand Down
28 changes: 15 additions & 13 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import click

from ._typing import get_args, get_origin, is_union
from .completion import get_completion_inspect_parameters
from .core import MarkupMode, TyperArgument, TyperCommand, TyperGroup, TyperOption
from .models import (
Expand Down Expand Up @@ -825,30 +826,31 @@ def get_click_param(
is_tuple = False
parameter_type: Any = None
is_flag = None
origin = getattr(main_type, "__origin__", None)
origin = get_origin(main_type)

if origin is not None:
# Handle Optional[SomeType]
if origin is Union:
# Handle SomeType | None and Optional[SomeType]
if is_union(origin):
types = []
for type_ in main_type.__args__:
for type_ in get_args(main_type):
if type_ is NoneType:
continue
types.append(type_)
assert len(types) == 1, "Typer Currently doesn't support Union types"
main_type = types[0]
origin = getattr(main_type, "__origin__", None)
origin = get_origin(main_type)
# Handle Tuples and Lists
if lenient_issubclass(origin, List):
main_type = main_type.__args__[0]
assert not getattr(
main_type, "__origin__", None
main_type = get_args(main_type)[0]
assert not get_origin(
main_type
), "List types with complex sub-types are not currently supported"
is_list = True
elif lenient_issubclass(origin, Tuple): # type: ignore
types = []
for type_ in main_type.__args__:
assert not getattr(
type_, "__origin__", None
for type_ in get_args(main_type):
assert not get_origin(
type_
), "Tuple types with complex sub-types are not currently supported"
types.append(
get_click_type(annotation=type_, parameter_info=parameter_info)
Expand All @@ -865,7 +867,7 @@ def get_click_param(
convertor=convertor, default_value=default_value
)
if is_tuple:
convertor = generate_tuple_convertor(main_type.__args__)
convertor = generate_tuple_convertor(get_args(main_type))
if isinstance(parameter_info, OptionInfo):
if main_type is bool and parameter_info.is_flag is not False:
is_flag = True
Expand Down Expand Up @@ -1019,7 +1021,7 @@ def get_param_completion(
incomplete_name = None
unassigned_params = list(parameters.values())
for param_sig in unassigned_params[:]:
origin = getattr(param_sig.annotation, "__origin__", None)
origin = get_origin(param_sig.annotation)
if lenient_issubclass(param_sig.annotation, click.Context):
ctx_name = param_sig.name
unassigned_params.remove(param_sig)
Expand Down