From 626531d006ea297c1c2f6156c825e04e554e1e36 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 14 Jun 2023 11:12:33 -0500 Subject: [PATCH] Make JSON Schema work --- pydantic/json_schema.py | 27 ++++++++++++++++++++++----- tests/test_main.py | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/pydantic/json_schema.py b/pydantic/json_schema.py index 4edd9cef900..15da6c05f56 100644 --- a/pydantic/json_schema.py +++ b/pydantic/json_schema.py @@ -995,10 +995,11 @@ def _update_class_schema( # referenced_schema['title'] = title schema_to_update.setdefault('title', title) - if extra == 'allow': - schema_to_update['additionalProperties'] = True - elif extra == 'forbid': - schema_to_update['additionalProperties'] = False + if 'additionalProperties' not in schema_to_update: + if extra == 'allow': + schema_to_update['additionalProperties'] = True + elif extra == 'forbid': + schema_to_update['additionalProperties'] = False if isinstance(json_schema_extra, (staticmethod, classmethod)): # In older versions of python, this is necessary to ensure staticmethod/classmethods are callable @@ -1018,6 +1019,17 @@ def _update_class_schema( return json_schema + def resolve_schema_to_update(self, json_schema: JsonSchemaValue) -> JsonSchemaValue: + """Resolve a JsonSchemaValue to the non-ref schema if it is a $ref schema""" + if '$ref' in json_schema: + schema_to_update = self.get_schema_from_definitions(JsonRef(json_schema['$ref'])) + if schema_to_update is None: + raise RuntimeError(f'Cannot update undefined schema for $ref={json_schema["$ref"]}') + return self.resolve_schema_to_update(schema_to_update) + else: + schema_to_update = json_schema + return schema_to_update + def model_fields_schema(self, schema: core_schema.ModelFieldsSchema) -> JsonSchemaValue: named_required_fields: list[tuple[str, bool, CoreSchemaField]] = [ (name, self.field_is_required(field), field) @@ -1026,7 +1038,12 @@ def model_fields_schema(self, schema: core_schema.ModelFieldsSchema) -> JsonSche ] if self.mode == 'serialization': named_required_fields.extend(self._name_required_computed_fields(schema.get('computed_fields', []))) - return self._named_required_fields_schema(named_required_fields) + json_schema = self._named_required_fields_schema(named_required_fields) + extra_validator = schema.get('extra_validator', None) + if extra_validator is not None: + schema_to_update = self.resolve_schema_to_update(json_schema) + schema_to_update['additionalProperties'] = self.generate_inner(extra_validator) + return json_schema def field_is_present(self, field: CoreSchemaField) -> bool: """Whether the field should be included in the generated JSON schema.""" diff --git a/tests/test_main.py b/tests/test_main.py index 5a5fe77cb05..f5ca776beee 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2446,7 +2446,7 @@ def set_x(m: 'OuterModel') -> 'OuterModel': assert OuterModel(inner=InnerModel()).x == 2 -def test_extra_validator(): +def test_extra_validator_scalar() -> None: class Model(BaseModel): model_config = ConfigDict(extra='allow') @@ -2455,3 +2455,41 @@ class Child(Model): m = Child(a='1') assert m.__pydantic_extra__ == {'a': 1} + + # insert_assert(Child.model_json_schema()) + assert Child.model_json_schema() == { + 'additionalProperties': {'type': 'integer'}, + 'properties': {}, + 'title': 'Child', + 'type': 'object', + } + + +def test_extra_validator_named() -> None: + class Foo(BaseModel): + x: int + + class Model(BaseModel): + model_config = ConfigDict(extra='allow') + + class Child(Model): + __pydantic_extra__: Dict[str, Foo] + + m = Child(a={'x': '1'}) + assert m.__pydantic_extra__ == {'a': Foo(x=1)} + + # insert_assert(Child.model_json_schema()) + assert Child.model_json_schema() == { + '$defs': { + 'Foo': { + 'properties': {'x': {'title': 'X', 'type': 'integer'}}, + 'required': ['x'], + 'title': 'Foo', + 'type': 'object', + } + }, + 'additionalProperties': {'$ref': '#/$defs/Foo'}, + 'properties': {}, + 'title': 'Child', + 'type': 'object', + }