From 7a14cffe508c9e3011d0d37dc146ccd02cfe5e13 Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Sat, 8 Nov 2025 11:54:13 -0500 Subject: [PATCH] Handle list of models, dict with model values, and existing dict --- hatch_build/cli.py | 254 +++++++++++++++++++++++----- hatch_build/tests/test_cli_model.py | 17 +- 2 files changed, 232 insertions(+), 39 deletions(-) diff --git a/hatch_build/cli.py b/hatch_build/cli.py index 68bc970..13e5894 100644 --- a/hatch_build/cli.py +++ b/hatch_build/cli.py @@ -1,5 +1,5 @@ from argparse import ArgumentParser -from logging import getLogger +from logging import Formatter, StreamHandler, getLogger from pathlib import Path from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, Union, get_args, get_origin @@ -16,6 +16,10 @@ _extras = None _log = getLogger(__name__) +_handler = StreamHandler() +_formatter = Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z") +_handler.setFormatter(_formatter) +_log.addHandler(_handler) def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]: @@ -27,18 +31,34 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]: def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["BaseModel"]], prefix: str = ""): from pydantic import BaseModel + from pydantic_core import PydanticUndefined + # Model is required if model is None: raise ValueError("Model instance cannot be None") + + # Extract the fields from a model instance or class if isinstance(model, type): model_fields = model.model_fields else: model_fields = model.__class__.model_fields + + # For each available field, add an argument to the parser for field_name, field in model_fields.items(): + # Grab the annotation to map to type field_type = field.annotation - arg_name = f"--{prefix}{field_name.replace('_', '-')}" + # Build the argument name converting underscores to dashes + arg_name = f"--{prefix.replace('_', '-')}{field_name.replace('_', '-')}" + + # If theres an instance, use that so we have concrete values + model_instance = model if not isinstance(model, type) else None - # Wrappers + # If we have an instance, grab the field value + field_instance = getattr(model_instance, field_name, None) if model_instance else None + + # MARK: Wrappers: + # - Optional[T] + # - Union[T, None] if get_origin(field_type) is Optional: field_type = get_args(field_type)[0] elif get_origin(field_type) is Union: @@ -49,44 +69,126 @@ def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type[" _log.warning(f"Unsupported Union type for argument '{field_name}': {field_type}") continue + # Default value, promote PydanticUndefined to None + if field.default is PydanticUndefined: + default_value = None + else: + default_value = field.default + # Handled types + # - bool, str, int, float + # - Path + # - Nested BaseModel + # - Literal + # - List[T] + # - where T is bool, str, int, float + # - List[BaseModel] where we have an instance to recurse into + # - Dict[str, T] + # - where T is bool, str, int, float + # - Dict[str, BaseModel] where we have an instance to recurse into if field_type is bool: - parser.add_argument(arg_name, action="store_true", default=field.default) + ############# + # MARK: bool + parser.add_argument(arg_name, action="store_true", default=default_value) elif field_type in (str, int, float): + ######################## + # MARK: str, int, float try: - parser.add_argument(arg_name, type=field_type, default=field.default) + parser.add_argument(arg_name, type=field_type, default=default_value) except TypeError: # TODO: handle more complex types if needed - parser.add_argument(arg_name, type=str, default=field.default) + parser.add_argument(arg_name, type=str, default=default_value) elif isinstance(field_type, type) and issubclass(field_type, Path): + ############# + # MARK: Path # Promote to/from string - parser.add_argument(arg_name, type=str, default=str(field.default) if isinstance(field.default, Path) else None) - elif isinstance(field_type, Type) and issubclass(field_type, BaseModel): + parser.add_argument(arg_name, type=str, default=str(default_value) if isinstance(default_value, Path) else None) + elif isinstance(field_instance, BaseModel): + ############################ + # MARK: instance(BaseModel) # Nested model, add its fields with a prefix + _recurse_add_fields(parser, field_instance, prefix=f"{field_name}.") + elif isinstance(field_type, Type) and issubclass(field_type, BaseModel): + ######################## + # MARK: type(BaseModel) + # Nested model class, add its fields with a prefix _recurse_add_fields(parser, field_type, prefix=f"{field_name}.") elif get_origin(field_type) is Literal: + ################ + # MARK: Literal literal_args = get_args(field_type) if not all(isinstance(arg, (str, int, float, bool)) for arg in literal_args): - _log.warning(f"Only Literal types of str, int, float, or bool are supported - got {literal_args}") - else: - parser.add_argument(arg_name, type=type(literal_args[0]), choices=literal_args, default=field.default) + # Only support simple literal types for now + _log.warning(f"Only Literal types of str, int, float, or bool are supported - field `{field_name}` got {literal_args}") + continue + #################################### + # MARK: Literal[str|int|float|bool] + parser.add_argument(arg_name, type=type(literal_args[0]), choices=literal_args, default=default_value) elif get_origin(field_type) in (list, List): - # TODO: if list arg is complex type, warn as not implemented for now + ################ + # MARK: List[T] if get_args(field_type) and get_args(field_type)[0] not in (str, int, float, bool): - _log.warning(f"Only lists of str, int, float, or bool are supported - got {get_args(field_type)[0]}") - else: - parser.add_argument(arg_name, type=str, default=",".join(map(str, field.default)) if isinstance(field, str) else None) + # If theres already something here, we can procede by adding the command with a positional indicator + if field_instance: + ######################## + # MARK: List[BaseModel] + for i, value in enumerate(field_instance): + _recurse_add_fields(parser, value, prefix=f"{field_name}.{i}.") + continue + # If there's nothing here, we don't know how to address them + # TODO: we could just prefill e.g. --field.0, --field.1 up to some limit + _log.warning(f"Only lists of str, int, float, or bool are supported - field `{field_name}` got {get_args(field_type)[0]}") + continue + ################################# + # MARK: List[str|int|float|bool] + parser.add_argument(arg_name, type=str, default=",".join(map(str, default_value)) if isinstance(field, str) else None) elif get_origin(field_type) in (dict, Dict): - # TODO: if key args are complex type, warn as not implemented for now + ###################### + # MARK: Dict[str, T] key_type, value_type = get_args(field_type) - if key_type not in (str, int, float, bool): - _log.warning(f"Only dicts with str keys are supported - got key type {key_type}") - if value_type not in (str, int, float, bool): - _log.warning(f"Only dicts with str values are supported - got value type {value_type}") - else: - parser.add_argument( - arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items()) if isinstance(field.default, dict) else None - ) + + if key_type not in (str, int, float, bool) and not ( + get_origin(key_type) is Literal and all(isinstance(arg, (str, int, float, bool)) for arg in get_args(key_type)) + ): + # Check Key type, must be str, int, float, bool + _log.warning(f"Only dicts with str keys are supported - field `{field_name}` got key type {key_type}") + continue + + if value_type not in (str, int, float, bool) and not field_instance: + # Check Value type, must be str, int, float, bool if an instance isnt provided + _log.warning(f"Only dicts with str values are supported - field `{field_name}` got value type {value_type}") + continue + + # If theres already something here, we can procede by adding the command by keyword + if field_instance: + if all(isinstance(v, BaseModel) for v in field_instance.values()): + ############################# + # MARK: Dict[str, BaseModel] + for key, value in field_instance.items(): + _recurse_add_fields(parser, value, prefix=f"{field_name}.{key}.") + continue + # If we have mixed, we don't support + elif any(isinstance(v, BaseModel) for v in field_instance.values()): + _log.warning(f"Mixed dict value types are not supported - field `{field_name}` has mixed BaseModel and non-BaseModel values") + continue + # If we have non BaseModel values, we can still add a parser by route + if all(isinstance(v, (str, int, float, bool)) for v in field_instance.values()): + # We can set "known" values here + for key, value in field_instance.items(): + ########################################## + # MARK: Dict[str, str|int|float|bool] + parser.add_argument( + f"{arg_name}.{key}", + type=type(value), + default=value, + ) + # NOTE: don't continue to allow adding the full setter below + # Finally add the full setter for unknown values + ########################################## + # MARK: Dict[str, str|int|float|bool|str] + parser.add_argument( + arg_name, type=str, default=",".join(f"{k}={v}" for k, v in default_value.items()) if isinstance(default_value, dict) else None + ) else: _log.warning(f"Unsupported field type for argument '{field_name}': {field_type}") return parser @@ -107,20 +209,46 @@ def parse_extra_args_model(model: "BaseModel"): for key, value in args.items(): # Handle nested fields if "." in key: + # We're going to walk down the model tree to get to the right sub-model parts = key.split(".") + + # Accounting sub_model = model - for part in parts[:-1]: - model_to_set = getattr(sub_model, part) + parent_model = None + + for i, part in enumerate(parts[:-1]): + if part.isdigit() and isinstance(sub_model, list): + # List index + index = int(part) + + # Should never be out of bounds, but check to be sure + if index >= len(sub_model): + raise IndexError(f"Index {index} out of range for field '{parts[i - 1]}' on model '{parent_model.__class__.__name__}'") + + # Grab the model instance from the list + model_to_set = sub_model[index] + elif isinstance(sub_model, dict): + # Dict key + if part not in sub_model: + raise KeyError(f"Key '{part}' not found for field '{parts[i - 1]}' on model '{parent_model.__class__.__name__}'") + + # Grab the model instance from the dict + model_to_set = sub_model[part] + else: + model_to_set = getattr(sub_model, part) + if model_to_set is None: # Create a new instance of model field = sub_model.__class__.model_fields[part] - # if field annotation is an optional or union with none, extrat type + + # if field annotation is an optional or union with none, extract type if get_origin(field.annotation) is Optional: model_to_instance = get_args(field.annotation)[0] elif get_origin(field.annotation) is Union: non_none_types = [t for t in get_args(field.annotation) if t is not type(None)] if len(non_none_types) == 1: model_to_instance = non_none_types[0] + else: model_to_instance = field.annotation if not isinstance(model_to_instance, type) or not issubclass(model_to_instance, BaseModel): @@ -129,10 +257,39 @@ def parse_extra_args_model(model: "BaseModel"): ) model_to_set = model_to_instance() setattr(sub_model, part, model_to_set) + + parent_model = sub_model + sub_model = model_to_set + key = parts[-1] else: + # Accounting + sub_model = model + parent_model = model model_to_set = model + if not isinstance(model_to_set, BaseModel): + if isinstance(model_to_set, dict): + # We allow setting dict values directly + # Grab the dict from the parent model, set the value, and continue + if key in model_to_set: + model_to_set[key] = value + elif key.replace("_", "-") in model_to_set: + # Argparse converts dashes back to underscores, so undo + model_to_set[key.replace("_", "-")] = value + else: + # Raise + raise KeyError(f"Key '{key}' not found in dict field on model '{parent_model.__class__.__name__}'") + + # Now adjust our variable accounting to set the whole dict back on the parent model, + # allowing us to trigger any validation + key = part + value = model_to_set + model_to_set = parent_model + else: + _log.warning(f"Cannot set field '{key}' on non-BaseModel instance of type '{type(model_to_set).__name__}'") + continue + # Grab the field from the model class and make a type adapter field = model_to_set.__class__.model_fields[key] adapter = TypeAdapter(field.annotation) @@ -140,24 +297,45 @@ def parse_extra_args_model(model: "BaseModel"): # Convert the value using the type adapter if get_origin(field.annotation) in (list, List): value = value or "" - value = value.split(",") + if isinstance(value, list): + # Already a list, use as is + pass + elif isinstance(value, str): + # Convert from comma-separated values + value = value.split(",") + else: + # Unknown, raise + raise ValueError(f"Cannot convert value '{value}' to list for field '{key}'") elif get_origin(field.annotation) in (dict, Dict): value = value or "" - dict_items = value.split(",") - dict_value = {} - for item in dict_items: - if item: - k, v = item.split("=", 1) - dict_value[k] = v - value = dict_value + if isinstance(value, dict): + # Already a dict, use as is + pass + elif isinstance(value, str): + # Convert from comma-separated key=value pairs + dict_items = value.split(",") + dict_value = {} + for item in dict_items: + if item: + k, v = item.split("=", 1) + dict_value[k] = v + # Grab any previously existing dict to preserve other keys + existing_dict = getattr(model_to_set, key, {}) or {} + dict_value.update(existing_dict) + value = dict_value + else: + # Unknown, raise + raise ValueError(f"Cannot convert value '{value}' to dict for field '{key}'") try: - value = adapter.validate_python(value) + if value is not None: + value = adapter.validate_python(value) + + # Set the value on the model + setattr(model_to_set, key, value) except ValidationError: _log.warning(f"Failed to validate field '{key}' with value '{value}' for model '{model_to_set.__class__.__name__}'") continue - # Set the value on the model - setattr(model_to_set, key, value) return model, kwargs diff --git a/hatch_build/tests/test_cli_model.py b/hatch_build/tests/test_cli_model.py index 4213ad0..c14c045 100644 --- a/hatch_build/tests/test_cli_model.py +++ b/hatch_build/tests/test_cli_model.py @@ -20,13 +20,19 @@ class MyTopLevelModel(BaseModel, validate_assignment=True): extra_arg_literal: Literal["a", "b", "c"] = "a" list_arg: List[int] = [1, 2, 3] - dict_arg: Dict[str, str] = {"key": "value"} + dict_arg: Dict[str, str] = {} + dict_arg_default_values: Dict[str, str] = {"existing-key": "existing-value"} path_arg: Path = Path(".") submodel: SubModel submodel2: SubModel = SubModel() submodel3: Optional[SubModel] = None + submodel_list: List[SubModel] = [] + submodel_list_instanced: List[SubModel] = [SubModel()] + submodel_dict: Dict[str, SubModel] = {} + submodel_dict_instanced: Dict[str, SubModel] = {"a": SubModel()} + class TestCLIMdel: def test_get_arg_from_model(self): @@ -47,6 +53,8 @@ def test_get_arg_from_model(self): "1,2,3", "--dict-arg", "key1=value1,key2=value2", + "--dict-arg-default-values.existing-key", + "new-value", "--path-arg", "/some/path", "--submodel.sub-arg", @@ -59,6 +67,10 @@ def test_get_arg_from_model(self): "sub_value2", "--submodel3.sub-arg", "300", + "--submodel-list-instanced.0.sub-arg", + "400", + "--submodel-dict-instanced.a.sub-arg", + "500", ], ): assert hatchling() == 0 @@ -70,11 +82,14 @@ def test_get_arg_from_model(self): assert model.extra_arg_literal == "b" assert model.list_arg == [1, 2, 3] assert model.dict_arg == {"key1": "value1", "key2": "value2"} + assert model.dict_arg_default_values == {"existing-key": "new-value"} assert model.path_arg == Path("/some/path") assert model.submodel.sub_arg == 100 assert model.submodel.sub_arg_with_value == "sub_value" assert model.submodel2.sub_arg == 200 assert model.submodel2.sub_arg_with_value == "sub_value2" assert model.submodel3.sub_arg == 300 + assert model.submodel_list_instanced[0].sub_arg == 400 + assert model.submodel_dict_instanced["a"].sub_arg == 500 assert "--extra-arg-not-in-parser" in extras