From cd8b504568a7da3c0649347e2bcec4e3a4e79143 Mon Sep 17 00:00:00 2001 From: John Carter Date: Tue, 7 Jan 2020 01:01:03 +1300 Subject: [PATCH] Pass model_class to schema_extra staticmethod (#1125) * Pass model_class to schema_extra staticmethod Resolves #1122 * Add changelog * Apply suggestions from code review Co-Authored-By: Samuel Colvin * Fix import after rebase * Fix test bug * Use TypeError instead of assert as per review * Rename var so declaration fits one one line * tiny tweaks Co-authored-by: Samuel Colvin --- changes/1125-therefromhere.md | 1 + docs/examples/schema_extra_callable.py | 4 ++-- pydantic/schema.py | 10 +++++++-- tests/test_schema.py | 28 ++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 4 deletions(-) create mode 100644 changes/1125-therefromhere.md diff --git a/changes/1125-therefromhere.md b/changes/1125-therefromhere.md new file mode 100644 index 0000000000..5fa216d499 --- /dev/null +++ b/changes/1125-therefromhere.md @@ -0,0 +1 @@ +Pass model class to the `Config.schema_extra` callable diff --git a/docs/examples/schema_extra_callable.py b/docs/examples/schema_extra_callable.py index c4d518cebd..d48b5a7fa9 100644 --- a/docs/examples/schema_extra_callable.py +++ b/docs/examples/schema_extra_callable.py @@ -1,5 +1,5 @@ # output-json -from typing import Dict, Any +from typing import Dict, Any, Type from pydantic import BaseModel class Person(BaseModel): @@ -8,7 +8,7 @@ class Person(BaseModel): class Config: @staticmethod - def schema_extra(schema: Dict[str, Any]) -> None: + def schema_extra(schema: Dict[str, Any], model: Type['Person']) -> None: for prop in schema.get('properties', {}).values(): prop.pop('title', None) diff --git a/pydantic/schema.py b/pydantic/schema.py index 60b2f23dec..e9482201ba 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -5,6 +5,7 @@ from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from pathlib import Path +from types import FunctionType from typing import ( TYPE_CHECKING, Any, @@ -450,7 +451,7 @@ def model_process_schema( sub-models of the returned schema will be referenced, but their definitions will not be included in the schema. All the definitions are returned as the second value. """ - from inspect import getdoc + from inspect import getdoc, signature ref_prefix = ref_prefix or default_prefix known_models = known_models or set() @@ -465,7 +466,12 @@ def model_process_schema( s.update(m_schema) schema_extra = model.__config__.schema_extra if callable(schema_extra): - schema_extra(s) + if not isinstance(schema_extra, FunctionType): + raise TypeError(f'{model.__name__}.Config.schema_extra callable is expected to be a staticmethod') + if len(signature(schema_extra).parameters) == 1: + schema_extra(s) + else: + schema_extra(s, model) else: s.update(schema_extra) return s, m_definitions, nested_models diff --git a/tests/test_schema.py b/tests/test_schema.py index dd54c91f51..f01e534a20 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1490,6 +1490,20 @@ class Config: def test_model_with_schema_extra_callable(): + class Model(BaseModel): + name: str = None + + class Config: + @staticmethod + def schema_extra(schema, model_class): + schema.pop('properties') + schema['type'] = 'override' + assert model_class is Model + + assert Model.schema() == {'title': 'Model', 'type': 'override'} + + +def test_model_with_schema_extra_callable_no_model_class(): class Model(BaseModel): name: str = None @@ -1502,6 +1516,20 @@ def schema_extra(schema): assert Model.schema() == {'title': 'Model', 'type': 'override'} +def test_model_with_schema_extra_callable_classmethod_asserts(): + class Model(BaseModel): + name: str = None + + class Config: + @classmethod + def schema_extra(cls, schema, model_class): + schema.pop('properties') + schema['type'] = 'override' + + with pytest.raises(TypeError, match='Model.Config.schema_extra callable is expected to be a staticmethod'): + Model.schema() + + def test_model_with_extra_forbidden(): class Model(BaseModel): a: str