Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,9 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
a flag to determine whether value is complex.
"""

for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
field_infos = self._extract_field_info(field, field_name)
preferred_key, *_ = field_infos[0]
for field_key, env_name, value_is_complex in field_infos:
# paths reversed to match the last-wins behaviour of `env_file`
for secrets_path in reversed(self.secrets_paths):
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
Expand All @@ -670,14 +672,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
continue

if path.is_file():
return path.read_text().strip(), field_key, value_is_complex
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
preferred_key = field_key
return path.read_text().strip(), preferred_key, value_is_complex
else:
warnings.warn(
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
stacklevel=4,
)

return None, field_key, value_is_complex
return None, preferred_key, value_is_complex

def __repr__(self) -> str:
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'
Expand Down Expand Up @@ -725,12 +729,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
"""

env_val: str | None = None
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
field_infos = self._extract_field_info(field, field_name)
preferred_key, *_ = field_infos[0]
for field_key, env_name, value_is_complex in field_infos:
env_val = self.env_vars.get(env_name)
if env_val is not None:
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
preferred_key = field_key
break

return env_val, field_key, value_is_complex
return env_val, preferred_key, value_is_complex

def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
"""
Expand Down
33 changes: 32 additions & 1 deletion tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import typing_extensions
from pydantic import (
AliasChoices,
AliasGenerator,
AliasPath,
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -107,7 +108,7 @@ def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace:
return self.parser.parse_args(*args, **kwargs)


def test_validation_alias_with_cli_prefix():
def test_cli_validation_alias_with_cli_prefix():
class Settings(BaseSettings, cli_exit_on_error=False):
foobar: str = Field(validation_alias='foo')

Expand All @@ -119,6 +120,36 @@ class Settings(BaseSettings, cli_exit_on_error=False):
assert CliApp.run(Settings, cli_args=['--p.foo', 'bar']).foobar == 'bar'


@pytest.mark.parametrize(
'alias_generator',
[
AliasGenerator(validation_alias=lambda s: AliasChoices(s, s.replace('_', '-'))),
AliasGenerator(validation_alias=lambda s: AliasChoices(s.replace('_', '-'), s)),
],
)
def test_cli_alias_resolution_consistency_with_env(env, alias_generator):
class SubModel(BaseModel):
v1: str = 'model default'

class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_nested_delimiter='__',
nested_model_default_partial_update=True,
alias_generator=alias_generator,
)

sub_model: SubModel = SubModel(v1='top default')

assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'top default'}}

env.set('SUB_MODEL__V1', 'env default')
assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'env default'}}

assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli default']).model_dump() == {
'sub_model': {'v1': 'cli default'}
}


def test_cli_nested_arg():
class SubSubValue(BaseModel):
v6: str
Expand Down
Loading