Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass model_class to schema_extra staticmethod #1125

Merged
merged 8 commits into from Jan 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/1125-therefromhere.md
@@ -0,0 +1 @@
Pass model class to the `Config.schema_extra` callable
4 changes: 2 additions & 2 deletions docs/examples/schema_extra_callable.py
@@ -1,5 +1,5 @@
# output-json
from typing import Dict, Any
from typing import Dict, Any, Type
from pydantic import BaseModel

class Person(BaseModel):
Expand All @@ -8,7 +8,7 @@ class Person(BaseModel):

class Config:
@staticmethod
def schema_extra(schema: Dict[str, Any]) -> None:
def schema_extra(schema: Dict[str, Any], model: Type['Person']) -> None:
for prop in schema.get('properties', {}).values():
prop.pop('title', None)

Expand Down
10 changes: 8 additions & 2 deletions pydantic/schema.py
Expand Up @@ -5,6 +5,7 @@
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -450,7 +451,7 @@ def model_process_schema(
sub-models of the returned schema will be referenced, but their definitions will not be included in the schema. All
the definitions are returned as the second value.
"""
from inspect import getdoc
from inspect import getdoc, signature

ref_prefix = ref_prefix or default_prefix
known_models = known_models or set()
Expand All @@ -465,7 +466,12 @@ def model_process_schema(
s.update(m_schema)
schema_extra = model.__config__.schema_extra
if callable(schema_extra):
schema_extra(s)
if not isinstance(schema_extra, FunctionType):
raise TypeError(f'{model.__name__}.Config.schema_extra callable is expected to be a staticmethod')
if len(signature(schema_extra).parameters) == 1:
schema_extra(s)
else:
schema_extra(s, model)
else:
s.update(schema_extra)
return s, m_definitions, nested_models
Expand Down
28 changes: 28 additions & 0 deletions tests/test_schema.py
Expand Up @@ -1490,6 +1490,20 @@ class Config:


def test_model_with_schema_extra_callable():
class Model(BaseModel):
name: str = None

class Config:
@staticmethod
def schema_extra(schema, model_class):
schema.pop('properties')
schema['type'] = 'override'
therefromhere marked this conversation as resolved.
Show resolved Hide resolved
assert model_class is Model

assert Model.schema() == {'title': 'Model', 'type': 'override'}


def test_model_with_schema_extra_callable_no_model_class():
class Model(BaseModel):
name: str = None

Expand All @@ -1502,6 +1516,20 @@ def schema_extra(schema):
assert Model.schema() == {'title': 'Model', 'type': 'override'}


def test_model_with_schema_extra_callable_classmethod_asserts():
class Model(BaseModel):
name: str = None

class Config:
@classmethod
def schema_extra(cls, schema, model_class):
schema.pop('properties')
schema['type'] = 'override'

with pytest.raises(TypeError, match='Model.Config.schema_extra callable is expected to be a staticmethod'):
Model.schema()


def test_model_with_extra_forbidden():
class Model(BaseModel):
a: str
Expand Down