From b290b31fed18b77f7e168e5f58dbb91b0131606a Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:43:15 -0600 Subject: [PATCH] Fix usage of `AliasGenerator` with `computed_field` decorator (#8806) Co-authored-by: Alex Hall --- pydantic/_internal/_generate_schema.py | 63 ++++++++++++++++++++------ tests/test_aliases.py | 23 ++++++++++ 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index d7df88b39c..76b179f122 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -973,7 +973,7 @@ def _apply_alias_generator_to_field_info( # Apply an alias_generator if # 1. An alias is not specified # 2. An alias is specified, but the priority is <= 1 - if alias_generator and ( + if ( field_info.alias_priority is None or field_info.alias_priority <= 1 or field_info.alias is None @@ -1009,6 +1009,49 @@ def _apply_alias_generator_to_field_info( if field_info.validation_alias is None: field_info.validation_alias = validation_alias or alias + @staticmethod + def _apply_alias_generator_to_computed_field_info( + alias_generator: Callable[[str], str] | AliasGenerator, + computed_field_info: ComputedFieldInfo, + computed_field_name: str, + ): + """Apply an alias_generator to alias on a ComputedFieldInfo instance if appropriate. + + Args: + alias_generator: A callable that takes a string and returns a string, or an AliasGenerator instance. + computed_field_info: The ComputedFieldInfo instance to which the alias_generator is (maybe) applied. + computed_field_name: The name of the computed field from which to generate the alias. + """ + # Apply an alias_generator if + # 1. An alias is not specified + # 2. An alias is specified, but the priority is <= 1 + + if ( + computed_field_info.alias_priority is None + or computed_field_info.alias_priority <= 1 + or computed_field_info.alias is None + ): + alias, validation_alias, serialization_alias = None, None, None + + if isinstance(alias_generator, AliasGenerator): + alias, validation_alias, serialization_alias = alias_generator.generate_aliases(computed_field_name) + elif isinstance(alias_generator, Callable): + alias = alias_generator(computed_field_name) + if not isinstance(alias, str): + raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') + + # if priority is not set, we set to 1 + # which supports the case where the alias_generator from a child class is used + # to generate an alias for a field in a parent class + if computed_field_info.alias_priority is None or computed_field_info.alias_priority <= 1: + computed_field_info.alias_priority = 1 + + # if the priority is 1, then we set the aliases to the generated alias + # note that we use the serialization_alias with priority over alias, as computed_field + # aliases are used for serialization only (not validation) + if computed_field_info.alias_priority == 1: + computed_field_info.alias = serialization_alias or alias + def _common_field_schema( # C901 self, name: str, field_info: FieldInfo, decorators: DecoratorInfos ) -> _CommonField: @@ -1659,20 +1702,12 @@ def _computed_field_schema( filter_field_decorator_info_by_field(field_serializers.values(), d.cls_var_name), computed_field=True, ) - # Handle alias_generator using similar logic to that from - # pydantic._internal._generate_schema.GenerateSchema._common_field_schema, - # with field_info -> d.info and name -> d.cls_var_name + alias_generator = self._config_wrapper.alias_generator - if alias_generator and (d.info.alias_priority is None or d.info.alias_priority <= 1): - alias = None - if isinstance(alias_generator, AliasGenerator) and alias_generator.alias is not None: - alias = alias_generator.alias(d.cls_var_name) - elif isinstance(alias_generator, Callable): - alias = alias_generator(d.cls_var_name) - if not isinstance(alias, str): - raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') - d.info.alias = alias - d.info.alias_priority = 1 + if alias_generator is not None: + self._apply_alias_generator_to_computed_field_info( + alias_generator=alias_generator, computed_field_info=d.info, computed_field_name=d.cls_var_name + ) def set_computed_field_metadata(schema: CoreSchemaOrField, handler: GetJsonSchemaHandler) -> JsonSchemaValue: json_schema = handler(schema) diff --git a/tests/test_aliases.py b/tests/test_aliases.py index fb21655748..4072f1446d 100644 --- a/tests/test_aliases.py +++ b/tests/test_aliases.py @@ -702,3 +702,26 @@ class Foo(BaseModel): assert f.a == 'a' assert f.model_dump(by_alias=True) == {'a_ser_alias': 'a'} assert f.model_dump(by_alias=False) == {'a': 'a'} + + +def test_alias_generator_with_computed_field_for_serialization() -> None: + """Tests that the alias generator is used for computed fields, with serialization_alias taking precedence over alias.""" + + class Rectangle(BaseModel): + model_config = ConfigDict( + alias_generator=AliasGenerator( + validation_alias=lambda field_name: f'{field_name}_val_alias', + alias=lambda field_name: f'{field_name}_alias', + serialization_alias=lambda field_name: f'{field_name}_ser_alias', + ) + ) + + width: int + height: int + + @computed_field + def area(self) -> int: + return self.width * self.height + + r = Rectangle(width_val_alias=10, height_val_alias=20) + assert r.model_dump(by_alias=True) == {'width_ser_alias': 10, 'height_ser_alias': 20, 'area_ser_alias': 200}