Skip to content

Commit

Permalink
Make path of the item to validate available in plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
hramezani committed Oct 18, 2023
1 parent 188018c commit a2c10ac
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pydantic/_internal/_dataclasses.py
Expand Up @@ -170,7 +170,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -

cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, core_config, config_wrapper.plugin_settings
schema, f'{cls.__module__}:{cls.__name__}', core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)

Expand Down
16 changes: 15 additions & 1 deletion pydantic/_internal/_model_construction.py
@@ -1,6 +1,7 @@
"""Private logic for creating models."""
from __future__ import annotations as _annotations

import sys
import typing
import warnings
import weakref
Expand Down Expand Up @@ -94,6 +95,7 @@ def __new__(
# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact
# that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__`
# call we're in the middle of is for the `BaseModel` class.
is_dynamic_model: bool = kwargs.pop('is_dynamic_model', False)
if bases:
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)

Expand Down Expand Up @@ -185,6 +187,7 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
config_wrapper,
raise_errors=False,
types_namespace=types_namespace,
is_dynamic_model=is_dynamic_model,
)
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
Expand Down Expand Up @@ -439,6 +442,7 @@ def complete_model_class(
*,
raise_errors: bool = True,
types_namespace: dict[str, Any] | None,
is_dynamic_model: bool = False,
) -> bool:
"""Finish building a model class.
Expand All @@ -451,6 +455,7 @@ def complete_model_class(
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors.
types_namespace: Optional extra namespace to look for types in.
is_dynamic_model: Whether the model is a dynamic model (function is called from `create_model`).
Returns:
`True` if the model is successfully completed, else `False`.
Expand Down Expand Up @@ -494,7 +499,16 @@ def complete_model_class(

# debug(schema)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)

if is_dynamic_model:
f = sys._getframe(3)
cls_module = f.f_globals['__name__']
else:
cls_module = cls.__module__

cls.__pydantic_validator__ = create_schema_validator(
schema, f'{cls_module}:{cls_name}', core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True

Expand Down
6 changes: 4 additions & 2 deletions pydantic/_internal/_validate_call.py
Expand Up @@ -64,7 +64,9 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
self.__pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)
self.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
path = f'{self.__module__}:{self.__name__}'

self.__pydantic_validator__ = create_schema_validator(schema, path, core_config, config_wrapper.plugin_settings)

if self._validate_return:
return_type = (
Expand All @@ -75,7 +77,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
self.__return_pydantic_core_schema__ = schema
validator = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
validator = create_schema_validator(schema, path, core_config, config_wrapper.plugin_settings)
if inspect.iscoroutinefunction(self.raw_function):

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
Expand Down
1 change: 1 addition & 0 deletions pydantic/main.py
Expand Up @@ -1425,6 +1425,7 @@ def create_model(
if resolved_bases is not __base__:
ns['__orig_bases__'] = __base__
namespace.update(ns)
kwds['is_dynamic_model'] = True
return meta(__model_name, resolved_bases, namespace, __pydantic_reset_parent_namespace__=False, **kwds)


Expand Down
2 changes: 2 additions & 0 deletions pydantic/plugin/__init__.py
Expand Up @@ -27,6 +27,7 @@ class PydanticPluginProtocol(Protocol):
def new_schema_validator(
self,
schema: CoreSchema,
path: str,
config: CoreConfig | None,
plugin_settings: dict[str, object],
) -> tuple[
Expand All @@ -40,6 +41,7 @@ def new_schema_validator(
Args:
schema: The schema to validate against.
path: The path of item to validate against.
config: The config to use for validation.
plugin_settings: Any plugin settings.
Expand Down
7 changes: 4 additions & 3 deletions pydantic/plugin/_schema_validator.py
Expand Up @@ -18,7 +18,7 @@


def create_schema_validator(
schema: CoreSchema, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None
schema: CoreSchema, path: str, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None
) -> SchemaValidator:
"""Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
Expand All @@ -29,7 +29,7 @@ def create_schema_validator(

plugins = get_plugins()
if plugins:
return PluggableSchemaValidator(schema, config, plugins, plugin_settings or {}) # type: ignore
return PluggableSchemaValidator(schema, path, config, plugins, plugin_settings or {}) # type: ignore
else:
return SchemaValidator(schema, config)

Expand All @@ -42,6 +42,7 @@ class PluggableSchemaValidator:
def __init__(
self,
schema: CoreSchema,
path: str,
config: CoreConfig | None,
plugins: Iterable[PydanticPluginProtocol],
plugin_settings: dict[str, Any],
Expand All @@ -52,7 +53,7 @@ def __init__(
json_event_handlers: list[BaseValidateHandlerProtocol] = []
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
for plugin in plugins:
p, j, s = plugin.new_schema_validator(schema, config, plugin_settings)
p, j, s = plugin.new_schema_validator(schema, path, config, plugin_settings)
if p is not None:
python_event_handlers.append(p)
if j is not None:
Expand Down
5 changes: 4 additions & 1 deletion pydantic/type_adapter.py
Expand Up @@ -245,7 +245,10 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth
try:
validator = _getattr_no_parents(type, '__pydantic_validator__')
except AttributeError:
validator = create_schema_validator(core_schema, core_config, config_wrapper.plugin_settings)
f = sys._getframe(1)
validator = create_schema_validator(
core_schema, f'{f.f_globals["__name__"]}:{type}', core_config, config_wrapper.plugin_settings
)

serializer: SchemaSerializer
try:
Expand Down
94 changes: 86 additions & 8 deletions tests/test_plugins.py
@@ -1,11 +1,12 @@
from __future__ import annotations

import contextlib
from typing import Any, Generator
from functools import partial
from typing import Any, Generator, List

from pydantic_core import ValidationError

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, TypeAdapter, create_model, dataclasses, field_validator, validate_call
from pydantic.plugin import (
PydanticPluginProtocol,
ValidateJsonHandlerProtocol,
Expand Down Expand Up @@ -43,9 +44,10 @@ def on_success(self, result: Any) -> None:
assert isinstance(result, Model)

class CustomPlugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, config, plugin_settings):
def new_schema_validator(self, schema, path, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
assert path == 'tests.test_plugins:Model'
return None, CustomOnValidateJson(), None

plugin = CustomPlugin()
Expand Down Expand Up @@ -85,7 +87,7 @@ def on_error(self, error: ValidationError) -> None:
]

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, config, plugin_settings):
def new_schema_validator(self, schema, path, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
return None, CustomOnValidateJson(), None
Expand Down Expand Up @@ -121,7 +123,7 @@ def on_success(self, result: Any) -> None:
assert isinstance(result, Model)

class Plugin:
def new_schema_validator(self, schema, config, plugin_settings):
def new_schema_validator(self, schema, path, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
return CustomOnValidatePython(), None, None
Expand Down Expand Up @@ -164,7 +166,7 @@ def on_error(self, error: ValidationError) -> None:
]

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, config, plugin_settings):
def new_schema_validator(self, schema, path, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
return CustomOnValidatePython(), None, None
Expand Down Expand Up @@ -205,7 +207,7 @@ def on_exception(self, exception: Exception) -> None:
stack.pop()

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, config, plugin_settings):
def new_schema_validator(self, schema, path, config, plugin_settings):
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand Down Expand Up @@ -268,7 +270,7 @@ def on_error(self, error: ValidationError) -> None:
log.append(f'strings error error={error}')

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, config, plugin_settings):
def new_schema_validator(self, schema, path, config, plugin_settings):
return Python(), Json(), Strings()

plugin = Plugin()
Expand All @@ -294,3 +296,79 @@ class Model(BaseModel):
"strings enter input={'a': '3'} kwargs={'strict': True, 'context': {'c': 3}}",
'strings success result=a=3',
]


