Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hramezani committed Oct 20, 2023
1 parent 0d58e87 commit 8e0c732
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 49 deletions.
1 change: 1 addition & 0 deletions docs/concepts/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class Plugin(PydanticPluginProtocol):
def new_schema_validator(
self,
schema: CoreSchema,
type_path: str,
config: Union[CoreConfig, None],
plugin_settings: Dict[str, object],
) -> NewSchemaReturns:
Expand Down
2 changes: 1 addition & 1 deletion pydantic/_internal/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -

cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, f'{cls.__module__}:{cls.__qualname__}', core_config, config_wrapper.plugin_settings
schema, cls.__module__, cls.__qualname__, core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)

Expand Down
17 changes: 7 additions & 10 deletions pydantic/_internal/_model_construction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Private logic for creating models."""
from __future__ import annotations as _annotations

import sys
import typing
import warnings
import weakref
Expand Down Expand Up @@ -77,6 +76,7 @@ def __new__(
namespace: dict[str, Any],
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
__pydantic_reset_parent_namespace__: bool = True,
cls_module: str | None = None,
**kwargs: Any,
) -> type:
"""Metaclass for creating Pydantic models.
Expand All @@ -87,6 +87,7 @@ def __new__(
namespace: The attribute dictionary of the class to be created.
__pydantic_generic_metadata__: Metadata for generic models.
__pydantic_reset_parent_namespace__: Reset parent namespace.
cls_module: The module of the class to be created.
**kwargs: Catch-all for any other keyword arguments.
Returns:
Expand All @@ -95,7 +96,6 @@ def __new__(
# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact
# that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__`
# call we're in the middle of is for the `BaseModel` class.
is_dynamic_model: bool = kwargs.pop('is_dynamic_model', False)
if bases:
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)

Expand Down Expand Up @@ -187,7 +187,7 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
config_wrapper,
raise_errors=False,
types_namespace=types_namespace,
is_dynamic_model=is_dynamic_model,
cls_module=cls_module,
)
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
Expand Down Expand Up @@ -442,7 +442,7 @@ def complete_model_class(
*,
raise_errors: bool = True,
types_namespace: dict[str, Any] | None,
is_dynamic_model: bool = False,
cls_module: str | None = None,
) -> bool:
"""Finish building a model class.
Expand All @@ -455,7 +455,7 @@ def complete_model_class(
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors.
types_namespace: Optional extra namespace to look for types in.
is_dynamic_model: Whether the model is a dynamic model (function is called from `create_model`).
cls_module: The module of the class to be created.
Returns:
`True` if the model is successfully completed, else `False`.
Expand Down Expand Up @@ -500,14 +500,11 @@ def complete_model_class(
# debug(schema)
cls.__pydantic_core_schema__ = schema

if is_dynamic_model:
f = sys._getframe(3)
cls_module = f.f_globals['__name__']
else:
if cls_module is None:
cls_module = cls.__module__

cls.__pydantic_validator__ = create_schema_validator(
schema, f'{cls_module}:{cls.__qualname__}', core_config, config_wrapper.plugin_settings
schema, cls_module, cls.__qualname__, core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True
Expand Down
9 changes: 6 additions & 3 deletions pydantic/_internal/_validate_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
self.__pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)
path = f'{self.__module__}:{self.__qualname__}'

self.__pydantic_validator__ = create_schema_validator(schema, path, core_config, config_wrapper.plugin_settings)
self.__pydantic_validator__ = create_schema_validator(
schema, self.__module__, self.__qualname__, core_config, config_wrapper.plugin_settings
)

if self._validate_return:
return_type = (
Expand All @@ -77,7 +78,9 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
self.__return_pydantic_core_schema__ = schema
validator = create_schema_validator(schema, path, core_config, config_wrapper.plugin_settings)
validator = create_schema_validator(
schema, self.__module__, self.__qualname__, core_config, config_wrapper.plugin_settings
)
if inspect.iscoroutinefunction(self.raw_function):

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
Expand Down
15 changes: 13 additions & 2 deletions pydantic/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Logic for creating models."""
from __future__ import annotations as _annotations

