From e2b407b48ab7509a771ef7f9bc21860b8db65bc5 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Wed, 18 Oct 2023 17:35:31 +0330 Subject: [PATCH 01/11] Make path of the item to validate available in plugin --- pydantic/_internal/_dataclasses.py | 2 +- pydantic/_internal/_model_construction.py | 16 +++- pydantic/_internal/_validate_call.py | 6 +- pydantic/main.py | 1 + pydantic/plugin/__init__.py | 2 + pydantic/plugin/_schema_validator.py | 7 +- pydantic/type_adapter.py | 5 +- tests/plugin/example_plugin.py | 2 +- tests/test_plugins.py | 94 +++++++++++++++++++++-- 9 files changed, 118 insertions(+), 17 deletions(-) diff --git a/pydantic/_internal/_dataclasses.py b/pydantic/_internal/_dataclasses.py index 17240571e5..fe694ba47c 100644 --- a/pydantic/_internal/_dataclasses.py +++ b/pydantic/_internal/_dataclasses.py @@ -172,7 +172,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) - cls.__pydantic_core_schema__ = schema cls.__pydantic_validator__ = validator = create_schema_validator( - schema, core_config, config_wrapper.plugin_settings + schema, f'{cls.__module__}:{cls.__name__}', core_config, config_wrapper.plugin_settings ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 26cb8ffda7..6d92b748d1 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -1,6 +1,7 @@ """Private logic for creating models.""" from __future__ import annotations as _annotations +import sys import typing import warnings import weakref @@ -81,6 +82,7 @@ def __new__( # Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact # that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__` # call we're in the middle of is for the `BaseModel` class. + is_dynamic_model: bool = kwargs.pop('is_dynamic_model', False) if bases: base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases) @@ -182,6 +184,7 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None: config_wrapper, raise_errors=False, types_namespace=types_namespace, + is_dynamic_model=is_dynamic_model, ) # using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__ # I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is @@ -438,6 +441,7 @@ def complete_model_class( *, raise_errors: bool = True, types_namespace: dict[str, Any] | None, + is_dynamic_model: bool = False, ) -> bool: """Finish building a model class. @@ -450,6 +454,7 @@ def complete_model_class( config_wrapper: The config wrapper instance. raise_errors: Whether to raise errors. types_namespace: Optional extra namespace to look for types in. + is_dynamic_model: Whether the model is a dynamic model (function is called from `create_model`). Returns: `True` if the model is successfully completed, else `False`. @@ -493,7 +498,16 @@ def complete_model_class( # debug(schema) cls.__pydantic_core_schema__ = schema - cls.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings) + + if is_dynamic_model: + f = sys._getframe(3) + cls_module = f.f_globals['__name__'] + else: + cls_module = cls.__module__ + + cls.__pydantic_validator__ = create_schema_validator( + schema, f'{cls_module}:{cls_name}', core_config, config_wrapper.plugin_settings + ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) cls.__pydantic_complete__ = True diff --git a/pydantic/_internal/_validate_call.py b/pydantic/_internal/_validate_call.py index bd5f80a1b9..5b1bc4afc2 100644 --- a/pydantic/_internal/_validate_call.py +++ b/pydantic/_internal/_validate_call.py @@ -64,7 +64,9 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali schema = gen_schema.clean_schema(gen_schema.generate_schema(function)) self.__pydantic_core_schema__ = schema core_config = config_wrapper.core_config(self) - self.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings) + path = f'{self.__module__}:{self.__name__}' + + self.__pydantic_validator__ = create_schema_validator(schema, path, core_config, config_wrapper.plugin_settings) if self._validate_return: return_type = ( @@ -75,7 +77,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace) schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type)) self.__return_pydantic_core_schema__ = schema - validator = create_schema_validator(schema, core_config, config_wrapper.plugin_settings) + validator = create_schema_validator(schema, path, core_config, config_wrapper.plugin_settings) if inspect.iscoroutinefunction(self.raw_function): async def return_val_wrapper(aw: Awaitable[Any]) -> None: diff --git a/pydantic/main.py b/pydantic/main.py index 28ad104177..458ad4153a 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -1431,6 +1431,7 @@ def create_model( if resolved_bases is not __base__: ns['__orig_bases__'] = __base__ namespace.update(ns) + kwds['is_dynamic_model'] = True return meta(__model_name, resolved_bases, namespace, __pydantic_reset_parent_namespace__=False, **kwds) diff --git a/pydantic/plugin/__init__.py b/pydantic/plugin/__init__.py index b1e0c0d6fe..c5374585b2 100644 --- a/pydantic/plugin/__init__.py +++ b/pydantic/plugin/__init__.py @@ -27,6 +27,7 @@ class PydanticPluginProtocol(Protocol): def new_schema_validator( self, schema: CoreSchema, + path: str, config: CoreConfig | None, plugin_settings: dict[str, object], ) -> tuple[ @@ -40,6 +41,7 @@ def new_schema_validator( Args: schema: The schema to validate against. + path: The path of item to validate against. config: The config to use for validation. plugin_settings: Any plugin settings. diff --git a/pydantic/plugin/_schema_validator.py b/pydantic/plugin/_schema_validator.py index 4fc41d2db1..83bc857f4b 100644 --- a/pydantic/plugin/_schema_validator.py +++ b/pydantic/plugin/_schema_validator.py @@ -18,7 +18,7 @@ def create_schema_validator( - schema: CoreSchema, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None + schema: CoreSchema, path: str, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None ) -> SchemaValidator: """Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed. @@ -29,7 +29,7 @@ def create_schema_validator( plugins = get_plugins() if plugins: - return PluggableSchemaValidator(schema, config, plugins, plugin_settings or {}) # type: ignore + return PluggableSchemaValidator(schema, path, config, plugins, plugin_settings or {}) # type: ignore else: return SchemaValidator(schema, config) @@ -42,6 +42,7 @@ class PluggableSchemaValidator: def __init__( self, schema: CoreSchema, + path: str, config: CoreConfig | None, plugins: Iterable[PydanticPluginProtocol], plugin_settings: dict[str, Any], @@ -52,7 +53,7 @@ def __init__( json_event_handlers: list[BaseValidateHandlerProtocol] = [] strings_event_handlers: list[BaseValidateHandlerProtocol] = [] for plugin in plugins: - p, j, s = plugin.new_schema_validator(schema, config, plugin_settings) + p, j, s = plugin.new_schema_validator(schema, path, config, plugin_settings) if p is not None: python_event_handlers.append(p) if j is not None: diff --git a/pydantic/type_adapter.py b/pydantic/type_adapter.py index c839388ee0..fe9eea9ea0 100644 --- a/pydantic/type_adapter.py +++ b/pydantic/type_adapter.py @@ -264,7 +264,10 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth try: validator = _getattr_no_parents(type, '__pydantic_validator__') except AttributeError: - validator = create_schema_validator(core_schema, core_config, config_wrapper.plugin_settings) + f = sys._getframe(1) + validator = create_schema_validator( + core_schema, f'{f.f_globals["__name__"]}:{type}', core_config, config_wrapper.plugin_settings + ) serializer: SchemaSerializer try: diff --git a/tests/plugin/example_plugin.py b/tests/plugin/example_plugin.py index 3450e335cd..119cae584b 100644 --- a/tests/plugin/example_plugin.py +++ b/tests/plugin/example_plugin.py @@ -24,7 +24,7 @@ def on_error(self, error) -> None: class Plugin: - def new_schema_validator(self, schema, config, plugin_settings): + def new_schema_validator(self, schema, path, config, plugin_settings): return ValidatePythonHandler(), None, None diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 25bda29f0c..5fb01c7f6c 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,11 +1,12 @@ from __future__ import annotations import contextlib -from typing import Any, Generator +from functools import partial +from typing import Any, Generator, List from pydantic_core import ValidationError -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, TypeAdapter, create_model, dataclasses, field_validator, validate_call from pydantic.plugin import ( PydanticPluginProtocol, ValidateJsonHandlerProtocol, @@ -43,9 +44,10 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class CustomPlugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, config, plugin_settings): + def new_schema_validator(self, schema, path, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} + assert path == 'tests.test_plugins:Model' return None, CustomOnValidateJson(), None plugin = CustomPlugin() @@ -85,7 +87,7 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, config, plugin_settings): + def new_schema_validator(self, schema, path, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} return None, CustomOnValidateJson(), None @@ -121,7 +123,7 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class Plugin: - def new_schema_validator(self, schema, config, plugin_settings): + def new_schema_validator(self, schema, path, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} return CustomOnValidatePython(), None, None @@ -164,7 +166,7 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, config, plugin_settings): + def new_schema_validator(self, schema, path, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} return CustomOnValidatePython(), None, None @@ -205,7 +207,7 @@ def on_exception(self, exception: Exception) -> None: stack.pop() class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, config, plugin_settings): + def new_schema_validator(self, schema, path, config, plugin_settings): return CustomOnValidatePython(), None, None plugin = Plugin() @@ -268,7 +270,7 @@ def on_error(self, error: ValidationError) -> None: log.append(f'strings error error={error}') class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, config, plugin_settings): + def new_schema_validator(self, schema, path, config, plugin_settings): return Python(), Json(), Strings() plugin = Plugin() @@ -294,3 +296,79 @@ class Model(BaseModel): "strings enter input={'a': '3'} kwargs={'strict': True, 'context': {'c': 3}}", 'strings success result=a=3', ] + + +def test_plugin_path_dataclass() -> None: + class CustomOnValidatePython(ValidatePythonHandlerProtocol): + pass + + class Plugin: + def new_schema_validator(self, schema, path, config, plugin_settings): + assert path == 'tests.test_plugins:Bar' + return CustomOnValidatePython(), None, None + + plugin = Plugin() + with install_plugin(plugin): + + @dataclasses.dataclass + class Bar: + a: int + + +def test_plugin_path_type_adapter() -> None: + class CustomOnValidatePython(ValidatePythonHandlerProtocol): + pass + + class Plugin: + def new_schema_validator(self, schema, path, config, plugin_settings): + assert path == 'tests.test_plugins:typing.List[str]' + return CustomOnValidatePython(), None, None + + plugin = Plugin() + with install_plugin(plugin): + TypeAdapter(List[str]) + + +def test_plugin_path_validate_call() -> None: + class CustomOnValidatePython(ValidatePythonHandlerProtocol): + pass + + class Plugin1: + def new_schema_validator(self, schema, path, config, plugin_settings): + assert path == 'tests.test_plugins:foo' + return CustomOnValidatePython(), None, None + + plugin = Plugin1() + with install_plugin(plugin): + + @validate_call() + def foo(a: int): + return a + + class Plugin2: + def new_schema_validator(self, schema, path, config, plugin_settings): + assert path == 'tests.test_plugins:partial(my_wrapped_function)' + return CustomOnValidatePython(), None, None + + plugin = Plugin2() + with install_plugin(plugin): + + def my_wrapped_function(a: int, b: int, c: int): + return a + b + c + + my_partial_function = partial(my_wrapped_function, c=3) + validate_call(my_partial_function) + + +def test_plugin_path_create_model() -> None: + class CustomOnValidatePython(ValidatePythonHandlerProtocol): + pass + + class Plugin: + def new_schema_validator(self, schema, path, config, plugin_settings): + assert path == 'tests.test_plugins:FooModel' + return CustomOnValidatePython(), None, None + + plugin = Plugin() + with install_plugin(plugin): + create_model('FooModel', foo=(str, ...), bar=(int, 123)) From be9c166d2f4f606b2302a707e2e43a24918572b3 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Wed, 18 Oct 2023 18:30:48 +0330 Subject: [PATCH 02/11] Use __qualname__ --- pydantic/_internal/_dataclasses.py | 2 +- pydantic/_internal/_model_construction.py | 2 +- pydantic/_internal/_validate_call.py | 2 +- tests/test_plugins.py | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pydantic/_internal/_dataclasses.py b/pydantic/_internal/_dataclasses.py index fe694ba47c..5b6927d80c 100644 --- a/pydantic/_internal/_dataclasses.py +++ b/pydantic/_internal/_dataclasses.py @@ -172,7 +172,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) - cls.__pydantic_core_schema__ = schema cls.__pydantic_validator__ = validator = create_schema_validator( - schema, f'{cls.__module__}:{cls.__name__}', core_config, config_wrapper.plugin_settings + schema, f'{cls.__module__}:{cls.__qualname__}', core_config, config_wrapper.plugin_settings ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 6d92b748d1..bcd824deb2 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -506,7 +506,7 @@ def complete_model_class( cls_module = cls.__module__ cls.__pydantic_validator__ = create_schema_validator( - schema, f'{cls_module}:{cls_name}', core_config, config_wrapper.plugin_settings + schema, f'{cls_module}:{cls.__qualname__}', core_config, config_wrapper.plugin_settings ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) cls.__pydantic_complete__ = True diff --git a/pydantic/_internal/_validate_call.py b/pydantic/_internal/_validate_call.py index 5b1bc4afc2..6938616284 100644 --- a/pydantic/_internal/_validate_call.py +++ b/pydantic/_internal/_validate_call.py @@ -64,7 +64,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali schema = gen_schema.clean_schema(gen_schema.generate_schema(function)) self.__pydantic_core_schema__ = schema core_config = config_wrapper.core_config(self) - path = f'{self.__module__}:{self.__name__}' + path = f'{self.__module__}:{self.__qualname__}' self.__pydantic_validator__ = create_schema_validator(schema, path, core_config, config_wrapper.plugin_settings) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 5fb01c7f6c..a5a900b688 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -47,7 +47,7 @@ class CustomPlugin(PydanticPluginProtocol): def new_schema_validator(self, schema, path, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} - assert path == 'tests.test_plugins:Model' + assert path == 'tests.test_plugins:test_on_validate_json_on_success..Model' return None, CustomOnValidateJson(), None plugin = CustomPlugin() @@ -304,7 +304,7 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): class Plugin: def new_schema_validator(self, schema, path, config, plugin_settings): - assert path == 'tests.test_plugins:Bar' + assert path == 'tests.test_plugins:test_plugin_path_dataclass..Bar' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -335,7 +335,7 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): class Plugin1: def new_schema_validator(self, schema, path, config, plugin_settings): - assert path == 'tests.test_plugins:foo' + assert path == 'tests.test_plugins:test_plugin_path_validate_call..foo' return CustomOnValidatePython(), None, None plugin = Plugin1() @@ -347,7 +347,7 @@ def foo(a: int): class Plugin2: def new_schema_validator(self, schema, path, config, plugin_settings): - assert path == 'tests.test_plugins:partial(my_wrapped_function)' + assert path == 'tests.test_plugins:partial(test_plugin_path_validate_call..my_wrapped_function)' return CustomOnValidatePython(), None, None plugin = Plugin2() From f98883f44284c63e69abf3578bd6b4444096696d Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Wed, 18 Oct 2023 18:52:23 +0330 Subject: [PATCH 03/11] Add more tests --- tests/test_plugins.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index a5a900b688..eaac2ae601 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -372,3 +372,34 @@ def new_schema_validator(self, schema, path, config, plugin_settings): plugin = Plugin() with install_plugin(plugin): create_model('FooModel', foo=(str, ...), bar=(int, 123)) + + +def test_plugin_path_complex() -> None: + paths: list[str] = [] + + class CustomOnValidatePython(ValidatePythonHandlerProtocol): + pass + + class Plugin: + def new_schema_validator(self, schema, path, config, plugin_settings): + paths.append(path) + return CustomOnValidatePython(), None, None + + plugin = Plugin() + with install_plugin(plugin): + + def foo(): + class Model(BaseModel): + pass + + def bar(): + class Model(BaseModel): + pass + + foo() + bar() + + assert paths == [ + 'tests.test_plugins:test_plugin_path_complex..foo..Model', + 'tests.test_plugins:test_plugin_path_complex..bar..Model', + ] From 514e267ac572a5edb8246a728bbe56db4763aaa8 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Fri, 20 Oct 2023 11:34:46 +0330 Subject: [PATCH 04/11] Address comments --- docs/concepts/plugins.md | 1 + pydantic/_internal/_dataclasses.py | 2 +- pydantic/_internal/_model_construction.py | 17 +++---- pydantic/_internal/_validate_call.py | 9 ++-- pydantic/main.py | 15 ++++++- pydantic/plugin/__init__.py | 4 +- pydantic/plugin/_schema_validator.py | 17 +++++-- pydantic/type_adapter.py | 19 +++++--- tests/plugin/example_plugin.py | 2 +- tests/test_plugins.py | 54 +++++++++++++++-------- 10 files changed, 91 insertions(+), 49 deletions(-) diff --git a/docs/concepts/plugins.md b/docs/concepts/plugins.md index 24fbfe5649..5a3e9b6592 100644 --- a/docs/concepts/plugins.md +++ b/docs/concepts/plugins.md @@ -89,6 +89,7 @@ class Plugin(PydanticPluginProtocol): def new_schema_validator( self, schema: CoreSchema, + type_path: str, config: Union[CoreConfig, None], plugin_settings: Dict[str, object], ) -> NewSchemaReturns: diff --git a/pydantic/_internal/_dataclasses.py b/pydantic/_internal/_dataclasses.py index 5b6927d80c..122ec9007f 100644 --- a/pydantic/_internal/_dataclasses.py +++ b/pydantic/_internal/_dataclasses.py @@ -172,7 +172,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) - cls.__pydantic_core_schema__ = schema cls.__pydantic_validator__ = validator = create_schema_validator( - schema, f'{cls.__module__}:{cls.__qualname__}', core_config, config_wrapper.plugin_settings + schema, cls.__module__, cls.__qualname__, core_config, config_wrapper.plugin_settings ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index bcd824deb2..36f48d3afd 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -1,7 +1,6 @@ """Private logic for creating models.""" from __future__ import annotations as _annotations -import sys import typing import warnings import weakref @@ -64,6 +63,7 @@ def __new__( namespace: dict[str, Any], __pydantic_generic_metadata__: PydanticGenericMetadata | None = None, __pydantic_reset_parent_namespace__: bool = True, + cls_module: str | None = None, **kwargs: Any, ) -> type: """Metaclass for creating Pydantic models. @@ -74,6 +74,7 @@ def __new__( namespace: The attribute dictionary of the class to be created. __pydantic_generic_metadata__: Metadata for generic models. __pydantic_reset_parent_namespace__: Reset parent namespace. + cls_module: The module of the class to be created. **kwargs: Catch-all for any other keyword arguments. Returns: @@ -82,7 +83,6 @@ def __new__( # Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact # that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__` # call we're in the middle of is for the `BaseModel` class. - is_dynamic_model: bool = kwargs.pop('is_dynamic_model', False) if bases: base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases) @@ -184,7 +184,7 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None: config_wrapper, raise_errors=False, types_namespace=types_namespace, - is_dynamic_model=is_dynamic_model, + cls_module=cls_module, ) # using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__ # I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is @@ -441,7 +441,7 @@ def complete_model_class( *, raise_errors: bool = True, types_namespace: dict[str, Any] | None, - is_dynamic_model: bool = False, + cls_module: str | None = None, ) -> bool: """Finish building a model class. @@ -454,7 +454,7 @@ def complete_model_class( config_wrapper: The config wrapper instance. raise_errors: Whether to raise errors. types_namespace: Optional extra namespace to look for types in. - is_dynamic_model: Whether the model is a dynamic model (function is called from `create_model`). + cls_module: The module of the class to be created. Returns: `True` if the model is successfully completed, else `False`. @@ -499,14 +499,11 @@ def complete_model_class( # debug(schema) cls.__pydantic_core_schema__ = schema - if is_dynamic_model: - f = sys._getframe(3) - cls_module = f.f_globals['__name__'] - else: + if cls_module is None: cls_module = cls.__module__ cls.__pydantic_validator__ = create_schema_validator( - schema, f'{cls_module}:{cls.__qualname__}', core_config, config_wrapper.plugin_settings + schema, cls_module, cls.__qualname__, core_config, config_wrapper.plugin_settings ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) cls.__pydantic_complete__ = True diff --git a/pydantic/_internal/_validate_call.py b/pydantic/_internal/_validate_call.py index 6938616284..113af7f5b4 100644 --- a/pydantic/_internal/_validate_call.py +++ b/pydantic/_internal/_validate_call.py @@ -64,9 +64,10 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali schema = gen_schema.clean_schema(gen_schema.generate_schema(function)) self.__pydantic_core_schema__ = schema core_config = config_wrapper.core_config(self) - path = f'{self.__module__}:{self.__qualname__}' - self.__pydantic_validator__ = create_schema_validator(schema, path, core_config, config_wrapper.plugin_settings) + self.__pydantic_validator__ = create_schema_validator( + schema, self.__module__, self.__qualname__, core_config, config_wrapper.plugin_settings + ) if self._validate_return: return_type = ( @@ -77,7 +78,9 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace) schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type)) self.__return_pydantic_core_schema__ = schema - validator = create_schema_validator(schema, path, core_config, config_wrapper.plugin_settings) + validator = create_schema_validator( + schema, self.__module__, self.__qualname__, core_config, config_wrapper.plugin_settings + ) if inspect.iscoroutinefunction(self.raw_function): async def return_val_wrapper(aw: Awaitable[Any]) -> None: diff --git a/pydantic/main.py b/pydantic/main.py index 458ad4153a..a615f7177a 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -1,6 +1,7 @@ """Logic for creating models.""" from __future__ import annotations as _annotations +import sys import types import typing import warnings @@ -1431,8 +1432,18 @@ def create_model( if resolved_bases is not __base__: ns['__orig_bases__'] = __base__ namespace.update(ns) - kwds['is_dynamic_model'] = True - return meta(__model_name, resolved_bases, namespace, __pydantic_reset_parent_namespace__=False, **kwds) + + f = sys._getframe(1) + cls_module = f.f_globals['__name__'] + + return meta( + __model_name, + resolved_bases, + namespace, + __pydantic_reset_parent_namespace__=False, + cls_module=cls_module, + **kwds, + ) __getattr__ = getattr_migration(__name__) diff --git a/pydantic/plugin/__init__.py b/pydantic/plugin/__init__.py index c5374585b2..65cb514224 100644 --- a/pydantic/plugin/__init__.py +++ b/pydantic/plugin/__init__.py @@ -27,7 +27,7 @@ class PydanticPluginProtocol(Protocol): def new_schema_validator( self, schema: CoreSchema, - path: str, + type_path: str, config: CoreConfig | None, plugin_settings: dict[str, object], ) -> tuple[ @@ -41,7 +41,7 @@ def new_schema_validator( Args: schema: The schema to validate against. - path: The path of item to validate against. + type_path: The path of item to validate against. config: The config to use for validation. plugin_settings: Any plugin settings. diff --git a/pydantic/plugin/_schema_validator.py b/pydantic/plugin/_schema_validator.py index 83bc857f4b..6e3e679488 100644 --- a/pydantic/plugin/_schema_validator.py +++ b/pydantic/plugin/_schema_validator.py @@ -17,8 +17,16 @@ events: list[Event] = list(Event.__args__) # type: ignore +def build_type_path(module: str, name: str) -> str: + return f'{module}:{name}' + + def create_schema_validator( - schema: CoreSchema, path: str, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None + schema: CoreSchema, + module: str, + type_name: str, + config: CoreConfig | None = None, + plugin_settings: dict[str, Any] | None = None, ) -> SchemaValidator: """Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed. @@ -29,7 +37,8 @@ def create_schema_validator( plugins = get_plugins() if plugins: - return PluggableSchemaValidator(schema, path, config, plugins, plugin_settings or {}) # type: ignore + type_path = build_type_path(module, type_name) + return PluggableSchemaValidator(schema, type_path, config, plugins, plugin_settings or {}) # type: ignore else: return SchemaValidator(schema, config) @@ -42,7 +51,7 @@ class PluggableSchemaValidator: def __init__( self, schema: CoreSchema, - path: str, + type_path: str, config: CoreConfig | None, plugins: Iterable[PydanticPluginProtocol], plugin_settings: dict[str, Any], @@ -53,7 +62,7 @@ def __init__( json_event_handlers: list[BaseValidateHandlerProtocol] = [] strings_event_handlers: list[BaseValidateHandlerProtocol] = [] for plugin in plugins: - p, j, s = plugin.new_schema_validator(schema, path, config, plugin_settings) + p, j, s = plugin.new_schema_validator(schema, type_path, config, plugin_settings) if p is not None: python_event_handlers.append(p) if j is not None: diff --git a/pydantic/type_adapter.py b/pydantic/type_adapter.py index fe9eea9ea0..5ca1eefa0e 100644 --- a/pydantic/type_adapter.py +++ b/pydantic/type_adapter.py @@ -205,17 +205,21 @@ def __new__(cls, __type: Any, *, config: ConfigDict | None = ...) -> TypeAdapter raise NotImplementedError @overload - def __init__(self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None: + def __init__( + self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None + ) -> None: ... # this overload is for non-type things like Union[int, str] # Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object] # so an explicit type cast is needed @overload - def __init__(self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None: + def __init__( + self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None + ) -> None: ... - def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None: + def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None) -> None: """Initializes the TypeAdapter object. Args: @@ -223,6 +227,7 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth config: Configuration for the `TypeAdapter`, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict]. _parent_depth: depth at which to search the parent namespace to construct the local namespace. + !!! note You cannot use the `config` argument when instantiating a `TypeAdapter` if the type you're using has its own config that cannot be overridden (ex: `BaseModel`, `TypedDict`, and `dataclass`). A @@ -264,10 +269,10 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth try: validator = _getattr_no_parents(type, '__pydantic_validator__') except AttributeError: - f = sys._getframe(1) - validator = create_schema_validator( - core_schema, f'{f.f_globals["__name__"]}:{type}', core_config, config_wrapper.plugin_settings - ) + if module is None: + f = sys._getframe(1) + module = f.f_globals['__name__'] + validator = create_schema_validator(core_schema, module, type, core_config, config_wrapper.plugin_settings) # type: ignore serializer: SchemaSerializer try: diff --git a/tests/plugin/example_plugin.py b/tests/plugin/example_plugin.py index 119cae584b..3eee3808a8 100644 --- a/tests/plugin/example_plugin.py +++ b/tests/plugin/example_plugin.py @@ -24,7 +24,7 @@ def on_error(self, error) -> None: class Plugin: - def new_schema_validator(self, schema, path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, config, plugin_settings): return ValidatePythonHandler(), None, None diff --git a/tests/test_plugins.py b/tests/test_plugins.py index eaac2ae601..03a3fe4b14 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -44,10 +44,10 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class CustomPlugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} - assert path == 'tests.test_plugins:test_on_validate_json_on_success..Model' + assert type_path == 'tests.test_plugins:test_on_validate_json_on_success..Model' return None, CustomOnValidateJson(), None plugin = CustomPlugin() @@ -87,7 +87,7 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} return None, CustomOnValidateJson(), None @@ -123,7 +123,7 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class Plugin: - def new_schema_validator(self, schema, path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} return CustomOnValidatePython(), None, None @@ -166,7 +166,7 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} return CustomOnValidatePython(), None, None @@ -207,7 +207,7 @@ def on_exception(self, exception: Exception) -> None: stack.pop() class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, config, plugin_settings): return CustomOnValidatePython(), None, None plugin = Plugin() @@ -270,7 +270,7 @@ def on_error(self, error: ValidationError) -> None: log.append(f'strings error error={error}') class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, config, plugin_settings): return Python(), Json(), Strings() plugin = Plugin() @@ -303,8 +303,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, path, config, plugin_settings): - assert path == 'tests.test_plugins:test_plugin_path_dataclass..Bar' + def new_schema_validator(self, schema, type_path, config, plugin_settings): + assert type_path == 'tests.test_plugins:test_plugin_path_dataclass..Bar' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -320,8 +320,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, path, config, plugin_settings): - assert path == 'tests.test_plugins:typing.List[str]' + def new_schema_validator(self, schema, type_path, config, plugin_settings): + assert type_path == 'tests.test_plugins:typing.List[str]' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -329,13 +329,27 @@ def new_schema_validator(self, schema, path, config, plugin_settings): TypeAdapter(List[str]) +def test_plugin_path_type_adapter_with_module() -> None: + class CustomOnValidatePython(ValidatePythonHandlerProtocol): + pass + + class Plugin: + def new_schema_validator(self, schema, type_path, config, plugin_settings): + assert type_path == 'provided_module_by_type_adapter:typing.List[str]' + return CustomOnValidatePython(), None, None + + plugin = Plugin() + with install_plugin(plugin): + TypeAdapter(List[str], module='provided_module_by_type_adapter') + + def test_plugin_path_validate_call() -> None: class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin1: - def new_schema_validator(self, schema, path, config, plugin_settings): - assert path == 'tests.test_plugins:test_plugin_path_validate_call..foo' + def new_schema_validator(self, schema, type_path, config, plugin_settings): + assert type_path == 'tests.test_plugins:test_plugin_path_validate_call..foo' return CustomOnValidatePython(), None, None plugin = Plugin1() @@ -346,8 +360,10 @@ def foo(a: int): return a class Plugin2: - def new_schema_validator(self, schema, path, config, plugin_settings): - assert path == 'tests.test_plugins:partial(test_plugin_path_validate_call..my_wrapped_function)' + def new_schema_validator(self, schema, type_path, config, plugin_settings): + assert ( + type_path == 'tests.test_plugins:partial(test_plugin_path_validate_call..my_wrapped_function)' + ) return CustomOnValidatePython(), None, None plugin = Plugin2() @@ -365,8 +381,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, path, config, plugin_settings): - assert path == 'tests.test_plugins:FooModel' + def new_schema_validator(self, schema, type_path, config, plugin_settings): + assert type_path == 'tests.test_plugins:FooModel' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -381,8 +397,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, path, config, plugin_settings): - paths.append(path) + def new_schema_validator(self, schema, type_path, config, plugin_settings): + paths.append(type_path) return CustomOnValidatePython(), None, None plugin = Plugin() From ecbfd5bf68ba8e36b7e017e5a852a985a68127c5 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Fri, 20 Oct 2023 12:55:21 +0330 Subject: [PATCH 05/11] Add item_type --- docs/concepts/plugins.md | 1 + pydantic/_internal/_dataclasses.py | 2 +- pydantic/_internal/_model_construction.py | 10 +++--- pydantic/_internal/_validate_call.py | 4 +-- pydantic/plugin/__init__.py | 2 ++ pydantic/plugin/_schema_validator.py | 6 ++-- pydantic/type_adapter.py | 2 +- tests/plugin/example_plugin.py | 2 +- tests/test_plugins.py | 43 ++++++++++++++--------- 9 files changed, 44 insertions(+), 28 deletions(-) diff --git a/docs/concepts/plugins.md b/docs/concepts/plugins.md index 5a3e9b6592..ba94b1f62b 100644 --- a/docs/concepts/plugins.md +++ b/docs/concepts/plugins.md @@ -90,6 +90,7 @@ class Plugin(PydanticPluginProtocol): self, schema: CoreSchema, type_path: str, + item_type: str, config: Union[CoreConfig, None], plugin_settings: Dict[str, object], ) -> NewSchemaReturns: diff --git a/pydantic/_internal/_dataclasses.py b/pydantic/_internal/_dataclasses.py index 122ec9007f..e0da672abf 100644 --- a/pydantic/_internal/_dataclasses.py +++ b/pydantic/_internal/_dataclasses.py @@ -172,7 +172,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) - cls.__pydantic_core_schema__ = schema cls.__pydantic_validator__ = validator = create_schema_validator( - schema, cls.__module__, cls.__qualname__, core_config, config_wrapper.plugin_settings + schema, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 36f48d3afd..59cc369d13 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -499,11 +499,13 @@ def complete_model_class( # debug(schema) cls.__pydantic_core_schema__ = schema - if cls_module is None: - cls_module = cls.__module__ - cls.__pydantic_validator__ = create_schema_validator( - schema, cls_module, cls.__qualname__, core_config, config_wrapper.plugin_settings + schema, + cls_module if cls_module else cls.__module__, + cls.__qualname__, + 'create_model' if cls_module else 'BaseModel', + core_config, + config_wrapper.plugin_settings, ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) cls.__pydantic_complete__ = True diff --git a/pydantic/_internal/_validate_call.py b/pydantic/_internal/_validate_call.py index 113af7f5b4..0eb20c900c 100644 --- a/pydantic/_internal/_validate_call.py +++ b/pydantic/_internal/_validate_call.py @@ -66,7 +66,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali core_config = config_wrapper.core_config(self) self.__pydantic_validator__ = create_schema_validator( - schema, self.__module__, self.__qualname__, core_config, config_wrapper.plugin_settings + schema, self.__module__, self.__qualname__, 'validate_call', core_config, config_wrapper.plugin_settings ) if self._validate_return: @@ -79,7 +79,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type)) self.__return_pydantic_core_schema__ = schema validator = create_schema_validator( - schema, self.__module__, self.__qualname__, core_config, config_wrapper.plugin_settings + schema, self.__module__, self.__qualname__, 'validate_call', core_config, config_wrapper.plugin_settings ) if inspect.iscoroutinefunction(self.raw_function): diff --git a/pydantic/plugin/__init__.py b/pydantic/plugin/__init__.py index 65cb514224..fdc2aa1083 100644 --- a/pydantic/plugin/__init__.py +++ b/pydantic/plugin/__init__.py @@ -28,6 +28,7 @@ def new_schema_validator( self, schema: CoreSchema, type_path: str, + item_type: str, config: CoreConfig | None, plugin_settings: dict[str, object], ) -> tuple[ @@ -42,6 +43,7 @@ def new_schema_validator( Args: schema: The schema to validate against. type_path: The path of item to validate against. + item_type: The type of item to validate against. config: The config to use for validation. plugin_settings: Any plugin settings. diff --git a/pydantic/plugin/_schema_validator.py b/pydantic/plugin/_schema_validator.py index 6e3e679488..638fc8bcc4 100644 --- a/pydantic/plugin/_schema_validator.py +++ b/pydantic/plugin/_schema_validator.py @@ -25,6 +25,7 @@ def create_schema_validator( schema: CoreSchema, module: str, type_name: str, + item_type: str, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None, ) -> SchemaValidator: @@ -38,7 +39,7 @@ def create_schema_validator( plugins = get_plugins() if plugins: type_path = build_type_path(module, type_name) - return PluggableSchemaValidator(schema, type_path, config, plugins, plugin_settings or {}) # type: ignore + return PluggableSchemaValidator(schema, type_path, item_type, config, plugins, plugin_settings or {}) # type: ignore else: return SchemaValidator(schema, config) @@ -52,6 +53,7 @@ def __init__( self, schema: CoreSchema, type_path: str, + item_type: str, config: CoreConfig | None, plugins: Iterable[PydanticPluginProtocol], plugin_settings: dict[str, Any], @@ -62,7 +64,7 @@ def __init__( json_event_handlers: list[BaseValidateHandlerProtocol] = [] strings_event_handlers: list[BaseValidateHandlerProtocol] = [] for plugin in plugins: - p, j, s = plugin.new_schema_validator(schema, type_path, config, plugin_settings) + p, j, s = plugin.new_schema_validator(schema, type_path, item_type, config, plugin_settings) if p is not None: python_event_handlers.append(p) if j is not None: diff --git a/pydantic/type_adapter.py b/pydantic/type_adapter.py index 5ca1eefa0e..e6c2d47b93 100644 --- a/pydantic/type_adapter.py +++ b/pydantic/type_adapter.py @@ -272,7 +272,7 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth if module is None: f = sys._getframe(1) module = f.f_globals['__name__'] - validator = create_schema_validator(core_schema, module, type, core_config, config_wrapper.plugin_settings) # type: ignore + validator = create_schema_validator(core_schema, module, type, 'type_adapter', core_config, config_wrapper.plugin_settings) # type: ignore serializer: SchemaSerializer try: diff --git a/tests/plugin/example_plugin.py b/tests/plugin/example_plugin.py index 3eee3808a8..cb9402a5a5 100644 --- a/tests/plugin/example_plugin.py +++ b/tests/plugin/example_plugin.py @@ -24,7 +24,7 @@ def on_error(self, error) -> None: class Plugin: - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): return ValidatePythonHandler(), None, None diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 03a3fe4b14..863727cf13 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -44,10 +44,11 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class CustomPlugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} assert type_path == 'tests.test_plugins:test_on_validate_json_on_success..Model' + assert item_type == 'BaseModel' return None, CustomOnValidateJson(), None plugin = CustomPlugin() @@ -87,7 +88,7 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} return None, CustomOnValidateJson(), None @@ -123,9 +124,10 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class Plugin: - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} + assert item_type == 'BaseModel' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -166,9 +168,10 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} + assert item_type == 'BaseModel' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -207,7 +210,7 @@ def on_exception(self, exception: Exception) -> None: stack.pop() class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): return CustomOnValidatePython(), None, None plugin = Plugin() @@ -270,7 +273,7 @@ def on_error(self, error: ValidationError) -> None: log.append(f'strings error error={error}') class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): return Python(), Json(), Strings() plugin = Plugin() @@ -303,8 +306,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert type_path == 'tests.test_plugins:test_plugin_path_dataclass..Bar' + assert item_type == 'dataclass' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -320,8 +324,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert type_path == 'tests.test_plugins:typing.List[str]' + assert item_type == 'type_adapter' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -334,8 +339,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert type_path == 'provided_module_by_type_adapter:typing.List[str]' + assert item_type == 'type_adapter' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -348,8 +354,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin1: - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert type_path == 'tests.test_plugins:test_plugin_path_validate_call..foo' + assert item_type == 'validate_call' return CustomOnValidatePython(), None, None plugin = Plugin1() @@ -360,10 +367,11 @@ def foo(a: int): return a class Plugin2: - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert ( type_path == 'tests.test_plugins:partial(test_plugin_path_validate_call..my_wrapped_function)' ) + assert item_type == 'validate_call' return CustomOnValidatePython(), None, None plugin = Plugin2() @@ -381,8 +389,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, config, plugin_settings): + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): assert type_path == 'tests.test_plugins:FooModel' + assert item_type == 'create_model' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -391,14 +400,14 @@ def new_schema_validator(self, schema, type_path, config, plugin_settings): def test_plugin_path_complex() -> None: - paths: list[str] = [] + paths: list[tuple(str, str)] = [] class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, config, plugin_settings): - paths.append(type_path) + def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + paths.append((type_path, item_type)) return CustomOnValidatePython(), None, None plugin = Plugin() @@ -416,6 +425,6 @@ class Model(BaseModel): bar() assert paths == [ - 'tests.test_plugins:test_plugin_path_complex..foo..Model', - 'tests.test_plugins:test_plugin_path_complex..bar..Model', + ('tests.test_plugins:test_plugin_path_complex..foo..Model', 'BaseModel'), + ('tests.test_plugins:test_plugin_path_complex..bar..Model', 'BaseModel'), ] From b3606eb054e773ea44bfe43d0666ae8090cb7939 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Fri, 20 Oct 2023 19:11:48 +0330 Subject: [PATCH 06/11] Add source_type --- docs/concepts/plugins.md | 1 + pydantic/_internal/_dataclasses.py | 2 +- pydantic/_internal/_model_construction.py | 1 + pydantic/_internal/_validate_call.py | 18 ++++++++- pydantic/plugin/__init__.py | 2 + pydantic/plugin/_schema_validator.py | 6 ++- pydantic/type_adapter.py | 2 +- tests/plugin/example_plugin.py | 2 +- tests/test_plugins.py | 46 ++++++++++++++--------- 9 files changed, 55 insertions(+), 25 deletions(-) diff --git a/docs/concepts/plugins.md b/docs/concepts/plugins.md index ba94b1f62b..e06c5bea88 100644 --- a/docs/concepts/plugins.md +++ b/docs/concepts/plugins.md @@ -89,6 +89,7 @@ class Plugin(PydanticPluginProtocol): def new_schema_validator( self, schema: CoreSchema, + source_type: str, type_path: str, item_type: str, config: Union[CoreConfig, None], diff --git a/pydantic/_internal/_dataclasses.py b/pydantic/_internal/_dataclasses.py index e0da672abf..2bc43e9665 100644 --- a/pydantic/_internal/_dataclasses.py +++ b/pydantic/_internal/_dataclasses.py @@ -172,7 +172,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) - cls.__pydantic_core_schema__ = schema cls.__pydantic_validator__ = validator = create_schema_validator( - schema, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings + schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 59cc369d13..70d241eef2 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -501,6 +501,7 @@ def complete_model_class( cls.__pydantic_validator__ = create_schema_validator( schema, + cls, cls_module if cls_module else cls.__module__, cls.__qualname__, 'create_model' if cls_module else 'BaseModel', diff --git a/pydantic/_internal/_validate_call.py b/pydantic/_internal/_validate_call.py index 0eb20c900c..30235ff435 100644 --- a/pydantic/_internal/_validate_call.py +++ b/pydantic/_internal/_validate_call.py @@ -46,12 +46,14 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali self.__signature__ = inspect.signature(function) if isinstance(function, partial): func = function.func + source_type = func self.__name__ = f'partial({func.__name__})' self.__qualname__ = f'partial({func.__qualname__})' self.__annotations__ = func.__annotations__ self.__module__ = func.__module__ self.__doc__ = func.__doc__ else: + source_type = function self.__name__ = function.__name__ self.__qualname__ = function.__qualname__ self.__annotations__ = function.__annotations__ @@ -66,7 +68,13 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali core_config = config_wrapper.core_config(self) self.__pydantic_validator__ = create_schema_validator( - schema, self.__module__, self.__qualname__, 'validate_call', core_config, config_wrapper.plugin_settings + schema, + source_type, + self.__module__, + self.__qualname__, + 'validate_call', + core_config, + config_wrapper.plugin_settings, ) if self._validate_return: @@ -79,7 +87,13 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type)) self.__return_pydantic_core_schema__ = schema validator = create_schema_validator( - schema, self.__module__, self.__qualname__, 'validate_call', core_config, config_wrapper.plugin_settings + schema, + source_type, + self.__module__, + self.__qualname__, + 'validate_call', + core_config, + config_wrapper.plugin_settings, ) if inspect.iscoroutinefunction(self.raw_function): diff --git a/pydantic/plugin/__init__.py b/pydantic/plugin/__init__.py index fdc2aa1083..106082c2cb 100644 --- a/pydantic/plugin/__init__.py +++ b/pydantic/plugin/__init__.py @@ -27,6 +27,7 @@ class PydanticPluginProtocol(Protocol): def new_schema_validator( self, schema: CoreSchema, + source_type: str, type_path: str, item_type: str, config: CoreConfig | None, @@ -42,6 +43,7 @@ def new_schema_validator( Args: schema: The schema to validate against. + source_type: The item to validate against. type_path: The path of item to validate against. item_type: The type of item to validate against. config: The config to use for validation. diff --git a/pydantic/plugin/_schema_validator.py b/pydantic/plugin/_schema_validator.py index 638fc8bcc4..ed6bc1daae 100644 --- a/pydantic/plugin/_schema_validator.py +++ b/pydantic/plugin/_schema_validator.py @@ -23,6 +23,7 @@ def build_type_path(module: str, name: str) -> str: def create_schema_validator( schema: CoreSchema, + source_type: Any, module: str, type_name: str, item_type: str, @@ -39,7 +40,7 @@ def create_schema_validator( plugins = get_plugins() if plugins: type_path = build_type_path(module, type_name) - return PluggableSchemaValidator(schema, type_path, item_type, config, plugins, plugin_settings or {}) # type: ignore + return PluggableSchemaValidator(schema, source_type, type_path, item_type, config, plugins, plugin_settings or {}) # type: ignore else: return SchemaValidator(schema, config) @@ -52,6 +53,7 @@ class PluggableSchemaValidator: def __init__( self, schema: CoreSchema, + source_type: Any, type_path: str, item_type: str, config: CoreConfig | None, @@ -64,7 +66,7 @@ def __init__( json_event_handlers: list[BaseValidateHandlerProtocol] = [] strings_event_handlers: list[BaseValidateHandlerProtocol] = [] for plugin in plugins: - p, j, s = plugin.new_schema_validator(schema, type_path, item_type, config, plugin_settings) + p, j, s = plugin.new_schema_validator(schema, source_type, type_path, item_type, config, plugin_settings) if p is not None: python_event_handlers.append(p) if j is not None: diff --git a/pydantic/type_adapter.py b/pydantic/type_adapter.py index e6c2d47b93..4b8dda4b48 100644 --- a/pydantic/type_adapter.py +++ b/pydantic/type_adapter.py @@ -272,7 +272,7 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth if module is None: f = sys._getframe(1) module = f.f_globals['__name__'] - validator = create_schema_validator(core_schema, module, type, 'type_adapter', core_config, config_wrapper.plugin_settings) # type: ignore + validator = create_schema_validator(core_schema, type, module, str(type), 'type_adapter', core_config, config_wrapper.plugin_settings) # type: ignore serializer: SchemaSerializer try: diff --git a/tests/plugin/example_plugin.py b/tests/plugin/example_plugin.py index cb9402a5a5..8f0c70a755 100644 --- a/tests/plugin/example_plugin.py +++ b/tests/plugin/example_plugin.py @@ -24,7 +24,7 @@ def on_error(self, error) -> None: class Plugin: - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): return ValidatePythonHandler(), None, None diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 863727cf13..b667b1c294 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -44,9 +44,10 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class CustomPlugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} + assert source_type.__name__ == 'Model' assert type_path == 'tests.test_plugins:test_on_validate_json_on_success..Model' assert item_type == 'BaseModel' return None, CustomOnValidateJson(), None @@ -88,7 +89,7 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} return None, CustomOnValidateJson(), None @@ -124,9 +125,10 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class Plugin: - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} + assert source_type.__name__ == 'Model' assert item_type == 'BaseModel' return CustomOnValidatePython(), None, None @@ -168,9 +170,10 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} + assert source_type.__name__ == 'Model' assert item_type == 'BaseModel' return CustomOnValidatePython(), None, None @@ -210,7 +213,7 @@ def on_exception(self, exception: Exception) -> None: stack.pop() class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): return CustomOnValidatePython(), None, None plugin = Plugin() @@ -273,7 +276,7 @@ def on_error(self, error: ValidationError) -> None: log.append(f'strings error error={error}') class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): return Python(), Json(), Strings() plugin = Plugin() @@ -306,7 +309,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + assert source_type.__name__ == 'Bar' assert type_path == 'tests.test_plugins:test_plugin_path_dataclass..Bar' assert item_type == 'dataclass' return CustomOnValidatePython(), None, None @@ -324,7 +328,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + assert str(source_type) == 'typing.List[str]' assert type_path == 'tests.test_plugins:typing.List[str]' assert item_type == 'type_adapter' return CustomOnValidatePython(), None, None @@ -339,7 +344,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + assert str(source_type) == 'typing.List[str]' assert type_path == 'provided_module_by_type_adapter:typing.List[str]' assert item_type == 'type_adapter' return CustomOnValidatePython(), None, None @@ -354,7 +360,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin1: - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + assert source_type.__name__ == 'foo' assert type_path == 'tests.test_plugins:test_plugin_path_validate_call..foo' assert item_type == 'validate_call' return CustomOnValidatePython(), None, None @@ -367,7 +374,8 @@ def foo(a: int): return a class Plugin2: - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + assert source_type.__name__ == 'my_wrapped_function' assert ( type_path == 'tests.test_plugins:partial(test_plugin_path_validate_call..my_wrapped_function)' ) @@ -389,7 +397,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + assert source_type.__name__ == 'FooModel' + assert list(source_type.model_fields.keys()) == ['foo', 'bar'] assert type_path == 'tests.test_plugins:FooModel' assert item_type == 'create_model' return CustomOnValidatePython(), None, None @@ -406,25 +416,25 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings): - paths.append((type_path, item_type)) + def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + paths.append((source_type.__name__, type_path, item_type)) return CustomOnValidatePython(), None, None plugin = Plugin() with install_plugin(plugin): def foo(): - class Model(BaseModel): + class Model1(BaseModel): pass def bar(): - class Model(BaseModel): + class Model2(BaseModel): pass foo() bar() assert paths == [ - ('tests.test_plugins:test_plugin_path_complex..foo..Model', 'BaseModel'), - ('tests.test_plugins:test_plugin_path_complex..bar..Model', 'BaseModel'), + ('Model1', 'tests.test_plugins:test_plugin_path_complex..foo..Model1', 'BaseModel'), + ('Model2', 'tests.test_plugins:test_plugin_path_complex..bar..Model2', 'BaseModel'), ] From 586d460ec30158b6aaaab32763a18d9008f02378 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Thu, 26 Oct 2023 11:08:40 +0330 Subject: [PATCH 07/11] Address comments --- docs/concepts/plugins.md | 14 +++- pydantic/_internal/_validate_call.py | 8 +-- pydantic/plugin/__init__.py | 21 +++--- pydantic/plugin/_schema_validator.py | 27 ++++---- pydantic/type_adapter.py | 8 ++- tests/plugin/example_plugin.py | 2 +- tests/test_plugins.py | 99 ++++++++++++++++------------ 7 files changed, 103 insertions(+), 76 deletions(-) diff --git a/docs/concepts/plugins.md b/docs/concepts/plugins.md index e06c5bea88..55716889cd 100644 --- a/docs/concepts/plugins.md +++ b/docs/concepts/plugins.md @@ -58,10 +58,12 @@ Let's see an example of a plugin that _wraps_ the `validate_python` method of th from typing import Any, Dict, Optional, Union from pydantic_core import CoreConfig, CoreSchema, ValidationError +from typing_extensions import Literal from pydantic.plugin import ( NewSchemaReturns, PydanticPluginProtocol, + SchemaTypePath, ValidatePythonHandlerProtocol, ) @@ -89,9 +91,15 @@ class Plugin(PydanticPluginProtocol): def new_schema_validator( self, schema: CoreSchema, - source_type: str, - type_path: str, - item_type: str, + schema_type: str, + schema_type_path: SchemaTypePath, + schema_kind: Literal[ + 'BaseModel', + 'TypeAdapter', + 'dataclass', + 'create_model', + 'validate_call', + ], config: Union[CoreConfig, None], plugin_settings: Dict[str, object], ) -> NewSchemaReturns: diff --git a/pydantic/_internal/_validate_call.py b/pydantic/_internal/_validate_call.py index 30235ff435..543b064b95 100644 --- a/pydantic/_internal/_validate_call.py +++ b/pydantic/_internal/_validate_call.py @@ -46,14 +46,14 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali self.__signature__ = inspect.signature(function) if isinstance(function, partial): func = function.func - source_type = func + schema_type = func self.__name__ = f'partial({func.__name__})' self.__qualname__ = f'partial({func.__qualname__})' self.__annotations__ = func.__annotations__ self.__module__ = func.__module__ self.__doc__ = func.__doc__ else: - source_type = function + schema_type = function self.__name__ = function.__name__ self.__qualname__ = function.__qualname__ self.__annotations__ = function.__annotations__ @@ -69,7 +69,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali self.__pydantic_validator__ = create_schema_validator( schema, - source_type, + schema_type, self.__module__, self.__qualname__, 'validate_call', @@ -88,7 +88,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali self.__return_pydantic_core_schema__ = schema validator = create_schema_validator( schema, - source_type, + schema_type, self.__module__, self.__qualname__, 'validate_call', diff --git a/pydantic/plugin/__init__.py b/pydantic/plugin/__init__.py index 106082c2cb..8de1cc76ec 100644 --- a/pydantic/plugin/__init__.py +++ b/pydantic/plugin/__init__.py @@ -4,10 +4,10 @@ """ from __future__ import annotations -from typing import Any, Callable +from typing import Any, Callable, NamedTuple from pydantic_core import CoreConfig, CoreSchema, ValidationError -from typing_extensions import Protocol, TypeAlias +from typing_extensions import Literal, Protocol, TypeAlias __all__ = ( 'PydanticPluginProtocol', @@ -21,15 +21,20 @@ NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]' +class SchemaTypePath(NamedTuple): + module: str + name: str + + class PydanticPluginProtocol(Protocol): """Protocol defining the interface for Pydantic plugins.""" def new_schema_validator( self, schema: CoreSchema, - source_type: str, - type_path: str, - item_type: str, + schema_type: Any, + schema_type_path: SchemaTypePath, + schema_kind: Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call'], config: CoreConfig | None, plugin_settings: dict[str, object], ) -> tuple[ @@ -43,9 +48,9 @@ def new_schema_validator( Args: schema: The schema to validate against. - source_type: The item to validate against. - type_path: The path of item to validate against. - item_type: The type of item to validate against. + schema_type: The schema to validate against. + schema_type_path: The path of schema to validate against. + schema_kind: The kind of schema to validate against. config: The config to use for validation. plugin_settings: Any plugin settings. diff --git a/pydantic/plugin/_schema_validator.py b/pydantic/plugin/_schema_validator.py index ed6bc1daae..5506840ccd 100644 --- a/pydantic/plugin/_schema_validator.py +++ b/pydantic/plugin/_schema_validator.py @@ -7,6 +7,8 @@ from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError from typing_extensions import Literal, ParamSpec +from pydantic.plugin import SchemaTypePath + if TYPE_CHECKING: from . import BaseValidateHandlerProtocol, PydanticPluginProtocol @@ -17,16 +19,12 @@ events: list[Event] = list(Event.__args__) # type: ignore -def build_type_path(module: str, name: str) -> str: - return f'{module}:{name}' - - def create_schema_validator( schema: CoreSchema, - source_type: Any, - module: str, - type_name: str, - item_type: str, + schema_type: Any, + schema_type_module: str, + schema_type_name: str, + schema_kind: Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call'], config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None, ) -> SchemaValidator: @@ -39,8 +37,7 @@ def create_schema_validator( plugins = get_plugins() if plugins: - type_path = build_type_path(module, type_name) - return PluggableSchemaValidator(schema, source_type, type_path, item_type, config, plugins, plugin_settings or {}) # type: ignore + return PluggableSchemaValidator(schema, schema_type, SchemaTypePath(schema_type_module, schema_type_name), schema_kind, config, plugins, plugin_settings or {}) # type: ignore else: return SchemaValidator(schema, config) @@ -53,9 +50,9 @@ class PluggableSchemaValidator: def __init__( self, schema: CoreSchema, - source_type: Any, - type_path: str, - item_type: str, + schema_type: Any, + schema_type_path: SchemaTypePath, + schema_kind: Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call'], config: CoreConfig | None, plugins: Iterable[PydanticPluginProtocol], plugin_settings: dict[str, Any], @@ -66,7 +63,9 @@ def __init__( json_event_handlers: list[BaseValidateHandlerProtocol] = [] strings_event_handlers: list[BaseValidateHandlerProtocol] = [] for plugin in plugins: - p, j, s = plugin.new_schema_validator(schema, source_type, type_path, item_type, config, plugin_settings) + p, j, s = plugin.new_schema_validator( + schema, schema_type, schema_type_path, schema_kind, config, plugin_settings + ) if p is not None: python_event_handlers.append(p) if j is not None: diff --git a/pydantic/type_adapter.py b/pydantic/type_adapter.py index 4b8dda4b48..b9a725b69e 100644 --- a/pydantic/type_adapter.py +++ b/pydantic/type_adapter.py @@ -219,14 +219,16 @@ def __init__( ) -> None: ... - def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None) -> None: + def __init__( + self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None + ) -> None: """Initializes the TypeAdapter object. Args: type: The type associated with the `TypeAdapter`. config: Configuration for the `TypeAdapter`, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict]. _parent_depth: depth at which to search the parent namespace to construct the local namespace. - + module: The module that passes to plugin if provided. !!! note You cannot use the `config` argument when instantiating a `TypeAdapter` if the type you're using has its own @@ -272,7 +274,7 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth if module is None: f = sys._getframe(1) module = f.f_globals['__name__'] - validator = create_schema_validator(core_schema, type, module, str(type), 'type_adapter', core_config, config_wrapper.plugin_settings) # type: ignore + validator = create_schema_validator(core_schema, type, module, str(type), 'TypeAdapter', core_config, config_wrapper.plugin_settings) # type: ignore serializer: SchemaSerializer try: diff --git a/tests/plugin/example_plugin.py b/tests/plugin/example_plugin.py index 8f0c70a755..8f53787c1c 100644 --- a/tests/plugin/example_plugin.py +++ b/tests/plugin/example_plugin.py @@ -24,7 +24,7 @@ def on_error(self, error) -> None: class Plugin: - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): return ValidatePythonHandler(), None, None diff --git a/tests/test_plugins.py b/tests/test_plugins.py index b667b1c294..0eb7d4c9b3 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, TypeAdapter, create_model, dataclasses, field_validator, validate_call from pydantic.plugin import ( PydanticPluginProtocol, + SchemaTypePath, ValidateJsonHandlerProtocol, ValidatePythonHandlerProtocol, ValidateStringsHandlerProtocol, @@ -44,12 +45,14 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class CustomPlugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} - assert source_type.__name__ == 'Model' - assert type_path == 'tests.test_plugins:test_on_validate_json_on_success..Model' - assert item_type == 'BaseModel' + assert schema_type.__name__ == 'Model' + assert schema_type_path == SchemaTypePath( + 'tests.test_plugins', 'test_on_validate_json_on_success..Model' + ) + assert schema_kind == 'BaseModel' return None, CustomOnValidateJson(), None plugin = CustomPlugin() @@ -89,7 +92,7 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} return None, CustomOnValidateJson(), None @@ -125,11 +128,11 @@ def on_success(self, result: Any) -> None: assert isinstance(result, Model) class Plugin: - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} - assert source_type.__name__ == 'Model' - assert item_type == 'BaseModel' + assert schema_type.__name__ == 'Model' + assert schema_kind == 'BaseModel' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -170,11 +173,11 @@ def on_error(self, error: ValidationError) -> None: ] class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): assert config == {'title': 'Model'} assert plugin_settings == {'observe': 'all'} - assert source_type.__name__ == 'Model' - assert item_type == 'BaseModel' + assert schema_type.__name__ == 'Model' + assert schema_kind == 'BaseModel' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -213,7 +216,7 @@ def on_exception(self, exception: Exception) -> None: stack.pop() class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): return CustomOnValidatePython(), None, None plugin = Plugin() @@ -276,7 +279,7 @@ def on_error(self, error: ValidationError) -> None: log.append(f'strings error error={error}') class Plugin(PydanticPluginProtocol): - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): return Python(), Json(), Strings() plugin = Plugin() @@ -309,10 +312,10 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): - assert source_type.__name__ == 'Bar' - assert type_path == 'tests.test_plugins:test_plugin_path_dataclass..Bar' - assert item_type == 'dataclass' + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): + assert schema_type.__name__ == 'Bar' + assert schema_type_path == SchemaTypePath('tests.test_plugins', 'test_plugin_path_dataclass..Bar') + assert schema_kind == 'dataclass' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -328,10 +331,10 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): - assert str(source_type) == 'typing.List[str]' - assert type_path == 'tests.test_plugins:typing.List[str]' - assert item_type == 'type_adapter' + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): + assert str(schema_type) == 'typing.List[str]' + assert schema_type_path == SchemaTypePath('tests.test_plugins', 'typing.List[str]') + assert schema_kind == 'TypeAdapter' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -344,10 +347,10 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): - assert str(source_type) == 'typing.List[str]' - assert type_path == 'provided_module_by_type_adapter:typing.List[str]' - assert item_type == 'type_adapter' + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): + assert str(schema_type) == 'typing.List[str]' + assert schema_type_path == SchemaTypePath('provided_module_by_type_adapter', 'typing.List[str]') + assert schema_kind == 'TypeAdapter' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -360,10 +363,12 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin1: - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): - assert source_type.__name__ == 'foo' - assert type_path == 'tests.test_plugins:test_plugin_path_validate_call..foo' - assert item_type == 'validate_call' + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): + assert schema_type.__name__ == 'foo' + assert schema_type_path == SchemaTypePath( + 'tests.test_plugins', 'test_plugin_path_validate_call..foo' + ) + assert schema_kind == 'validate_call' return CustomOnValidatePython(), None, None plugin = Plugin1() @@ -374,12 +379,12 @@ def foo(a: int): return a class Plugin2: - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): - assert source_type.__name__ == 'my_wrapped_function' - assert ( - type_path == 'tests.test_plugins:partial(test_plugin_path_validate_call..my_wrapped_function)' + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): + assert schema_type.__name__ == 'my_wrapped_function' + assert schema_type_path == SchemaTypePath( + 'tests.test_plugins', 'partial(test_plugin_path_validate_call..my_wrapped_function)' ) - assert item_type == 'validate_call' + assert schema_kind == 'validate_call' return CustomOnValidatePython(), None, None plugin = Plugin2() @@ -397,11 +402,11 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): - assert source_type.__name__ == 'FooModel' - assert list(source_type.model_fields.keys()) == ['foo', 'bar'] - assert type_path == 'tests.test_plugins:FooModel' - assert item_type == 'create_model' + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): + assert schema_type.__name__ == 'FooModel' + assert list(schema_type.model_fields.keys()) == ['foo', 'bar'] + assert schema_type_path == SchemaTypePath('tests.test_plugins', 'FooModel') + assert schema_kind == 'create_model' return CustomOnValidatePython(), None, None plugin = Plugin() @@ -416,8 +421,8 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol): pass class Plugin: - def new_schema_validator(self, schema, source_type, type_path, item_type, config, plugin_settings): - paths.append((source_type.__name__, type_path, item_type)) + def new_schema_validator(self, schema, schema_type, schema_type_path, schema_kind, config, plugin_settings): + paths.append((schema_type.__name__, schema_type_path, schema_kind)) return CustomOnValidatePython(), None, None plugin = Plugin() @@ -435,6 +440,14 @@ class Model2(BaseModel): bar() assert paths == [ - ('Model1', 'tests.test_plugins:test_plugin_path_complex..foo..Model1', 'BaseModel'), - ('Model2', 'tests.test_plugins:test_plugin_path_complex..bar..Model2', 'BaseModel'), + ( + 'Model1', + SchemaTypePath('tests.test_plugins', 'test_plugin_path_complex..foo..Model1'), + 'BaseModel', + ), + ( + 'Model2', + SchemaTypePath('tests.test_plugins', 'test_plugin_path_complex..bar..Model2'), + 'BaseModel', + ), ] From 218b61b97545e11dabc9e5557cbbdfaa70a044e9 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Tue, 31 Oct 2023 10:46:14 +0330 Subject: [PATCH 08/11] Address comments --- docs/concepts/plugins.md | 2 +- pydantic/_internal/_model_construction.py | 14 +++++++------- pydantic/main.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/concepts/plugins.md b/docs/concepts/plugins.md index 55716889cd..d76b42e7dc 100644 --- a/docs/concepts/plugins.md +++ b/docs/concepts/plugins.md @@ -91,7 +91,7 @@ class Plugin(PydanticPluginProtocol): def new_schema_validator( self, schema: CoreSchema, - schema_type: str, + schema_type: Any, schema_type_path: SchemaTypePath, schema_kind: Literal[ 'BaseModel', diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 70d241eef2..0728c22e47 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -63,7 +63,7 @@ def __new__( namespace: dict[str, Any], __pydantic_generic_metadata__: PydanticGenericMetadata | None = None, __pydantic_reset_parent_namespace__: bool = True, - cls_module: str | None = None, + _cls_module: str | None = None, **kwargs: Any, ) -> type: """Metaclass for creating Pydantic models. @@ -74,7 +74,7 @@ def __new__( namespace: The attribute dictionary of the class to be created. __pydantic_generic_metadata__: Metadata for generic models. __pydantic_reset_parent_namespace__: Reset parent namespace. - cls_module: The module of the class to be created. + _cls_module: The module of the class to be created. **kwargs: Catch-all for any other keyword arguments. Returns: @@ -184,7 +184,7 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None: config_wrapper, raise_errors=False, types_namespace=types_namespace, - cls_module=cls_module, + _cls_module=_cls_module, ) # using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__ # I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is @@ -441,7 +441,7 @@ def complete_model_class( *, raise_errors: bool = True, types_namespace: dict[str, Any] | None, - cls_module: str | None = None, + _cls_module: str | None = None, ) -> bool: """Finish building a model class. @@ -454,7 +454,7 @@ def complete_model_class( config_wrapper: The config wrapper instance. raise_errors: Whether to raise errors. types_namespace: Optional extra namespace to look for types in. - cls_module: The module of the class to be created. + _cls_module: The module of the class to be created. Returns: `True` if the model is successfully completed, else `False`. @@ -502,9 +502,9 @@ def complete_model_class( cls.__pydantic_validator__ = create_schema_validator( schema, cls, - cls_module if cls_module else cls.__module__, + _cls_module if _cls_module else cls.__module__, cls.__qualname__, - 'create_model' if cls_module else 'BaseModel', + 'create_model' if _cls_module else 'BaseModel', core_config, config_wrapper.plugin_settings, ) diff --git a/pydantic/main.py b/pydantic/main.py index a615f7177a..c2a6b71096 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -1434,14 +1434,14 @@ def create_model( namespace.update(ns) f = sys._getframe(1) - cls_module = f.f_globals['__name__'] + _cls_module = f.f_globals['__name__'] return meta( __model_name, resolved_bases, namespace, __pydantic_reset_parent_namespace__=False, - cls_module=cls_module, + _cls_module=_cls_module, **kwds, ) From 4624c9fa9db59a3d5cd31f5d7e032714d585992d Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Tue, 31 Oct 2023 10:56:33 +0330 Subject: [PATCH 09/11] Fix formatting --- pydantic/plugin/_schema_validator.py | 10 +++++++++- pydantic/type_adapter.py | 8 +++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pydantic/plugin/_schema_validator.py b/pydantic/plugin/_schema_validator.py index 5506840ccd..7ee21f820e 100644 --- a/pydantic/plugin/_schema_validator.py +++ b/pydantic/plugin/_schema_validator.py @@ -37,7 +37,15 @@ def create_schema_validator( plugins = get_plugins() if plugins: - return PluggableSchemaValidator(schema, schema_type, SchemaTypePath(schema_type_module, schema_type_name), schema_kind, config, plugins, plugin_settings or {}) # type: ignore + return PluggableSchemaValidator( + schema, + schema_type, + SchemaTypePath(schema_type_module, schema_type_name), + schema_kind, + config, + plugins, + plugin_settings or {}, + ) # type: ignore else: return SchemaValidator(schema, config) diff --git a/pydantic/type_adapter.py b/pydantic/type_adapter.py index b9a725b69e..72d8b918f4 100644 --- a/pydantic/type_adapter.py +++ b/pydantic/type_adapter.py @@ -78,7 +78,7 @@ class Item(BaseModel): import sys from dataclasses import is_dataclass -from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, cast, overload from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some from typing_extensions import Literal, is_typeddict @@ -273,8 +273,10 @@ def __init__( except AttributeError: if module is None: f = sys._getframe(1) - module = f.f_globals['__name__'] - validator = create_schema_validator(core_schema, type, module, str(type), 'TypeAdapter', core_config, config_wrapper.plugin_settings) # type: ignore + module = cast(str, f.f_globals['__name__']) + validator = create_schema_validator( + core_schema, type, module, str(type), 'TypeAdapter', core_config, config_wrapper.plugin_settings + ) # type: ignore serializer: SchemaSerializer try: From 9162495ad011138491bc6fdba52de7c4d81f235e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 31 Oct 2023 15:14:24 +0000 Subject: [PATCH 10/11] type tweaks --- pydantic/_internal/_model_construction.py | 14 +++++++------- pydantic/main.py | 19 ++++++++++--------- pydantic/plugin/__init__.py | 9 ++++++--- pydantic/plugin/_schema_validator.py | 9 ++++----- 4 files changed, 27 insertions(+), 24 deletions(-) diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 0728c22e47..8a43b1b4a8 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -63,7 +63,7 @@ def __new__( namespace: dict[str, Any], __pydantic_generic_metadata__: PydanticGenericMetadata | None = None, __pydantic_reset_parent_namespace__: bool = True, - _cls_module: str | None = None, + _create_model_module: str | None = None, **kwargs: Any, ) -> type: """Metaclass for creating Pydantic models. @@ -74,7 +74,7 @@ def __new__( namespace: The attribute dictionary of the class to be created. __pydantic_generic_metadata__: Metadata for generic models. __pydantic_reset_parent_namespace__: Reset parent namespace. - _cls_module: The module of the class to be created. + _create_model_module: The module of the class to be created, if created by `create_model`. **kwargs: Catch-all for any other keyword arguments. Returns: @@ -184,7 +184,7 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None: config_wrapper, raise_errors=False, types_namespace=types_namespace, - _cls_module=_cls_module, + create_model_module=_create_model_module, ) # using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__ # I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is @@ -441,7 +441,7 @@ def complete_model_class( *, raise_errors: bool = True, types_namespace: dict[str, Any] | None, - _cls_module: str | None = None, + create_model_module: str | None = None, ) -> bool: """Finish building a model class. @@ -454,7 +454,7 @@ def complete_model_class( config_wrapper: The config wrapper instance. raise_errors: Whether to raise errors. types_namespace: Optional extra namespace to look for types in. - _cls_module: The module of the class to be created. + create_model_module: The module of the class to be created, if created by `create_model`. Returns: `True` if the model is successfully completed, else `False`. @@ -502,9 +502,9 @@ def complete_model_class( cls.__pydantic_validator__ = create_schema_validator( schema, cls, - _cls_module if _cls_module else cls.__module__, + create_model_module or cls.__module__, cls.__qualname__, - 'create_model' if _cls_module else 'BaseModel', + 'create_model' if create_model_module else 'BaseModel', core_config, config_wrapper.plugin_settings, ) diff --git a/pydantic/main.py b/pydantic/main.py index c2a6b71096..10205abc46 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -1346,13 +1346,13 @@ def create_model( ... -def create_model( +def create_model( # noqa: C901 __model_name: str, *, __config__: ConfigDict | None = None, __doc__: str | None = None, __base__: type[Model] | tuple[type[Model], ...] | None = None, - __module__: str = __name__, + __module__: str | None = None, __validators__: dict[str, AnyClassMethod] | None = None, __cls_kwargs__: dict[str, Any] | None = None, __slots__: tuple[str, ...] | None = None, @@ -1366,9 +1366,9 @@ def create_model( __config__: The configuration of the new model. __doc__: The docstring of the new model. __base__: The base class for the new model. - __module__: The name of the module that the model belongs to. - __validators__: A dictionary of methods that validate - fields. + __module__: The name of the module that the model belongs to, + if `None` the value is taken from `sys._getframe(1)` + __validators__: A dictionary of methods that validate fields. __cls_kwargs__: A dictionary of keyword arguments for class creation. __slots__: Deprecated. Should not be passed to `create_model`. **field_definitions: Attributes of the new model. They should be passed in the format: @@ -1419,6 +1419,10 @@ def create_model( annotations[f_name] = f_annotation fields[f_name] = f_value + if __module__ is None: + f = sys._getframe(1) + __module__ = f.f_globals['__name__'] + namespace: dict[str, Any] = {'__annotations__': annotations, '__module__': __module__} if __doc__: namespace.update({'__doc__': __doc__}) @@ -1433,15 +1437,12 @@ def create_model( ns['__orig_bases__'] = __base__ namespace.update(ns) - f = sys._getframe(1) - _cls_module = f.f_globals['__name__'] - return meta( __model_name, resolved_bases, namespace, __pydantic_reset_parent_namespace__=False, - _cls_module=_cls_module, + _create_model_module=__module__, **kwds, ) diff --git a/pydantic/plugin/__init__.py b/pydantic/plugin/__init__.py index 8de1cc76ec..e8b15439b2 100644 --- a/pydantic/plugin/__init__.py +++ b/pydantic/plugin/__init__.py @@ -26,6 +26,9 @@ class SchemaTypePath(NamedTuple): name: str +SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call'] + + class PydanticPluginProtocol(Protocol): """Protocol defining the interface for Pydantic plugins.""" @@ -34,7 +37,7 @@ def new_schema_validator( schema: CoreSchema, schema_type: Any, schema_type_path: SchemaTypePath, - schema_kind: Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call'], + schema_kind: SchemaKind, config: CoreConfig | None, plugin_settings: dict[str, object], ) -> tuple[ @@ -48,8 +51,8 @@ def new_schema_validator( Args: schema: The schema to validate against. - schema_type: The schema to validate against. - schema_type_path: The path of schema to validate against. + schema_type: The original type which the schema was created from, e.g. the model class. + schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called. schema_kind: The kind of schema to validate against. config: The config to use for validation. plugin_settings: Any plugin settings. diff --git a/pydantic/plugin/_schema_validator.py b/pydantic/plugin/_schema_validator.py index 7ee21f820e..3610a32d07 100644 --- a/pydantic/plugin/_schema_validator.py +++ b/pydantic/plugin/_schema_validator.py @@ -7,10 +7,8 @@ from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError from typing_extensions import Literal, ParamSpec -from pydantic.plugin import SchemaTypePath - if TYPE_CHECKING: - from . import BaseValidateHandlerProtocol, PydanticPluginProtocol + from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath P = ParamSpec('P') @@ -24,7 +22,7 @@ def create_schema_validator( schema_type: Any, schema_type_module: str, schema_type_name: str, - schema_kind: Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call'], + schema_kind: SchemaKind, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None, ) -> SchemaValidator: @@ -33,6 +31,7 @@ def create_schema_validator( Returns: If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`. """ + from . import SchemaTypePath from ._loader import get_plugins plugins = get_plugins() @@ -60,7 +59,7 @@ def __init__( schema: CoreSchema, schema_type: Any, schema_type_path: SchemaTypePath, - schema_kind: Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call'], + schema_kind: SchemaKind, config: CoreConfig | None, plugins: Iterable[PydanticPluginProtocol], plugin_settings: dict[str, Any], From 17b712565ece86cc97855362a141776cbb949870 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 31 Oct 2023 15:21:28 +0000 Subject: [PATCH 11/11] fix tests --- docs/concepts/models.md | 2 +- docs/concepts/plugins.md | 10 ++-------- pydantic/plugin/__init__.py | 4 ++++ pydantic/types.py | 9 +++++---- tests/test_create_model.py | 2 +- tests/test_json_schema.py | 13 +++++++------ 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/concepts/models.md b/docs/concepts/models.md index d52edf8e56..3d1786d4e2 100644 --- a/docs/concepts/models.md +++ b/docs/concepts/models.md @@ -1057,7 +1057,7 @@ BarModel = create_model( __base__=FooModel, ) print(BarModel) -#> +#> print(BarModel.model_fields.keys()) #> dict_keys(['foo', 'bar', 'apple', 'banana']) ``` diff --git a/docs/concepts/plugins.md b/docs/concepts/plugins.md index d76b42e7dc..1e5e054abe 100644 --- a/docs/concepts/plugins.md +++ b/docs/concepts/plugins.md @@ -58,11 +58,11 @@ Let's see an example of a plugin that _wraps_ the `validate_python` method of th from typing import Any, Dict, Optional, Union from pydantic_core import CoreConfig, CoreSchema, ValidationError -from typing_extensions import Literal from pydantic.plugin import ( NewSchemaReturns, PydanticPluginProtocol, + SchemaKind, SchemaTypePath, ValidatePythonHandlerProtocol, ) @@ -93,13 +93,7 @@ class Plugin(PydanticPluginProtocol): schema: CoreSchema, schema_type: Any, schema_type_path: SchemaTypePath, - schema_kind: Literal[ - 'BaseModel', - 'TypeAdapter', - 'dataclass', - 'create_model', - 'validate_call', - ], + schema_kind: SchemaKind, config: Union[CoreConfig, None], plugin_settings: Dict[str, object], ) -> NewSchemaReturns: diff --git a/pydantic/plugin/__init__.py b/pydantic/plugin/__init__.py index e8b15439b2..ba26a41ea4 100644 --- a/pydantic/plugin/__init__.py +++ b/pydantic/plugin/__init__.py @@ -16,12 +16,16 @@ 'ValidateJsonHandlerProtocol', 'ValidateStringsHandlerProtocol', 'NewSchemaReturns', + 'SchemaTypePath', + 'SchemaKind', ) NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]' class SchemaTypePath(NamedTuple): + """Path defining where `schema_type` was defined, or where `TypeAdapter` was called.""" + module: str name: str diff --git a/pydantic/types.py b/pydantic/types.py index 02bdf8f2ca..92ef65f0b4 100644 --- a/pydantic/types.py +++ b/pydantic/types.py @@ -992,6 +992,7 @@ class Foo(BaseModel): === ":white_check_mark: Do this" ```py from decimal import Decimal + from typing_extensions import Annotated from pydantic import BaseModel, Field @@ -1091,7 +1092,7 @@ def __hash__(self) -> int: ```py import uuid -from pydantic import BaseModel, UUID1 +from pydantic import UUID1, BaseModel class Model(BaseModel): uuid1: UUID1 @@ -1105,7 +1106,7 @@ class Model(BaseModel): ```py import uuid -from pydantic import BaseModel, UUID3 +from pydantic import UUID3, BaseModel class Model(BaseModel): uuid3: UUID3 @@ -1119,7 +1120,7 @@ class Model(BaseModel): ```py import uuid -from pydantic import BaseModel, UUID4 +from pydantic import UUID4, BaseModel class Model(BaseModel): uuid4: UUID4 @@ -1133,7 +1134,7 @@ class Model(BaseModel): ```py import uuid -from pydantic import BaseModel, UUID5 +from pydantic import UUID5, BaseModel class Model(BaseModel): uuid5: UUID5 diff --git a/tests/test_create_model.py b/tests/test_create_model.py index 2721ea8e02..f43e0bf44a 100644 --- a/tests/test_create_model.py +++ b/tests/test_create_model.py @@ -31,7 +31,7 @@ def test_create_model(): assert not model.__pydantic_decorators__.field_validators assert not model.__pydantic_decorators__.field_serializers - assert model.__module__ == 'pydantic.main' + assert model.__module__ == 'tests.test_create_model' def test_create_model_usage(): diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 4e0b4b4826..675b52a4a4 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -5419,24 +5419,25 @@ def test_multiple_models_with_same_qualname(): outer_a2=(model_a2, ...), ) + # insert_assert(model_c.model_json_schema()) assert model_c.model_json_schema() == { '$defs': { - 'pydantic__main__A__1': { - 'properties': {'inner_a1': {'title': 'Inner ' 'A1', 'type': 'string'}}, + 'tests__test_json_schema__A__1': { + 'properties': {'inner_a1': {'title': 'Inner A1', 'type': 'string'}}, 'required': ['inner_a1'], 'title': 'A', 'type': 'object', }, - 'pydantic__main__A__2': { - 'properties': {'inner_a2': {'title': 'Inner ' 'A2', 'type': 'string'}}, + 'tests__test_json_schema__A__2': { + 'properties': {'inner_a2': {'title': 'Inner A2', 'type': 'string'}}, 'required': ['inner_a2'], 'title': 'A', 'type': 'object', }, }, 'properties': { - 'outer_a1': {'$ref': '#/$defs/pydantic__main__A__1'}, - 'outer_a2': {'$ref': '#/$defs/pydantic__main__A__2'}, + 'outer_a1': {'$ref': '#/$defs/tests__test_json_schema__A__1'}, + 'outer_a2': {'$ref': '#/$defs/tests__test_json_schema__A__2'}, }, 'required': ['outer_a1', 'outer_a2'], 'title': 'B',