Skip to content

Commit

Permalink
Replace TransformSchema with GetPydanticSchema (#6484)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Jul 6, 2023
1 parent 7b526de commit ab583cc
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 12 deletions.
1 change: 1 addition & 0 deletions pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
'SkipValidation',
'InstanceOf',
'WithJsonSchema',
'GetPydanticSchema',
# type_adapter
'TypeAdapter',
# version
Expand Down
50 changes: 43 additions & 7 deletions pydantic/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
'Base64Encoder',
'Base64Bytes',
'Base64Str',
'GetPydanticSchema',
)


Expand Down Expand Up @@ -1333,12 +1334,47 @@ def encode_str(self, value: str) -> str:


@_internal_dataclass.slots_dataclass
class TransformSchema:
"""An annotation that can be used to apply a transform to a core schema."""
class GetPydanticSchema:
"""A convenience class for creating an annotation that provides pydantic custom type hooks.
transform: Callable[[CoreSchema], CoreSchema]
This class is intended to eliminate the need to create a custom "marker" which defines the
`__get_pydantic_core_schema__` and `__get_pydantic_json_schema__` custom hook methods.
def __get_pydantic_core_schema__(
self, source_type: type[Any], handler: _annotated_handlers.GetCoreSchemaHandler
) -> CoreSchema:
return self.transform(handler(source_type))
For example, to have a field treated by type checkers as `int`, but by pydantic as `Any`, you can do:
```python
from typing import Any
from typing_extensions import Annotated
from pydantic import BaseModel, GetPydanticSchema
HandleAsAny = GetPydanticSchema(lambda _s, h: h(Any))
class Model(BaseModel):
x: Annotated[int, HandleAsAny] # pydantic sees `x: Any`
print(repr(Model(x='abc').x))
#> 'abc'
```
"""

get_pydantic_core_schema: Callable[[Any, _annotated_handlers.GetCoreSchemaHandler], CoreSchema] | None = None
get_pydantic_json_schema: Callable[[Any, _annotated_handlers.GetJsonSchemaHandler], JsonSchemaValue] | None = None
# Note: if we find a use, we could uncomment the following as a way to specify `__prepare_pydantic_annotations__`:
# prepare_pydantic_annotations: Callable[
# [Any, tuple[Any, ...], ConfigDict], tuple[Any, Iterable[Any]]
# ] | None = None

# Note: we may want to consider adding a convenience staticmethod `def for_type(type_: Any) -> GetPydanticSchema:`
# which returns `GetPydanticSchema(lambda _s, h: h(type_))`

def __getattr__(self, item: str) -> Any:
"""Use this rather than defining `__get_pydantic_core_schema__` etc. to reduce the number of nested calls."""
if item == '__get_pydantic_core_schema__' and self.get_pydantic_core_schema:
return self.get_pydantic_core_schema
elif item == '__get_pydantic_json_schema__' and self.get_pydantic_json_schema:
return self.get_pydantic_json_schema
else:
return object.__getattribute__(self, item)

__hash__ = object.__hash__
22 changes: 17 additions & 5 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
)
from pydantic.errors import PydanticSchemaGenerationError
from pydantic.functional_validators import AfterValidator
from pydantic.types import AllowInfNan, ImportString, Strict, TransformSchema
from pydantic.types import AllowInfNan, GetPydanticSchema, ImportString, Strict

try:
import email_validator
Expand Down Expand Up @@ -5107,17 +5107,28 @@ class A(BaseModel):
}


def test_transform_schema():
ValidateStrAsInt = Annotated[str, GetPydanticSchema(lambda _s, h: core_schema.int_schema())]

class Model(BaseModel):
x: Optional[ValidateStrAsInt]

assert Model(x=None).x is None
assert Model(x='1').x == 1


def test_transform_schema_for_first_party_class():
# Here, first party means you can define the `__prepare_pydantic_annotations__` method on the class directly.
class LowercaseStr(str):
@classmethod
def __prepare_pydantic_annotations__(
cls, _source: Type[Any], annotations: Tuple[Any, ...], _config: ConfigDict
) -> Tuple[Any, Iterable[Any]]:
def transform_schema(schema: CoreSchema) -> CoreSchema:
def get_pydantic_core_schema(source: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
schema = handler(source)
return core_schema.no_info_after_validator_function(lambda v: v.lower(), schema)

return str, (*annotations, TransformSchema(transform_schema))
return str, (*annotations, GetPydanticSchema(get_pydantic_core_schema))

class Model(BaseModel):
lower: LowercaseStr = Field(min_length=1)
Expand Down Expand Up @@ -5154,10 +5165,11 @@ class _DatetimeWrapperAnnotation:
def __prepare_pydantic_annotations__(
cls, _source: Type[Any], annotations: Tuple[Any, ...], _config: ConfigDict
) -> Tuple[Any, Iterable[Any]]:
def transform(schema: CoreSchema) -> CoreSchema:
def get_pydantic_core_schema(source: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
schema = handler(source)
return core_schema.no_info_after_validator_function(lambda v: DatetimeWrapper(v), schema)

return datetime, list(annotations) + [TransformSchema(transform)]
return datetime, list(annotations) + [GetPydanticSchema(get_pydantic_core_schema)]

# Giving a name to Annotated[DatetimeWrapper, _DatetimeWrapperAnnotation] makes it easier to use in code
# where I want a field of type `DatetimeWrapper` that works as desired with pydantic.
Expand Down

0 comments on commit ab583cc

Please sign in to comment.