Skip to content

Commit

Permalink
Add item_type
Browse files Browse the repository at this point in the history
  • Loading branch information
hramezani committed Oct 20, 2023
1 parent 8e0c732 commit 559734b
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/concepts/plugins.md
Expand Up @@ -90,6 +90,7 @@ class Plugin(PydanticPluginProtocol):
self,
schema: CoreSchema,
type_path: str,
item_type: str,
config: Union[CoreConfig, None],
plugin_settings: Dict[str, object],
) -> NewSchemaReturns:
Expand Down
2 changes: 1 addition & 1 deletion pydantic/_internal/_dataclasses.py
Expand Up @@ -170,7 +170,7 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -

cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, cls.__module__, cls.__qualname__, core_config, config_wrapper.plugin_settings
schema, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)

Expand Down
10 changes: 6 additions & 4 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -500,11 +500,13 @@ def complete_model_class(
# debug(schema)
cls.__pydantic_core_schema__ = schema

if cls_module is None:
cls_module = cls.__module__

cls.__pydantic_validator__ = create_schema_validator(
schema, cls_module, cls.__qualname__, core_config, config_wrapper.plugin_settings
schema,
cls_module if cls_module else cls.__module__,
cls.__qualname__,
'create_model' if cls_module else 'BaseModel',
core_config,
config_wrapper.plugin_settings,
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True
Expand Down
4 changes: 2 additions & 2 deletions pydantic/_internal/_validate_call.py
Expand Up @@ -66,7 +66,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
core_config = config_wrapper.core_config(self)

self.__pydantic_validator__ = create_schema_validator(
schema, self.__module__, self.__qualname__, core_config, config_wrapper.plugin_settings
schema, self.__module__, self.__qualname__, 'validate_call', core_config, config_wrapper.plugin_settings
)

if self._validate_return:
Expand All @@ -79,7 +79,7 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
self.__return_pydantic_core_schema__ = schema
validator = create_schema_validator(
schema, self.__module__, self.__qualname__, core_config, config_wrapper.plugin_settings
schema, self.__module__, self.__qualname__, 'validate_call', core_config, config_wrapper.plugin_settings
)
if inspect.iscoroutinefunction(self.raw_function):

Expand Down
2 changes: 2 additions & 0 deletions pydantic/plugin/__init__.py
Expand Up @@ -28,6 +28,7 @@ def new_schema_validator(
self,
schema: CoreSchema,
type_path: str,
item_type: str,
config: CoreConfig | None,
plugin_settings: dict[str, object],
) -> tuple[
Expand All @@ -42,6 +43,7 @@ def new_schema_validator(
Args:
schema: The schema to validate against.
type_path: The path of item to validate against.
item_type: The type of item to validate against.
config: The config to use for validation.
plugin_settings: Any plugin settings.
Expand Down
6 changes: 4 additions & 2 deletions pydantic/plugin/_schema_validator.py
Expand Up @@ -25,6 +25,7 @@ def create_schema_validator(
schema: CoreSchema,
module: str,
type_name: str,
item_type: str,
config: CoreConfig | None = None,
plugin_settings: dict[str, Any] | None = None,
) -> SchemaValidator:
Expand All @@ -38,7 +39,7 @@ def create_schema_validator(
plugins = get_plugins()
if plugins:
type_path = build_type_path(module, type_name)
return PluggableSchemaValidator(schema, type_path, config, plugins, plugin_settings or {}) # type: ignore
return PluggableSchemaValidator(schema, type_path, item_type, config, plugins, plugin_settings or {}) # type: ignore
else:
return SchemaValidator(schema, config)

Expand All @@ -52,6 +53,7 @@ def __init__(
self,
schema: CoreSchema,
type_path: str,
item_type: str,
config: CoreConfig | None,
plugins: Iterable[PydanticPluginProtocol],
plugin_settings: dict[str, Any],
Expand All @@ -62,7 +64,7 @@ def __init__(
json_event_handlers: list[BaseValidateHandlerProtocol] = []
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
for plugin in plugins:
p, j, s = plugin.new_schema_validator(schema, type_path, config, plugin_settings)
p, j, s = plugin.new_schema_validator(schema, type_path, item_type, config, plugin_settings)
if p is not None:
python_event_handlers.append(p)
if j is not None:
Expand Down
2 changes: 1 addition & 1 deletion pydantic/type_adapter.py
Expand Up @@ -254,7 +254,7 @@ def __init__(
if module is None:
f = sys._getframe(1)
module = f.f_globals['__name__']
validator = create_schema_validator(core_schema, module, type, core_config, config_wrapper.plugin_settings) # type: ignore
validator = create_schema_validator(core_schema, module, type, 'type_adapter', core_config, config_wrapper.plugin_settings) # type: ignore

serializer: SchemaSerializer
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/plugin/example_plugin.py
Expand Up @@ -24,7 +24,7 @@ def on_error(self, error) -> None:


class Plugin:
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
return ValidatePythonHandler(), None, None


Expand Down
43 changes: 26 additions & 17 deletions tests/test_plugins.py
Expand Up @@ -44,10 +44,11 @@ def on_success(self, result: Any) -> None:
assert isinstance(result, Model)

class CustomPlugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
assert type_path == 'tests.test_plugins:test_on_validate_json_on_success.<locals>.Model'
assert item_type == 'BaseModel'
return None, CustomOnValidateJson(), None

plugin = CustomPlugin()
Expand Down Expand Up @@ -87,7 +88,7 @@ def on_error(self, error: ValidationError) -> None:
]

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
return None, CustomOnValidateJson(), None
Expand Down Expand Up @@ -123,9 +124,10 @@ def on_success(self, result: Any) -> None:
assert isinstance(result, Model)

class Plugin:
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
assert item_type == 'BaseModel'
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand Down Expand Up @@ -166,9 +168,10 @@ def on_error(self, error: ValidationError) -> None:
]

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert config == {'title': 'Model'}
assert plugin_settings == {'observe': 'all'}
assert item_type == 'BaseModel'
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand Down Expand Up @@ -207,7 +210,7 @@ def on_exception(self, exception: Exception) -> None:
stack.pop()

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand Down Expand Up @@ -270,7 +273,7 @@ def on_error(self, error: ValidationError) -> None:
log.append(f'strings error error={error}')

class Plugin(PydanticPluginProtocol):
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
return Python(), Json(), Strings()

plugin = Plugin()
Expand Down Expand Up @@ -303,8 +306,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert type_path == 'tests.test_plugins:test_plugin_path_dataclass.<locals>.Bar'
assert item_type == 'dataclass'
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand All @@ -320,8 +324,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert type_path == 'tests.test_plugins:typing.List[str]'
assert item_type == 'type_adapter'
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand All @@ -334,8 +339,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert type_path == 'provided_module_by_type_adapter:typing.List[str]'
assert item_type == 'type_adapter'
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand All @@ -348,8 +354,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin1:
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert type_path == 'tests.test_plugins:test_plugin_path_validate_call.<locals>.foo'
assert item_type == 'validate_call'
return CustomOnValidatePython(), None, None

plugin = Plugin1()
Expand All @@ -360,10 +367,11 @@ def foo(a: int):
return a

class Plugin2:
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert (
type_path == 'tests.test_plugins:partial(test_plugin_path_validate_call.<locals>.my_wrapped_function)'
)
assert item_type == 'validate_call'
return CustomOnValidatePython(), None, None

plugin = Plugin2()
Expand All @@ -381,8 +389,9 @@ class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, config, plugin_settings):
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
assert type_path == 'tests.test_plugins:FooModel'
assert item_type == 'create_model'
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand All @@ -391,14 +400,14 @@ def new_schema_validator(self, schema, type_path, config, plugin_settings):


def test_plugin_path_complex() -> None:
paths: list[str] = []
paths: list[tuple(str, str)] = []

class CustomOnValidatePython(ValidatePythonHandlerProtocol):
pass

class Plugin:
def new_schema_validator(self, schema, type_path, config, plugin_settings):
paths.append(type_path)
def new_schema_validator(self, schema, type_path, item_type, config, plugin_settings):
paths.append((type_path, item_type))
return CustomOnValidatePython(), None, None

plugin = Plugin()
Expand All @@ -416,6 +425,6 @@ class Model(BaseModel):
bar()

assert paths == [
'tests.test_plugins:test_plugin_path_complex.<locals>.foo.<locals>.Model',
'tests.test_plugins:test_plugin_path_complex.<locals>.bar.<locals>.Model',
('tests.test_plugins:test_plugin_path_complex.<locals>.foo.<locals>.Model', 'BaseModel'),
('tests.test_plugins:test_plugin_path_complex.<locals>.bar.<locals>.Model', 'BaseModel'),
]

0 comments on commit 559734b

Please sign in to comment.