def test_plugin_path_dataclass() -> None:
class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, path, config, plugin_settings):
assert path == 'tests.test_plugins:Bar'
return CustomOnValidatePython(), None, None

plugin = Plugin()
with install_plugin(plugin):

@dataclasses.dataclass
class Bar:
a: int


def test_plugin_path_type_adapter() -> None:
class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, path, config, plugin_settings):
assert path == 'tests.test_plugins:List[str]'
return CustomOnValidatePython(), None, None

plugin = Plugin()
with install_plugin(plugin):
TypeAdapter(List[str])


def test_plugin_path_validate_call() -> None:
class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin1:
def new_schema_validator(self, schema, path, config, plugin_settings):
assert path == 'tests.test_plugins:foo'
return CustomOnValidatePython(), None, None

plugin = Plugin1()
with install_plugin(plugin):

@validate_call()
def foo(a: int):
return a

class Plugin2:
def new_schema_validator(self, schema, path, config, plugin_settings):
assert path == 'tests.test_plugins:partial(my_wrapped_function)'
return CustomOnValidatePython(), None, None

plugin = Plugin2()
with install_plugin(plugin):

def my_wrapped_function(a: int, b: int, c: int):
return a + b + c

my_partial_function = partial(my_wrapped_function, c=3)
validate_call(my_partial_function)


def test_plugin_path_create_model() -> None:
class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, path, config, plugin_settings):
assert path == 'tests.test_plugins:FooModel'
return CustomOnValidatePython(), None, None

plugin = Plugin()
with install_plugin(plugin):
create_model('FooModel', foo=(str, ...), bar=(int, 123))

0 comments on commit a2c10ac

Please sign in to comment.