Skip to content

Commit

Permalink
✅ Add tests for missing pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
pypae committed Apr 25, 2024
1 parent 6c76ab9 commit ebb5877
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
import sys

import pytest
import typer
from typer.testing import CliRunner

Expand Down Expand Up @@ -33,3 +34,18 @@ def test_script():
encoding="utf-8",
)
assert "Usage" in result.stdout


def test_error_without_pydantic():
pydantic = typer.pydantic_extension.pydantic
typer.pydantic_extension.pydantic = None
with pytest.raises(
RuntimeError,
match="Type not yet supported: <class 'docs_src.parameter_types.pydantic.tutorial001.User'>",
):
runner.invoke(
app,
["1", "--user.id", "2", "--user.name", "John Doe"],
catch_exceptions=False,
)
typer.pydantic_extension.pydantic = pydantic
10 changes: 5 additions & 5 deletions typer/pydantic_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@
PYDANTIC_FIELD_SEPARATOR = "."


def flatten_pydantic_model(
def _flatten_pydantic_model(
model: "pydantic.BaseModel", ancestors: List[str]
) -> Dict[str, inspect.Parameter]:
if pydantic is None:
raise ImportError("Pydantic is required to use Pydantic models with Typer.")
# This function should only be called if pydantic is available
assert pydantic is not None
pydantic_parameters = {}
for field_name, field in model.model_fields.items():
qualifier = [*ancestors, field_name]
sub_name = f"_pydantic_{'_'.join(qualifier)}"
if lenient_issubclass(field.annotation, pydantic.BaseModel):
params = flatten_pydantic_model(field.annotation, qualifier) # type: ignore
params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore
pydantic_parameters.update(params)
else:
default = (
Expand All @@ -51,7 +51,7 @@ def wrap_pydantic_callback(callback: Callable[..., Any]) -> Callable[..., Any]:
other_parameters = {}
for name, parameter in original_signature.parameters.items():
if lenient_issubclass(parameter.annotation, pydantic.BaseModel):
params = flatten_pydantic_model(parameter.annotation, [name])
params = _flatten_pydantic_model(parameter.annotation, [name])
pydantic_parameters.update(params)
pydantic_roots[name] = parameter.annotation
else:
Expand Down

0 comments on commit ebb5877

Please sign in to comment.