From e4a6dba2af504777a20fd66ad0da05e6d3283345 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 26 Jul 2023 12:13:08 -0600 Subject: [PATCH] Make it harder to hit collisions with json schema defrefs (#6566) --- pydantic/json_schema.py | 6 +- tests/test_json_schema.py | 194 ++++++++++++++++++++++++++++++++++---- 2 files changed, 181 insertions(+), 19 deletions(-) diff --git a/pydantic/json_schema.py b/pydantic/json_schema.py index 7bf152298f..2a03b60142 100644 --- a/pydantic/json_schema.py +++ b/pydantic/json_schema.py @@ -1808,9 +1808,9 @@ def get_defs_ref(self, core_mode_ref: CoreModeRef) -> DefsRef: # be generated for any other core_ref. Currently, this should be the case because we include # the id of the source type in the core_ref name = DefsRef(self.normalize_name(short_ref)) - name_mode = DefsRef(self.normalize_name(short_ref + mode_title)) + name_mode = DefsRef(self.normalize_name(short_ref) + f'-{mode_title}') module_qualname = DefsRef(self.normalize_name(core_ref_no_id)) - module_qualname_mode = DefsRef(module_qualname + mode_title) + module_qualname_mode = DefsRef(f'{module_qualname}-{mode_title}') module_qualname_id = DefsRef(self.normalize_name(core_ref)) occurrence_index = self._collision_index.get(module_qualname_id) if occurrence_index is None: @@ -1818,7 +1818,7 @@ def get_defs_ref(self, core_mode_ref: CoreModeRef) -> DefsRef: occurrence_index = self._collision_index[module_qualname_id] = self._collision_counter[module_qualname] module_qualname_occurrence = DefsRef(f'{module_qualname}__{occurrence_index}') - module_qualname_occurrence_mode = DefsRef(f'{module_qualname}{mode_title}__{occurrence_index}') + module_qualname_occurrence_mode = DefsRef(f'{module_qualname_mode}__{occurrence_index}') self._prioritized_defsref_choices[module_qualname_occurrence_mode] = [ name, diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index ce630f3c65..480694d5cd 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -49,6 +49,7 @@ ValidationError, WithJsonSchema, computed_field, + field_serializer, field_validator, ) from pydantic._internal._core_metadata import CoreMetadataHandler, build_metadata_dict @@ -2673,16 +2674,16 @@ class NestedModel(BaseModel): ) model_names = set(schema['$defs'].keys()) expected_model_names = { - 'ModelOneInput', - 'ModelOneOutput', - 'ModelTwoInput', - 'ModelTwoOutput', - f'{module.__name__}__ModelOne__NestedModelInput', - f'{module.__name__}__ModelOne__NestedModelOutput', - f'{module.__name__}__ModelTwo__NestedModelInput', - f'{module.__name__}__ModelTwo__NestedModelOutput', - f'{module.__name__}__NestedModelInput', - f'{module.__name__}__NestedModelOutput', + 'ModelOne-Input', + 'ModelOne-Output', + 'ModelTwo-Input', + 'ModelTwo-Output', + f'{module.__name__}__ModelOne__NestedModel-Input', + f'{module.__name__}__ModelOne__NestedModel-Output', + f'{module.__name__}__ModelTwo__NestedModel-Input', + f'{module.__name__}__ModelTwo__NestedModel-Output', + f'{module.__name__}__NestedModel-Input', + f'{module.__name__}__NestedModel-Output', } assert model_names == expected_model_names @@ -2739,6 +2740,167 @@ class Model(BaseModel): } +def test_mode_name_causes_no_conflict(): + class Organization(BaseModel): + pass + + class OrganizationInput(BaseModel): + pass + + class OrganizationOutput(BaseModel): + pass + + class Model(BaseModel): + # Ensure the validation and serialization schemas are different: + x: Organization = Field(validation_alias='x_validation', serialization_alias='x_serialization') + y: OrganizationInput + z: OrganizationOutput + + assert Model.model_json_schema(mode='validation') == { + '$defs': { + 'Organization': {'properties': {}, 'title': 'Organization', 'type': 'object'}, + 'OrganizationInput': {'properties': {}, 'title': 'OrganizationInput', 'type': 'object'}, + 'OrganizationOutput': {'properties': {}, 'title': 'OrganizationOutput', 'type': 'object'}, + }, + 'properties': { + 'x_validation': {'$ref': '#/$defs/Organization'}, + 'y': {'$ref': '#/$defs/OrganizationInput'}, + 'z': {'$ref': '#/$defs/OrganizationOutput'}, + }, + 'required': ['x_validation', 'y', 'z'], + 'title': 'Model', + 'type': 'object', + } + assert Model.model_json_schema(mode='serialization') == { + '$defs': { + 'Organization': {'properties': {}, 'title': 'Organization', 'type': 'object'}, + 'OrganizationInput': {'properties': {}, 'title': 'OrganizationInput', 'type': 'object'}, + 'OrganizationOutput': {'properties': {}, 'title': 'OrganizationOutput', 'type': 'object'}, + }, + 'properties': { + 'x_serialization': {'$ref': '#/$defs/Organization'}, + 'y': {'$ref': '#/$defs/OrganizationInput'}, + 'z': {'$ref': '#/$defs/OrganizationOutput'}, + }, + 'required': ['x_serialization', 'y', 'z'], + 'title': 'Model', + 'type': 'object', + } + + +def test_ref_conflict_resolution_without_mode_difference(): + class OrganizationInput(BaseModel): + pass + + class Organization(BaseModel): + x: int + + schema_with_defs, defs = GenerateJsonSchema().generate_definitions( + [ + (Organization, 'validation', Organization.__pydantic_core_schema__), + (Organization, 'serialization', Organization.__pydantic_core_schema__), + (OrganizationInput, 'validation', OrganizationInput.__pydantic_core_schema__), + ] + ) + assert schema_with_defs == { + (Organization, 'serialization'): {'$ref': '#/$defs/Organization'}, + (Organization, 'validation'): {'$ref': '#/$defs/Organization'}, + (OrganizationInput, 'validation'): {'$ref': '#/$defs/OrganizationInput'}, + } + + assert defs == { + 'OrganizationInput': {'properties': {}, 'title': 'OrganizationInput', 'type': 'object'}, + 'Organization': { + 'properties': {'x': {'title': 'X', 'type': 'integer'}}, + 'required': ['x'], + 'title': 'Organization', + 'type': 'object', + }, + } + + +def test_ref_conflict_resolution_with_mode_difference(): + class OrganizationInput(BaseModel): + pass + + class Organization(BaseModel): + x: int + + @field_serializer('x') + def serialize_x(self, v: int) -> str: + return str(v) + + schema_with_defs, defs = GenerateJsonSchema().generate_definitions( + [ + (Organization, 'validation', Organization.__pydantic_core_schema__), + (Organization, 'serialization', Organization.__pydantic_core_schema__), + (OrganizationInput, 'validation', OrganizationInput.__pydantic_core_schema__), + ] + ) + assert schema_with_defs == { + (Organization, 'serialization'): {'$ref': '#/$defs/Organization-Output'}, + (Organization, 'validation'): {'$ref': '#/$defs/Organization-Input'}, + (OrganizationInput, 'validation'): {'$ref': '#/$defs/OrganizationInput'}, + } + + assert defs == { + 'OrganizationInput': {'properties': {}, 'title': 'OrganizationInput', 'type': 'object'}, + 'Organization-Input': { + 'properties': {'x': {'title': 'X', 'type': 'integer'}}, + 'required': ['x'], + 'title': 'Organization', + 'type': 'object', + }, + 'Organization-Output': { + 'properties': {'x': {'title': 'X', 'type': 'string'}}, + 'required': ['x'], + 'title': 'Organization', + 'type': 'object', + }, + } + + +def test_conflicting_names(): + class Organization__Input(BaseModel): + pass + + class Organization(BaseModel): + x: int + + @field_serializer('x') + def serialize_x(self, v: int) -> str: + return str(v) + + schema_with_defs, defs = GenerateJsonSchema().generate_definitions( + [ + (Organization, 'validation', Organization.__pydantic_core_schema__), + (Organization, 'serialization', Organization.__pydantic_core_schema__), + (Organization__Input, 'validation', Organization__Input.__pydantic_core_schema__), + ] + ) + assert schema_with_defs == { + (Organization, 'serialization'): {'$ref': '#/$defs/Organization-Output'}, + (Organization, 'validation'): {'$ref': '#/$defs/Organization-Input'}, + (Organization__Input, 'validation'): {'$ref': '#/$defs/Organization__Input'}, + } + + assert defs == { + 'Organization__Input': {'properties': {}, 'title': 'Organization__Input', 'type': 'object'}, + 'Organization-Input': { + 'properties': {'x': {'title': 'X', 'type': 'integer'}}, + 'required': ['x'], + 'title': 'Organization', + 'type': 'object', + }, + 'Organization-Output': { + 'properties': {'x': {'title': 'X', 'type': 'string'}}, + 'required': ['x'], + 'title': 'Organization', + 'type': 'object', + }, + } + + def test_schema_for_generic_field(): T = TypeVar('T') @@ -4356,7 +4518,7 @@ class Outer(BaseModel): _, vs_schema = models_json_schema([(Outer, 'validation'), (Outer, 'serialization')]) assert vs_schema == { '$defs': { - 'InnerInput': { + 'Inner-Input': { 'properties': { 'x': { 'contentMediaType': 'application/json', @@ -4369,20 +4531,20 @@ class Outer(BaseModel): 'title': 'Inner', 'type': 'object', }, - 'InnerOutput': { + 'Inner-Output': { 'properties': {'x': {'title': 'X', 'type': 'integer'}}, 'required': ['x'], 'title': 'Inner', 'type': 'object', }, - 'OuterInput': { - 'properties': {'inner': {'$ref': '#/$defs/InnerInput'}}, + 'Outer-Input': { + 'properties': {'inner': {'$ref': '#/$defs/Inner-Input'}}, 'required': ['inner'], 'title': 'Outer', 'type': 'object', }, - 'OuterOutput': { - 'properties': {'inner': {'$ref': '#/$defs/InnerOutput'}}, + 'Outer-Output': { + 'properties': {'inner': {'$ref': '#/$defs/Inner-Output'}}, 'required': ['inner'], 'title': 'Outer', 'type': 'object',