Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make path of the item to validate available in plugin #7861

Merged
merged 11 commits into from Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/concepts/models.md
Expand Up @@ -1057,7 +1057,7 @@ BarModel = create_model(
__base__=FooModel,
)
print(BarModel)
#> <class 'pydantic.main.BarModel'>
#> <class '__main__.BarModel'>
print(BarModel.model_fields.keys())
#> dict_keys(['foo', 'bar', 'apple', 'banana'])
```
Expand Down
5 changes: 5 additions & 0 deletions docs/concepts/plugins.md
Expand Up @@ -62,6 +62,8 @@ from pydantic_core import CoreConfig, CoreSchema, ValidationError
from pydantic.plugin import (
NewSchemaReturns,
PydanticPluginProtocol,
SchemaKind,
SchemaTypePath,
ValidatePythonHandlerProtocol,
)

Expand Down Expand Up @@ -89,6 +91,9 @@ class Plugin(PydanticPluginProtocol):
def new_schema_validator(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: Union[CoreConfig, None],
plugin_settings: Dict[str, object],
) -> NewSchemaReturns:
Expand Down
2 changes: 1 addition & 1 deletion pydantic/_internal/_dataclasses.py
Expand Up @@ -172,7 +172,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -

cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, core_config, config_wrapper.plugin_settings
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)

Expand Down
16 changes: 15 additions & 1 deletion pydantic/_internal/_model_construction.py
Expand Up @@ -63,6 +63,7 @@ def __new__(
namespace: dict[str, Any],
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
__pydantic_reset_parent_namespace__: bool = True,
_create_model_module: str | None = None,
**kwargs: Any,
) -> type:
"""Metaclass for creating Pydantic models.
Expand All @@ -73,6 +74,7 @@ def __new__(
namespace: The attribute dictionary of the class to be created.
__pydantic_generic_metadata__: Metadata for generic models.
__pydantic_reset_parent_namespace__: Reset parent namespace.
_create_model_module: The module of the class to be created, if created by `create_model`.
**kwargs: Catch-all for any other keyword arguments.

Returns:
Expand Down Expand Up @@ -182,6 +184,7 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
config_wrapper,
raise_errors=False,
types_namespace=types_namespace,
create_model_module=_create_model_module,
)
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
Expand Down Expand Up @@ -438,6 +441,7 @@ def complete_model_class(
*,
raise_errors: bool = True,
types_namespace: dict[str, Any] | None,
create_model_module: str | None = None,
) -> bool:
"""Finish building a model class.

Expand All @@ -450,6 +454,7 @@ def complete_model_class(
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors.
types_namespace: Optional extra namespace to look for types in.
create_model_module: The module of the class to be created, if created by `create_model`.

Returns:
`True` if the model is successfully completed, else `False`.
Expand Down Expand Up @@ -493,7 +498,16 @@ def complete_model_class(

# debug(schema)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)

cls.__pydantic_validator__ = create_schema_validator(
schema,
cls,
create_model_module or cls.__module__,
cls.__qualname__,
'create_model' if create_model_module else 'BaseModel',
core_config,
config_wrapper.plugin_settings,
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True

Expand Down
23 changes: 21 additions & 2 deletions pydantic/_internal/_validate_call.py
Expand Up @@ -46,12 +46,14 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
self.__signature__ = inspect.signature(function)
if isinstance(function, partial):
func = function.func
schema_type = func
self.__name__ = f'partial({func.__name__})'
self.__qualname__ = f'partial({func.__qualname__})'
self.__annotations__ = func.__annotations__
self.__module__ = func.__module__
self.__doc__ = func.__doc__
else:
schema_type = function
self.__name__ = function.__name__
self.__qualname__ = function.__qualname__
self.__annotations__ = function.__annotations__
Expand All @@ -64,7 +66,16 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
self.__pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)
self.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)

self.__pydantic_validator__ = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)

if self._validate_return:
return_type = (
Expand All @@ -75,7 +86,15 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
self.__return_pydantic_core_schema__ = schema
validator = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
validator = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(self.raw_function):

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
Expand Down
25 changes: 19 additions & 6 deletions pydantic/main.py
@@ -1,6 +1,7 @@
"""Logic for creating models."""
from __future__ import annotations as _annotations

import sys
import types
import typing
import warnings
Expand Down Expand Up @@ -1345,13 +1346,13 @@ def create_model(
...


def create_model(
def create_model( # noqa: C901
__model_name: str,
*,
__config__: ConfigDict | None = None,
__doc__: str | None = None,
__base__: type[Model] | tuple[type[Model], ...] | None = None,
__module__: str = __name__,
__module__: str | None = None,
__validators__: dict[str, AnyClassMethod] | None = None,
__cls_kwargs__: dict[str, Any] | None = None,
__slots__: tuple[str, ...] | None = None,
Expand All @@ -1365,9 +1366,9 @@ def create_model(
__config__: The configuration of the new model.
__doc__: The docstring of the new model.
__base__: The base class for the new model.
__module__: The name of the module that the model belongs to.
__validators__: A dictionary of methods that validate
fields.
__module__: The name of the module that the model belongs to,
if `None` the value is taken from `sys._getframe(1)`
__validators__: A dictionary of methods that validate fields.
__cls_kwargs__: A dictionary of keyword arguments for class creation.
__slots__: Deprecated. Should not be passed to `create_model`.
**field_definitions: Attributes of the new model. They should be passed in the format:
Expand Down Expand Up @@ -1418,6 +1419,10 @@ def create_model(
annotations[f_name] = f_annotation
fields[f_name] = f_value

if __module__ is None:
f = sys._getframe(1)
__module__ = f.f_globals['__name__']

namespace: dict[str, Any] = {'__annotations__': annotations, '__module__': __module__}
if __doc__:
namespace.update({'__doc__': __doc__})
Expand All @@ -1431,7 +1436,15 @@ def create_model(
if resolved_bases is not __base__:
ns['__orig_bases__'] = __base__
namespace.update(ns)
return meta(__model_name, resolved_bases, namespace, __pydantic_reset_parent_namespace__=False, **kwds)

return meta(
__model_name,
resolved_bases,
namespace,
__pydantic_reset_parent_namespace__=False,
_create_model_module=__module__,
**kwds,
)


__getattr__ = getattr_migration(__name__)
22 changes: 20 additions & 2 deletions pydantic/plugin/__init__.py
Expand Up @@ -4,10 +4,10 @@
"""
from __future__ import annotations

from typing import Any, Callable
from typing import Any, Callable, NamedTuple

from pydantic_core import CoreConfig, CoreSchema, ValidationError
from typing_extensions import Protocol, TypeAlias
from typing_extensions import Literal, Protocol, TypeAlias

__all__ = (
'PydanticPluginProtocol',
Expand All @@ -16,17 +16,32 @@
'ValidateJsonHandlerProtocol',
'ValidateStringsHandlerProtocol',
'NewSchemaReturns',
'SchemaTypePath',
'SchemaKind',
)

NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]'


class SchemaTypePath(NamedTuple):
"""Path defining where `schema_type` was defined, or where `TypeAdapter` was called."""

module: str
name: str


SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call']


class PydanticPluginProtocol(Protocol):
"""Protocol defining the interface for Pydantic plugins."""

def new_schema_validator(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: CoreConfig | None,
plugin_settings: dict[str, object],
) -> tuple[
Expand All @@ -40,6 +55,9 @@ def new_schema_validator(

Args:
schema: The schema to validate against.
schema_type: The original type which the schema was created from, e.g. the model class.
schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called.
schema_kind: The kind of schema to validate against.
config: The config to use for validation.
plugin_settings: Any plugin settings.

Expand Down
28 changes: 24 additions & 4 deletions pydantic/plugin/_schema_validator.py
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import Literal, ParamSpec

if TYPE_CHECKING:
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath


P = ParamSpec('P')
Expand All @@ -18,18 +18,33 @@


def create_schema_validator(
schema: CoreSchema, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None
schema: CoreSchema,
schema_type: Any,
schema_type_module: str,
schema_type_name: str,
schema_kind: SchemaKind,
config: CoreConfig | None = None,
plugin_settings: dict[str, Any] | None = None,
) -> SchemaValidator:
"""Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.

Returns:
If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
"""
from . import SchemaTypePath
from ._loader import get_plugins

plugins = get_plugins()
if plugins:
return PluggableSchemaValidator(schema, config, plugins, plugin_settings or {}) # type: ignore
return PluggableSchemaValidator(
schema,
schema_type,
SchemaTypePath(schema_type_module, schema_type_name),
schema_kind,
config,
plugins,
plugin_settings or {},
) # type: ignore
else:
return SchemaValidator(schema, config)

Expand All @@ -42,6 +57,9 @@ class PluggableSchemaValidator:
def __init__(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: CoreConfig | None,
plugins: Iterable[PydanticPluginProtocol],
plugin_settings: dict[str, Any],
Expand All @@ -52,7 +70,9 @@ def __init__(
json_event_handlers: list[BaseValidateHandlerProtocol] = []
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
for plugin in plugins:
p, j, s = plugin.new_schema_validator(schema, config, plugin_settings)
p, j, s = plugin.new_schema_validator(
schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
)
if p is not None:
python_event_handlers.append(p)
if j is not None:
Expand Down
22 changes: 17 additions & 5 deletions pydantic/type_adapter.py
Expand Up @@ -78,7 +78,7 @@ class Item(BaseModel):

import sys
from dataclasses import is_dataclass
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, cast, overload

from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some
from typing_extensions import Literal, is_typeddict
Expand Down Expand Up @@ -205,23 +205,30 @@ def __new__(cls, __type: Any, *, config: ConfigDict | None = ...) -> TypeAdapter
raise NotImplementedError

@overload
def __init__(self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
def __init__(
self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None
) -> None:
...

# this overload is for non-type things like Union[int, str]
# Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object]
# so an explicit type cast is needed
@overload
def __init__(self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
def __init__(
self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None
) -> None:
...

def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
def __init__(
self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None
) -> None:
"""Initializes the TypeAdapter object.

Args:
type: The type associated with the `TypeAdapter`.
config: Configuration for the `TypeAdapter`, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict].
_parent_depth: depth at which to search the parent namespace to construct the local namespace.
module: The module that passes to plugin if provided.

!!! note
You cannot use the `config` argument when instantiating a `TypeAdapter` if the type you're using has its own
Expand Down Expand Up @@ -264,7 +271,12 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth
try:
validator = _getattr_no_parents(type, '__pydantic_validator__')
except AttributeError:
validator = create_schema_validator(core_schema, core_config, config_wrapper.plugin_settings)
if module is None:
f = sys._getframe(1)
module = cast(str, f.f_globals['__name__'])
validator = create_schema_validator(
core_schema, type, module, str(type), 'TypeAdapter', core_config, config_wrapper.plugin_settings
) # type: ignore

serializer: SchemaSerializer
try:
Expand Down