From 61b1b0a2aa045ac50730dc7e02fe00361425673d Mon Sep 17 00:00:00 2001 From: John Carter Date: Tue, 24 Dec 2019 10:02:57 +1300 Subject: [PATCH] Pass model_class to schema_extra staticmethod Resolves #1122 --- docs/examples/schema_extra_callable.py | 10 ++++++---- pydantic/schema.py | 7 ++++++- tests/test_schema.py | 27 ++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/docs/examples/schema_extra_callable.py b/docs/examples/schema_extra_callable.py index c4d518cebd0..e304b804241 100644 --- a/docs/examples/schema_extra_callable.py +++ b/docs/examples/schema_extra_callable.py @@ -1,5 +1,5 @@ # output-json -from typing import Dict, Any +from typing import Dict, Any, Type from pydantic import BaseModel class Person(BaseModel): @@ -8,8 +8,10 @@ class Person(BaseModel): class Config: @staticmethod - def schema_extra(schema: Dict[str, Any]) -> None: - for prop in schema.get('properties', {}).values(): - prop.pop('title', None) + def schema_extra( + schema: Dict[str, Any], model_class: Type["Person"] + ) -> None: + for prop in schema.get("properties", {}).values(): + prop.pop("title", None) print(Person.schema_json(indent=2)) diff --git a/pydantic/schema.py b/pydantic/schema.py index 528525a5883..18c3f0b5680 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -6,6 +6,7 @@ from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from pathlib import Path +from types import FunctionType from typing import ( TYPE_CHECKING, Any, @@ -464,7 +465,11 @@ def model_process_schema( s.update(m_schema) schema_extra = model.__config__.schema_extra if callable(schema_extra): - schema_extra(s) + assert isinstance(schema_extra, FunctionType), 'Config.schema_extra callable is expected to be a staticmethod' + if len(inspect.signature(schema_extra).parameters) == 1: + schema_extra(s) + else: + schema_extra(s, model) else: s.update(schema_extra) return s, m_definitions, nested_models diff --git a/tests/test_schema.py b/tests/test_schema.py index 3a96360c0fc..ebf9cc5cecb 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1490,6 +1490,19 @@ class Config: def test_model_with_schema_extra_callable(): + class Model(BaseModel): + name: str = None + + class Config: + @staticmethod + def schema_extra(schema, model_class): + schema.pop('properties') + schema['type'] = 'override' + + assert Model.schema() == {'title': 'Model', 'type': 'override'} + + +def test_model_with_schema_extra_callable_no_model_class(): class Model(BaseModel): name: str = None @@ -1502,6 +1515,20 @@ def schema_extra(schema): assert Model.schema() == {'title': 'Model', 'type': 'override'} +def test_model_with_schema_extra_callable_classmethod_asserts(): + class Model(BaseModel): + name: str = None + + class Config: + @classmethod + def schema_extra(cls, schema, model_class): + schema.pop('properties') + schema['type'] = 'override' + + with pytest.raises(AssertionError, match='Config.schema_extra callable is expected to be a staticmethod'): + Model.schema() + + def test_model_with_extra_forbidden(): class Model(BaseModel): a: str