Skip to content

Commit

Permalink
Pass model_class to schema_extra staticmethod
Browse files Browse the repository at this point in the history
Resolves #1122
  • Loading branch information
therefromhere committed Dec 23, 2019
1 parent e65d112 commit 61b1b0a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
10 changes: 6 additions & 4 deletions docs/examples/schema_extra_callable.py
@@ -1,5 +1,5 @@
# output-json
from typing import Dict, Any
from typing import Dict, Any, Type
from pydantic import BaseModel

class Person(BaseModel):
Expand All @@ -8,8 +8,10 @@ class Person(BaseModel):

class Config:
@staticmethod
def schema_extra(schema: Dict[str, Any]) -> None:
for prop in schema.get('properties', {}).values():
prop.pop('title', None)
def schema_extra(
schema: Dict[str, Any], model_class: Type["Person"]
) -> None:
for prop in schema.get("properties", {}).values():
prop.pop("title", None)

print(Person.schema_json(indent=2))
7 changes: 6 additions & 1 deletion pydantic/schema.py
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -464,7 +465,11 @@ def model_process_schema(
s.update(m_schema)
schema_extra = model.__config__.schema_extra
if callable(schema_extra):
schema_extra(s)
assert isinstance(schema_extra, FunctionType), 'Config.schema_extra callable is expected to be a staticmethod'
if len(inspect.signature(schema_extra).parameters) == 1:
schema_extra(s)
else:
schema_extra(s, model)
else:
s.update(schema_extra)
return s, m_definitions, nested_models
Expand Down
27 changes: 27 additions & 0 deletions tests/test_schema.py
Expand Up @@ -1490,6 +1490,19 @@ class Config:


def test_model_with_schema_extra_callable():
class Model(BaseModel):
name: str = None

class Config:
@staticmethod
def schema_extra(schema, model_class):
schema.pop('properties')
schema['type'] = 'override'

assert Model.schema() == {'title': 'Model', 'type': 'override'}


def test_model_with_schema_extra_callable_no_model_class():
class Model(BaseModel):
name: str = None

Expand All @@ -1502,6 +1515,20 @@ def schema_extra(schema):
assert Model.schema() == {'title': 'Model', 'type': 'override'}


def test_model_with_schema_extra_callable_classmethod_asserts():
class Model(BaseModel):
name: str = None

class Config:
@classmethod
def schema_extra(cls, schema, model_class):
schema.pop('properties')
schema['type'] = 'override'

with pytest.raises(AssertionError, match='Config.schema_extra callable is expected to be a staticmethod'):
Model.schema()


def test_model_with_extra_forbidden():
class Model(BaseModel):
a: str
Expand Down

0 comments on commit 61b1b0a

Please sign in to comment.