import sys
import types
import typing
import warnings
Expand Down Expand Up @@ -1425,8 +1426,18 @@ def create_model(
if resolved_bases is not __base__:
ns['__orig_bases__'] = __base__
namespace.update(ns)
kwds['is_dynamic_model'] = True
return meta(__model_name, resolved_bases, namespace, __pydantic_reset_parent_namespace__=False, **kwds)

f = sys._getframe(1)
cls_module = f.f_globals['__name__']

return meta(
__model_name,
resolved_bases,
namespace,
__pydantic_reset_parent_namespace__=False,
cls_module=cls_module,
**kwds,
)


__getattr__ = getattr_migration(__name__)
4 changes: 2 additions & 2 deletions pydantic/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PydanticPluginProtocol(Protocol):
def new_schema_validator(
self,
schema: CoreSchema,
path: str,
type_path: str,
config: CoreConfig | None,
plugin_settings: dict[str, object],
) -> tuple[
Expand All @@ -41,7 +41,7 @@ def new_schema_validator(
Args:
schema: The schema to validate against.
path: The path of item to validate against.
type_path: The path of item to validate against.
config: The config to use for validation.
plugin_settings: Any plugin settings.
Expand Down
17 changes: 13 additions & 4 deletions pydantic/plugin/_schema_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@
events: list[Event] = list(Event.__args__) # type: ignore


def build_type_path(module: str, name: str) -> str:
return f'{module}:{name}'


def create_schema_validator(
schema: CoreSchema, path: str, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None
schema: CoreSchema,
module: str,
type_name: str,
config: CoreConfig | None = None,
plugin_settings: dict[str, Any] | None = None,
) -> SchemaValidator:
"""Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
Expand All @@ -29,7 +37,8 @@ def create_schema_validator(

plugins = get_plugins()
if plugins:
return PluggableSchemaValidator(schema, path, config, plugins, plugin_settings or {}) # type: ignore
type_path = build_type_path(module, type_name)
return PluggableSchemaValidator(schema, type_path, config, plugins, plugin_settings or {}) # type: ignore
else:
return SchemaValidator(schema, config)

Expand All @@ -42,7 +51,7 @@ class PluggableSchemaValidator:
def __init__(
self,
schema: CoreSchema,
path: str,
type_path: str,
config: CoreConfig | None,
plugins: Iterable[PydanticPluginProtocol],
plugin_settings: dict[str, Any],
Expand All @@ -53,7 +62,7 @@ def __init__(
json_event_handlers: list[BaseValidateHandlerProtocol] = []
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
for plugin in plugins:
p, j, s = plugin.new_schema_validator(schema, path, config, plugin_settings)
p, j, s = plugin.new_schema_validator(schema, type_path, config, plugin_settings)
if p is not None:
python_event_handlers.append(p)
if j is not None:
Expand Down
20 changes: 13 additions & 7 deletions pydantic/type_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,23 @@ def __new__(cls, __type: Any, *, config: ConfigDict | None = ...) -> TypeAdapter
raise NotImplementedError

@overload
def __init__(self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
def __init__(
self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None
) -> None:
...

# this overload is for non-type things like Union[int, str]
# Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object]
# so an explicit type cast is needed
@overload
def __init__(self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
def __init__(
self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None
) -> None:
...

def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
def __init__(
self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2, module: str | None = None
) -> None:
"""Initializes the TypeAdapter object."""
config_wrapper = _config.ConfigWrapper(config)

Expand Down Expand Up @@ -245,10 +251,10 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth
try:
validator = _getattr_no_parents(type, '__pydantic_validator__')
except AttributeError:
f = sys._getframe(1)
validator = create_schema_validator(
core_schema, f'{f.f_globals["__name__"]}:{type}', core_config, config_wrapper.plugin_settings
)
if module is None:
f = sys._getframe(1)
module = f.f_globals['__name__']
validator = create_schema_validator(core_schema, module, type, core_config, config_wrapper.plugin_settings) # type: ignore

serializer: SchemaSerializer
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/plugin/example_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def on_error(self, error) -> None:


class Plugin:
def new_schema_validator(self, schema, path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, config, plugin_settings):
return ValidatePythonHandler(), None, None


Expand Down

0 comments on commit 8e0c732

Please sign in to comment.