From b3504c1b8404fd1cf8f1814fa097f4fac835ee6c Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Thu, 6 Nov 2025 16:27:08 -0500 Subject: [PATCH] Instantiate missing models, support paths --- hatch_build/cli.py | 53 +++++++++++++++++++++++++---- hatch_build/tests/test_cli_model.py | 13 +++++-- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/hatch_build/cli.py b/hatch_build/cli.py index ad3590e..bbaecfb 100644 --- a/hatch_build/cli.py +++ b/hatch_build/cli.py @@ -1,6 +1,7 @@ from argparse import ArgumentParser from logging import getLogger -from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, get_args, get_origin +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, Union, get_args, get_origin from hatchling.cli.build import build_command @@ -24,12 +25,31 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]: return vars(kwargs), extras -def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str = ""): +def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["BaseModel"]], prefix: str = ""): from pydantic import BaseModel - for field_name, field in model.__class__.model_fields.items(): + if model is None: + raise ValueError("Model instance cannot be None") + if isinstance(model, type): + model_fields = model.model_fields + else: + model_fields = model.__class__.model_fields + for field_name, field in model_fields.items(): field_type = field.annotation arg_name = f"--{prefix}{field_name.replace('_', '-')}" + + # Wrappers + if get_origin(field_type) is Optional: + field_type = get_args(field_type)[0] + elif get_origin(field_type) is Union: + non_none_types = [t for t in get_args(field_type) if t is not type(None)] + if len(non_none_types) == 1: + field_type = non_none_types[0] + else: + _log.warning(f"Unsupported Union type for argument '{field_name}': {field_type}") + continue + + # Handled types if field_type is bool: parser.add_argument(arg_name, action="store_true", default=field.default) elif field_type in (str, int, float): @@ -38,9 +58,12 @@ def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str except TypeError: # TODO: handle more complex types if needed parser.add_argument(arg_name, type=str, default=field.default) + elif isinstance(field_type, type) and issubclass(field_type, Path): + # Promote to/from string + parser.add_argument(arg_name, type=str, default=str(field.default) if isinstance(field.default, Path) else None) elif isinstance(field_type, Type) and issubclass(field_type, BaseModel): # Nested model, add its fields with a prefix - _recurse_add_fields(parser, getattr(model, field_name), prefix=f"{field_name}.") + _recurse_add_fields(parser, field_type, prefix=f"{field_name}.") elif get_origin(field_type) is Literal: literal_args = get_args(field_type) if not all(isinstance(arg, (str, int, float, bool)) for arg in literal_args): @@ -65,13 +88,13 @@ def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items()) if isinstance(field.default, dict) else None ) else: - _log.warning(f"Unsupported field type for argument '{arg_name}': {field_type}") + _log.warning(f"Unsupported field type for argument '{field_name}': {field_type}") return parser def parse_extra_args_model(model: "BaseModel"): try: - from pydantic import TypeAdapter + from pydantic import BaseModel, TypeAdapter except ImportError: raise ImportError("pydantic is required to use parse_extra_args_model") # Recursively parse fields from a pydantic model and its sub-models @@ -88,6 +111,24 @@ def parse_extra_args_model(model: "BaseModel"): sub_model = model for part in parts[:-1]: model_to_set = getattr(sub_model, part) + if model_to_set is None: + # Create a new instance of model + field = sub_model.__class__.model_fields[part] + # if field annotation is an optional or union with none, extrat type + if get_origin(field.annotation) is Optional: + model_to_instance = get_args(field.annotation)[0] + elif get_origin(field.annotation) is Union: + non_none_types = [t for t in get_args(field.annotation) if t is not type(None)] + if len(non_none_types) == 1: + model_to_instance = non_none_types[0] + else: + model_to_instance = field.annotation + if not isinstance(model_to_instance, type) or not issubclass(model_to_instance, BaseModel): + raise ValueError( + f"Cannot create sub-model for field '{part}' on model '{sub_model.__class__.__name__}': - type is {model_to_instance}" + ) + model_to_set = model_to_instance() + setattr(sub_model, part, model_to_set) key = parts[-1] else: model_to_set = model diff --git a/hatch_build/tests/test_cli_model.py b/hatch_build/tests/test_cli_model.py index 11c3194..4213ad0 100644 --- a/hatch_build/tests/test_cli_model.py +++ b/hatch_build/tests/test_cli_model.py @@ -1,5 +1,6 @@ import sys -from typing import Dict, List, Literal +from pathlib import Path +from typing import Dict, List, Literal, Optional from unittest.mock import patch from pydantic import BaseModel @@ -15,14 +16,16 @@ class SubModel(BaseModel, validate_assignment=True): class MyTopLevelModel(BaseModel, validate_assignment=True): extra_arg: bool = False extra_arg_with_value: str = "default" - extra_arg_with_value_equals: str = "default_equals" + extra_arg_with_value_equals: Optional[str] = "default_equals" extra_arg_literal: Literal["a", "b", "c"] = "a" list_arg: List[int] = [1, 2, 3] dict_arg: Dict[str, str] = {"key": "value"} + path_arg: Path = Path(".") submodel: SubModel submodel2: SubModel = SubModel() + submodel3: Optional[SubModel] = None class TestCLIMdel: @@ -44,6 +47,8 @@ def test_get_arg_from_model(self): "1,2,3", "--dict-arg", "key1=value1,key2=value2", + "--path-arg", + "/some/path", "--submodel.sub-arg", "100", "--submodel.sub-arg-with-value", @@ -52,6 +57,8 @@ def test_get_arg_from_model(self): "200", "--submodel2.sub-arg-with-value", "sub_value2", + "--submodel3.sub-arg", + "300", ], ): assert hatchling() == 0 @@ -63,9 +70,11 @@ def test_get_arg_from_model(self): assert model.extra_arg_literal == "b" assert model.list_arg == [1, 2, 3] assert model.dict_arg == {"key1": "value1", "key2": "value2"} + assert model.path_arg == Path("/some/path") assert model.submodel.sub_arg == 100 assert model.submodel.sub_arg_with_value == "sub_value" assert model.submodel2.sub_arg == 200 assert model.submodel2.sub_arg_with_value == "sub_value2" + assert model.submodel3.sub_arg == 300 assert "--extra-arg-not-in-parser" in extras