Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions hatch_build/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from argparse import ArgumentParser
from logging import getLogger
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, get_args, get_origin
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, Union, get_args, get_origin

from hatchling.cli.build import build_command

Expand All @@ -24,12 +25,31 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
return vars(kwargs), extras


def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str = ""):
def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["BaseModel"]], prefix: str = ""):
from pydantic import BaseModel

for field_name, field in model.__class__.model_fields.items():
if model is None:
raise ValueError("Model instance cannot be None")
if isinstance(model, type):
model_fields = model.model_fields
else:
model_fields = model.__class__.model_fields
for field_name, field in model_fields.items():
field_type = field.annotation
arg_name = f"--{prefix}{field_name.replace('_', '-')}"

# Wrappers
if get_origin(field_type) is Optional:
field_type = get_args(field_type)[0]
elif get_origin(field_type) is Union:
non_none_types = [t for t in get_args(field_type) if t is not type(None)]
if len(non_none_types) == 1:
field_type = non_none_types[0]
else:
_log.warning(f"Unsupported Union type for argument '{field_name}': {field_type}")
continue

# Handled types
if field_type is bool:
parser.add_argument(arg_name, action="store_true", default=field.default)
elif field_type in (str, int, float):
Expand All @@ -38,9 +58,12 @@ def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str
except TypeError:
# TODO: handle more complex types if needed
parser.add_argument(arg_name, type=str, default=field.default)
elif isinstance(field_type, type) and issubclass(field_type, Path):
# Promote to/from string
parser.add_argument(arg_name, type=str, default=str(field.default) if isinstance(field.default, Path) else None)
elif isinstance(field_type, Type) and issubclass(field_type, BaseModel):
# Nested model, add its fields with a prefix
_recurse_add_fields(parser, getattr(model, field_name), prefix=f"{field_name}.")
_recurse_add_fields(parser, field_type, prefix=f"{field_name}.")
elif get_origin(field_type) is Literal:
literal_args = get_args(field_type)
if not all(isinstance(arg, (str, int, float, bool)) for arg in literal_args):
Expand All @@ -65,13 +88,13 @@ def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str
arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items()) if isinstance(field.default, dict) else None
)
else:
_log.warning(f"Unsupported field type for argument '{arg_name}': {field_type}")
_log.warning(f"Unsupported field type for argument '{field_name}': {field_type}")
return parser


def parse_extra_args_model(model: "BaseModel"):
try:
from pydantic import TypeAdapter
from pydantic import BaseModel, TypeAdapter
except ImportError:
raise ImportError("pydantic is required to use parse_extra_args_model")
# Recursively parse fields from a pydantic model and its sub-models
Expand All @@ -88,6 +111,24 @@ def parse_extra_args_model(model: "BaseModel"):
sub_model = model
for part in parts[:-1]:
model_to_set = getattr(sub_model, part)
if model_to_set is None:
# Create a new instance of model
field = sub_model.__class__.model_fields[part]
# if field annotation is an optional or union with none, extrat type
if get_origin(field.annotation) is Optional:
model_to_instance = get_args(field.annotation)[0]
elif get_origin(field.annotation) is Union:
non_none_types = [t for t in get_args(field.annotation) if t is not type(None)]
if len(non_none_types) == 1:
model_to_instance = non_none_types[0]
else:
model_to_instance = field.annotation
if not isinstance(model_to_instance, type) or not issubclass(model_to_instance, BaseModel):
raise ValueError(
f"Cannot create sub-model for field '{part}' on model '{sub_model.__class__.__name__}': - type is {model_to_instance}"
)
model_to_set = model_to_instance()
setattr(sub_model, part, model_to_set)
key = parts[-1]
else:
model_to_set = model
Expand Down
13 changes: 11 additions & 2 deletions hatch_build/tests/test_cli_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from typing import Dict, List, Literal
from pathlib import Path
from typing import Dict, List, Literal, Optional
from unittest.mock import patch

from pydantic import BaseModel
Expand All @@ -15,14 +16,16 @@ class SubModel(BaseModel, validate_assignment=True):
class MyTopLevelModel(BaseModel, validate_assignment=True):
extra_arg: bool = False
extra_arg_with_value: str = "default"
extra_arg_with_value_equals: str = "default_equals"
extra_arg_with_value_equals: Optional[str] = "default_equals"
extra_arg_literal: Literal["a", "b", "c"] = "a"

list_arg: List[int] = [1, 2, 3]
dict_arg: Dict[str, str] = {"key": "value"}
path_arg: Path = Path(".")

submodel: SubModel
submodel2: SubModel = SubModel()
submodel3: Optional[SubModel] = None


class TestCLIMdel:
Expand All @@ -44,6 +47,8 @@ def test_get_arg_from_model(self):
"1,2,3",
"--dict-arg",
"key1=value1,key2=value2",
"--path-arg",
"/some/path",
"--submodel.sub-arg",
"100",
"--submodel.sub-arg-with-value",
Expand All @@ -52,6 +57,8 @@ def test_get_arg_from_model(self):
"200",
"--submodel2.sub-arg-with-value",
"sub_value2",
"--submodel3.sub-arg",
"300",
],
):
assert hatchling() == 0
Expand All @@ -63,9 +70,11 @@ def test_get_arg_from_model(self):
assert model.extra_arg_literal == "b"
assert model.list_arg == [1, 2, 3]
assert model.dict_arg == {"key1": "value1", "key2": "value2"}
assert model.path_arg == Path("/some/path")
assert model.submodel.sub_arg == 100
assert model.submodel.sub_arg_with_value == "sub_value"
assert model.submodel2.sub_arg == 200
assert model.submodel2.sub_arg_with_value == "sub_value2"
assert model.submodel3.sub_arg == 300

assert "--extra-arg-not-in-parser" in extras