Skip to content

Commit

Permalink
Pass model_class to schema_extra staticmethod (#1125)
Browse files Browse the repository at this point in the history
* Pass model_class to schema_extra staticmethod

Resolves #1122

* Add changelog

* Apply suggestions from code review

Co-Authored-By: Samuel Colvin <samcolvin@gmail.com>

* Fix import after rebase

* Fix test bug

* Use TypeError instead of assert as per review

* Rename var so declaration fits one one line

* tiny tweaks

Co-authored-by: Samuel Colvin <samcolvin@gmail.com>
  • Loading branch information
therefromhere and samuelcolvin committed Jan 6, 2020
1 parent e169bd6 commit cd8b504
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 4 deletions.
1 change: 1 addition & 0 deletions changes/1125-therefromhere.md
@@ -0,0 +1 @@
Pass model class to the `Config.schema_extra` callable
4 changes: 2 additions & 2 deletions docs/examples/schema_extra_callable.py
@@ -1,5 +1,5 @@
# output-json
from typing import Dict, Any
from typing import Dict, Any, Type
from pydantic import BaseModel

class Person(BaseModel):
Expand All @@ -8,7 +8,7 @@ class Person(BaseModel):

class Config:
@staticmethod
def schema_extra(schema: Dict[str, Any]) -> None:
def schema_extra(schema: Dict[str, Any], model: Type['Person']) -> None:
for prop in schema.get('properties', {}).values():
prop.pop('title', None)

Expand Down
10 changes: 8 additions & 2 deletions pydantic/schema.py
Expand Up @@ -5,6 +5,7 @@
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -450,7 +451,7 @@ def model_process_schema(
sub-models of the returned schema will be referenced, but their definitions will not be included in the schema. All
the definitions are returned as the second value.
"""
from inspect import getdoc
from inspect import getdoc, signature

ref_prefix = ref_prefix or default_prefix
known_models = known_models or set()
Expand All @@ -465,7 +466,12 @@ def model_process_schema(
s.update(m_schema)
schema_extra = model.__config__.schema_extra
if callable(schema_extra):
schema_extra(s)
if not isinstance(schema_extra, FunctionType):
raise TypeError(f'{model.__name__}.Config.schema_extra callable is expected to be a staticmethod')
if len(signature(schema_extra).parameters) == 1:
schema_extra(s)
else:
schema_extra(s, model)
else:
s.update(schema_extra)
return s, m_definitions, nested_models
Expand Down
28 changes: 28 additions & 0 deletions tests/test_schema.py
Expand Up @@ -1490,6 +1490,20 @@ class Config:


def test_model_with_schema_extra_callable():
class Model(BaseModel):
name: str = None

class Config:
@staticmethod
def schema_extra(schema, model_class):
schema.pop('properties')
schema['type'] = 'override'
assert model_class is Model

assert Model.schema() == {'title': 'Model', 'type': 'override'}


def test_model_with_schema_extra_callable_no_model_class():
class Model(BaseModel):
name: str = None

Expand All @@ -1502,6 +1516,20 @@ def schema_extra(schema):
assert Model.schema() == {'title': 'Model', 'type': 'override'}


def test_model_with_schema_extra_callable_classmethod_asserts():
class Model(BaseModel):
name: str = None

class Config:
@classmethod
def schema_extra(cls, schema, model_class):
schema.pop('properties')
schema['type'] = 'override'

with pytest.raises(TypeError, match='Model.Config.schema_extra callable is expected to be a staticmethod'):
Model.schema()


def test_model_with_extra_forbidden():
class Model(BaseModel):
a: str
Expand Down

0 comments on commit cd8b504

Please sign in to comment.