diff --git a/pydantic/_internal/_config.py b/pydantic/_internal/_config.py index 7a43893885..445f77cab8 100644 --- a/pydantic/_internal/_config.py +++ b/pydantic/_internal/_config.py @@ -1,10 +1,21 @@ from __future__ import annotations as _annotations import warnings -from typing import TYPE_CHECKING, Any, Callable, cast +from contextlib import contextmanager, nullcontext +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Iterator, + cast, +) from pydantic_core import core_schema -from typing_extensions import Literal, Self +from typing_extensions import ( + Literal, + Self, +) from ..config import ConfigDict, ExtraValues, JsonEncoder, JsonSchemaExtraCallable from ..errors import PydanticUserError @@ -169,6 +180,34 @@ def __repr__(self): return f'ConfigWrapper({c})' +class ConfigWrapperStack: + """A stack of `ConfigWrapper` instances.""" + + def __init__(self, config_wrapper: ConfigWrapper): + self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper] + + @property + def tail(self) -> ConfigWrapper: + return self._config_wrapper_stack[-1] + + def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]: + if config_wrapper is None: + return nullcontext() + + if not isinstance(config_wrapper, ConfigWrapper): + config_wrapper = ConfigWrapper(config_wrapper, check=False) + + @contextmanager + def _context_manager() -> Iterator[None]: + self._config_wrapper_stack.append(config_wrapper) + try: + yield + finally: + self._config_wrapper_stack.pop() + + return _context_manager() + + config_defaults = ConfigDict( title=None, str_to_lower=False, diff --git a/pydantic/_internal/_core_metadata.py b/pydantic/_internal/_core_metadata.py index 1672ff6e31..296d49f598 100644 --- a/pydantic/_internal/_core_metadata.py +++ b/pydantic/_internal/_core_metadata.py @@ -30,6 +30,8 @@ class CoreMetadata(typing_extensions.TypedDict, total=False): # prefer positional over keyword arguments for an 'arguments' schema. pydantic_js_prefer_positional_arguments: bool | None + pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema + class CoreMetadataHandler: """Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents. @@ -67,6 +69,7 @@ def build_metadata_dict( js_functions: list[GetJsonSchemaFunction] | None = None, js_annotation_functions: list[GetJsonSchemaFunction] | None = None, js_prefer_positional_arguments: bool | None = None, + typed_dict_cls: type[Any] | None = None, initial_metadata: Any | None = None, ) -> Any: """Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent @@ -79,6 +82,7 @@ def build_metadata_dict( pydantic_js_functions=js_functions or [], pydantic_js_annotation_functions=js_annotation_functions or [], pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments, + pydantic_typed_dict_cls=typed_dict_cls, ) metadata = {k: v for k, v in metadata.items() if v is not None} diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index 5e464d1957..4142f359a5 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -8,7 +8,7 @@ import sys import typing import warnings -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager from copy import copy from enum import Enum from functools import partial @@ -20,7 +20,6 @@ TYPE_CHECKING, Any, Callable, - ContextManager, Dict, ForwardRef, Iterable, @@ -45,7 +44,7 @@ from ..warnings import PydanticDeprecatedSince20 from . import _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra from ._annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler -from ._config import ConfigWrapper +from ._config import ConfigWrapper, ConfigWrapperStack from ._core_metadata import ( CoreMetadataHandler, build_metadata_dict, @@ -261,34 +260,6 @@ def _add_custom_serialization_from_json_encoders( return schema -class ConfigWrapperStack: - """A stack of `ConfigWrapper` instances.""" - - def __init__(self, config_wrapper: ConfigWrapper): - self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper] - - @property - def tail(self) -> ConfigWrapper: - return self._config_wrapper_stack[-1] - - def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]: - if config_wrapper is None: - return nullcontext() - - if not isinstance(config_wrapper, ConfigWrapper): - config_wrapper = ConfigWrapper(config_wrapper, check=False) - - @contextmanager - def _context_manager() -> Iterator[None]: - self._config_wrapper_stack.append(config_wrapper) - try: - yield - finally: - self._config_wrapper_stack.pop() - - return _context_manager() - - class GenerateSchema: """Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... .""" @@ -1098,7 +1069,9 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co field_name, field_info, decorators, required=required ) - metadata = build_metadata_dict(js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls)]) + metadata = build_metadata_dict( + js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls)], typed_dict_cls=typed_dict_cls + ) td_schema = core_schema.typed_dict_schema( fields, diff --git a/pydantic/json_schema.py b/pydantic/json_schema.py index eaf6ed7caa..d7a8dd4c76 100644 --- a/pydantic/json_schema.py +++ b/pydantic/json_schema.py @@ -35,13 +35,21 @@ ) import pydantic_core -from pydantic_core import CoreConfig, CoreSchema, PydanticOmit, core_schema, to_jsonable_python +from pydantic_core import CoreSchema, PydanticOmit, core_schema, to_jsonable_python from pydantic_core.core_schema import ComputedField from typing_extensions import Annotated, Literal, assert_never -from pydantic._internal import _annotated_handlers, _internal_dataclass - -from ._internal import _core_metadata, _core_utils, _mock_val_ser, _schema_generation_shared, _typing_extra +from ._internal import ( + _annotated_handlers, + _config, + _core_metadata, + _core_utils, + _decorators, + _internal_dataclass, + _mock_val_ser, + _schema_generation_shared, + _typing_extra, +) from .config import JsonSchemaExtraCallable from .errors import PydanticInvalidForJsonSchema, PydanticUserError @@ -266,6 +274,7 @@ def __init__(self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLA self.json_to_defs_refs: dict[JsonRef, DefsRef] = {} self.definitions: dict[DefsRef, JsonSchemaValue] = {} + self._config_wrapper_stack = _config.ConfigWrapperStack(_config.ConfigWrapper({})) self.mode: JsonSchemaMode = 'validation' @@ -291,6 +300,10 @@ def __init__(self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLA # of a single instance of a schema generator self._used = False + @property + def _config(self) -> _config.ConfigWrapper: + return self._config_wrapper_stack.tail + def build_schema_type_to_method( self, ) -> dict[CoreSchemaOrFieldType, Callable[[CoreSchemaOrField], JsonSchemaValue]]: @@ -649,7 +662,7 @@ def bytes_schema(self, schema: core_schema.BytesSchema) -> JsonSchemaValue: Returns: The generated JSON schema. """ - json_schema = {'type': 'string', 'format': 'binary'} + json_schema = {'type': 'string', 'format': 'base64url' if self._config.ser_json_bytes == 'base64' else 'binary'} self.update_with_validations(json_schema, schema, self.ValidationsMapping.bytes) return json_schema @@ -697,6 +710,8 @@ def timedelta_schema(self, schema: core_schema.TimedeltaSchema) -> JsonSchemaVal Returns: The generated JSON schema. """ + if self._config.ser_json_timedelta == 'float': + return {'type': 'number'} return {'type': 'string', 'format': 'duration'} def literal_schema(self, schema: core_schema.LiteralSchema) -> JsonSchemaValue: @@ -1168,10 +1183,12 @@ def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaVa ] if self.mode == 'serialization': named_required_fields.extend(self._name_required_computed_fields(schema.get('computed_fields', []))) - json_schema = self._named_required_fields_schema(named_required_fields) - config: CoreConfig | None = schema.get('config', None) - extra = (config or {}).get('extra_fields_behavior', 'ignore') + config = _get_typed_dict_config(schema) + with self._config_wrapper_stack.push(config): + json_schema = self._named_required_fields_schema(named_required_fields) + + extra = config.get('extra', 'ignore') if extra == 'forbid': json_schema['additionalProperties'] = False elif extra == 'allow': @@ -1286,12 +1303,13 @@ def model_schema(self, schema: core_schema.ModelSchema) -> JsonSchemaValue: """ # We do not use schema['model'].model_json_schema() here # because it could lead to inconsistent refs handling, etc. - json_schema = self.generate_inner(schema['schema']) - cls = cast('type[BaseModel]', schema['cls']) config = cls.model_config title = config.get('title') + with self._config_wrapper_stack.push(config): + json_schema = self.generate_inner(schema['schema']) + json_schema_extra = config.get('json_schema_extra') if cls.__pydantic_root_model__: root_json_schema_extra = cls.model_fields['root'].json_schema_extra @@ -1461,13 +1479,13 @@ def dataclass_schema(self, schema: core_schema.DataclassSchema) -> JsonSchemaVal Returns: The generated JSON schema. """ - json_schema = self.generate_inner(schema['schema']).copy() - cls = schema['cls'] config: ConfigDict = getattr(cls, '__pydantic_config__', cast('ConfigDict', {})) - title = config.get('title') or cls.__name__ + with self._config_wrapper_stack.push(config): + json_schema = self.generate_inner(schema['schema']).copy() + json_schema_extra = config.get('json_schema_extra') json_schema = self._update_class_schema(json_schema, title, config.get('extra', None), cls, json_schema_extra) @@ -1942,7 +1960,12 @@ def encode_default(self, dft: Any) -> Any: Returns: The encoded default value. """ - return pydantic_core.to_jsonable_python(dft) + config = self._config + return pydantic_core.to_jsonable_python( + dft, + timedelta_mode=config.ser_json_timedelta, + bytes_mode=config.ser_json_bytes, + ) def update_with_validations( self, json_schema: JsonSchemaValue, core_schema: CoreSchema, mapping: dict[str, str] @@ -2321,3 +2344,14 @@ def __get_pydantic_json_schema__( def __hash__(self) -> int: return hash(type(self)) + + +def _get_typed_dict_config(schema: core_schema.TypedDictSchema) -> ConfigDict: + metadata = _core_metadata.CoreMetadataHandler(schema).metadata + cls = metadata.get('pydantic_typed_dict_cls') + if cls is not None: + try: + return _decorators.get_attribute_from_bases(cls, '__pydantic_config__') + except AttributeError: + pass + return {} diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 7e4e98c263..cc5a2574f6 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -1679,6 +1679,140 @@ class Outer(BaseModel): } +@pytest.mark.parametrize( + 'ser_json_timedelta,properties', + [ + ('float', {'duration': {'default': 300.0, 'title': 'Duration', 'type': 'number'}}), + ('iso8601', {'duration': {'default': 'PT300S', 'format': 'duration', 'title': 'Duration', 'type': 'string'}}), + ], +) +def test_model_default_timedelta(ser_json_timedelta: Literal['float', 'iso8601'], properties: typing.Dict[str, Any]): + class Model(BaseModel): + model_config = ConfigDict(ser_json_timedelta=ser_json_timedelta) + + duration: timedelta = timedelta(minutes=5) + + # insert_assert(Model.model_json_schema(mode='serialization')) + assert Model.model_json_schema(mode='serialization') == { + 'properties': properties, + 'required': ['duration'], + 'title': 'Model', + 'type': 'object', + } + + +@pytest.mark.parametrize( + 'ser_json_bytes,properties', + [ + ('base64', {'data': {'default': 'Zm9vYmFy', 'format': 'base64url', 'title': 'Data', 'type': 'string'}}), + ('utf8', {'data': {'default': 'foobar', 'format': 'binary', 'title': 'Data', 'type': 'string'}}), + ], +) +def test_model_default_bytes(ser_json_bytes: Literal['base64', 'utf8'], properties: typing.Dict[str, Any]): + class Model(BaseModel): + model_config = ConfigDict(ser_json_bytes=ser_json_bytes) + + data: bytes = b'foobar' + + # insert_assert(Model.model_json_schema(mode='serialization')) + assert Model.model_json_schema(mode='serialization') == { + 'properties': properties, + 'required': ['data'], + 'title': 'Model', + 'type': 'object', + } + + +@pytest.mark.parametrize( + 'ser_json_timedelta,properties', + [ + ('float', {'duration': {'default': 300.0, 'title': 'Duration', 'type': 'number'}}), + ('iso8601', {'duration': {'default': 'PT300S', 'format': 'duration', 'title': 'Duration', 'type': 'string'}}), + ], +) +def test_dataclass_default_timedelta( + ser_json_timedelta: Literal['float', 'iso8601'], properties: typing.Dict[str, Any] +): + @dataclass(config=ConfigDict(ser_json_timedelta=ser_json_timedelta)) + class Dataclass: + duration: timedelta = timedelta(minutes=5) + + # insert_assert(TypeAdapter(Dataclass).json_schema(mode='serialization')) + assert TypeAdapter(Dataclass).json_schema(mode='serialization') == { + 'properties': properties, + 'required': ['duration'], + 'title': 'Dataclass', + 'type': 'object', + } + + +@pytest.mark.parametrize( + 'ser_json_bytes,properties', + [ + ('base64', {'data': {'default': 'Zm9vYmFy', 'format': 'base64url', 'title': 'Data', 'type': 'string'}}), + ('utf8', {'data': {'default': 'foobar', 'format': 'binary', 'title': 'Data', 'type': 'string'}}), + ], +) +def test_dataclass_default_bytes(ser_json_bytes: Literal['base64', 'utf8'], properties: typing.Dict[str, Any]): + @dataclass(config=ConfigDict(ser_json_bytes=ser_json_bytes)) + class Dataclass: + data: bytes = b'foobar' + + # insert_assert(TypeAdapter(Dataclass).json_schema(mode='serialization')) + assert TypeAdapter(Dataclass).json_schema(mode='serialization') == { + 'properties': properties, + 'required': ['data'], + 'title': 'Dataclass', + 'type': 'object', + } + + +@pytest.mark.parametrize( + 'ser_json_timedelta,properties', + [ + ('float', {'duration': {'default': 300.0, 'title': 'Duration', 'type': 'number'}}), + ('iso8601', {'duration': {'default': 'PT300S', 'format': 'duration', 'title': 'Duration', 'type': 'string'}}), + ], +) +def test_typeddict_default_timedelta( + ser_json_timedelta: Literal['float', 'iso8601'], properties: typing.Dict[str, Any] +): + class MyTypedDict(TypedDict): + __pydantic_config__ = ConfigDict(ser_json_timedelta=ser_json_timedelta) + + duration: Annotated[timedelta, Field(timedelta(minutes=5))] + + # insert_assert(TypeAdapter(MyTypedDict).json_schema(mode='serialization')) + assert TypeAdapter(MyTypedDict).json_schema(mode='serialization') == { + 'properties': properties, + 'required': ['duration'], + 'title': 'MyTypedDict', + 'type': 'object', + } + + +@pytest.mark.parametrize( + 'ser_json_bytes,properties', + [ + ('base64', {'data': {'default': 'Zm9vYmFy', 'format': 'base64url', 'title': 'Data', 'type': 'string'}}), + ('utf8', {'data': {'default': 'foobar', 'format': 'binary', 'title': 'Data', 'type': 'string'}}), + ], +) +def test_typeddict_default_bytes(ser_json_bytes: Literal['base64', 'utf8'], properties: typing.Dict[str, Any]): + class MyTypedDict(TypedDict): + __pydantic_config__ = ConfigDict(ser_json_bytes=ser_json_bytes) + + data: Annotated[bytes, Field(b'foobar')] + + # insert_assert(TypeAdapter(MyTypedDict).json_schema(mode='serialization')) + assert TypeAdapter(MyTypedDict).json_schema(mode='serialization') == { + 'properties': properties, + 'required': ['data'], + 'title': 'MyTypedDict', + 'type': 'object', + } + + def test_model_subclass_metadata(): class A(BaseModel): """A Model docstring"""