diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index 1b673f911e..f52b6ac0cc 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -53,6 +53,7 @@ define_expected_missing_refs, get_ref, get_type_ref, + is_function_with_inner_schema, is_list_like_schema_with_items_schema, simplify_schema_references, validate_core_schema, @@ -620,7 +621,13 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C schema = self._unpack_refs_defs(schema) - ref = get_ref(schema) + if is_function_with_inner_schema(schema): + ref = schema['schema'].pop('ref', None) + if ref: + schema['ref'] = ref + else: + ref = get_ref(schema) + if ref: self.defs.definitions[ref] = self._post_process_generated_schema(schema) return core_schema.definition_reference_schema(ref) diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 675b52a4a4..4fe2e30398 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -33,10 +33,11 @@ import pytest from dirty_equals import HasRepr from pydantic_core import CoreSchema, SchemaValidator, core_schema, to_json -from typing_extensions import Annotated, Literal, TypedDict +from typing_extensions import Annotated, Literal, Self, TypedDict import pydantic from pydantic import ( + AfterValidator, BaseModel, Field, GetCoreSchemaHandler, @@ -5779,3 +5780,35 @@ class Foo(BaseModel): } } } + + +def test_repeated_custom_type(): + class Numeric(pydantic.BaseModel): + value: float + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler) -> CoreSchema: + return core_schema.no_info_before_validator_function(cls.validate, handler(source_type)) + + @classmethod + def validate(cls, v: Any) -> Union[Dict[str, Any], Self]: + if isinstance(v, (str, float, int)): + return cls(value=v) + if isinstance(v, Numeric): + return v + if isinstance(v, dict): + return v + raise ValueError(f'Invalid value for {cls}: {v}') + + def is_positive(value: Numeric): + assert value.value > 0.0, 'Must be positive' + + class OuterModel(pydantic.BaseModel): + x: Numeric + y: Numeric + z: Annotated[Numeric, AfterValidator(is_positive)] + + assert OuterModel(x=2, y=-1, z=1) + + with pytest.raises(ValidationError): + OuterModel(x=2, y=-1, z=-1)