Skip to content

Commit

Permalink
Enum support and strip annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab committed Feb 7, 2024
1 parent c38ed6a commit cb9c1c3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
25 changes: 24 additions & 1 deletion pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import is_dataclass
from enum import Enum
from pathlib import Path
from types import FunctionType
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, Tuple, TypeVar, Union, cast
Expand All @@ -27,7 +28,7 @@
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings

from argparse import SUPPRESS, ArgumentParser, _ArgumentGroup, _SubParsersAction
from argparse import SUPPRESS, Action, ArgumentParser, Namespace, _ArgumentGroup, _SubParsersAction

DotenvType = Union[Path, str, List[Union[Path, str]], Tuple[Union[Path, str], ...]]

Expand All @@ -50,6 +51,22 @@ class _CliPositionalArg:
CliPositionalArg = Annotated[T, _CliPositionalArg]


class _CliEnumAction(Action):
"""
CLI argparse action handler for enum types
"""

def __init__(self, **kwargs: Any):
self._enum = kwargs.pop('type')
kwargs['choices'] = tuple(val.name for val in self._enum)
super().__init__(**kwargs)

def __call__(
self, parser: ArgumentParser, namespace: Namespace, value: Any, option_string: str | None = None
) -> None:
setattr(namespace, self.dest, self._enum[value])


class EnvNoneType(str):
pass

Expand Down Expand Up @@ -917,6 +934,10 @@ def _add_fields_to_parser(
kwargs['action'] = 'append'
if _annotation_contains_types(field_info.annotation, (dict, Mapping), is_include_origin=True):
self._cli_dict_arg_names.append(kwargs['dest'])
elif lenient_issubclass(field_info.annotation, Enum):
kwargs['type'] = field_info.annotation
kwargs['action'] = _CliEnumAction
del kwargs['metavar']

arg_name = f'{arg_prefix.replace(subcommand_prefix, "", 1)}{field_name}'
if _CliPositionalArg in field_info.metadata:
Expand Down Expand Up @@ -972,6 +993,8 @@ def _metavar_format_list(self, args: list[str]) -> str:

def _metavar_format_recurse(self, obj: Any) -> str:
"""Pretty metavar representation of a type. Adapts logic from `pydantic._repr.display_as_type`."""
while get_origin(obj) == Annotated:
obj = get_args(obj)[0]
if isinstance(obj, FunctionType):
return obj.__name__
elif obj is ...:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing
import uuid
from datetime import datetime, timezone
from enum import IntEnum
from pathlib import Path
from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union

Expand All @@ -14,6 +15,7 @@
AliasChoices,
AliasPath,
BaseModel,
DirectoryPath,
Field,
HttpUrl,
Json,
Expand Down Expand Up @@ -2302,6 +2304,22 @@ class Cfg(BaseSettings):
assert cfg.model_dump() == {'child': {'name': 'new name a', 'diff_a': 'new diff a'}}


def test_cli_enum():
class Fruit(IntEnum):
apple = 0
banna = 1
orange = 2

class Cfg(BaseSettings):
fruit: Fruit

cfg = Cfg(_cli_parse_args=['--fruit', 'orange'])
assert cfg.model_dump() == {'fruit': Fruit.orange}

with pytest.raises(SystemExit):
Cfg(_cli_parse_args=['--fruit', 'lettuce'])


def test_cli_annotation_exceptions(monkeypatch):
class SubCmdAlt(BaseModel):
pass
Expand Down Expand Up @@ -2586,6 +2604,8 @@ class Settings(BaseSettings):
(Union[SimpleSettings, SettingWithIgnoreEmpty], 'JSON'),
(Union[SimpleSettings, str, SettingWithIgnoreEmpty], '{JSON,str}'),
(Union[str, SimpleSettings, SettingWithIgnoreEmpty], '{str,JSON}'),
(Annotated[SimpleSettings, 'annotation'], 'JSON'),
(DirectoryPath, 'Path'),
],
)
def test_cli_metavar_format(value, expected):
Expand Down

0 comments on commit cb9c1c3

Please sign in to comment.