10 changes: 4 additions & 6 deletions .github/workflows/codspeed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'

Expand Down Expand Up @@ -52,8 +52,7 @@ jobs:
key: v1

- name: Compile pydantic-core for profiling
run: |
pip install -e . --config-settings=build-args='--verbose --profile codspeed' -v
run: make build-profiling
env:
CONST_RANDOM_SEED: 0 # Fix the compile time RNG seed
RUSTFLAGS: "-Cprofile-generate=${{ github.workspace }}/profdata"
Expand All @@ -65,13 +64,12 @@ jobs:
run: rustup run stable bash -c '$RUSTUP_HOME/toolchains/$RUSTUP_TOOLCHAIN/lib/rustlib/x86_64-unknown-linux-gnu/bin/llvm-profdata merge -o ${{ github.workspace }}/merged.profdata ${{ github.workspace }}/profdata'

- name: Compile pydantic-core for benchmarking
run: |
pip install -e . --config-settings=build-args='--verbose --profile codspeed' -v
run: make build-profiling
env:
CONST_RANDOM_SEED: 0 # Fix the compile time RNG seed
RUSTFLAGS: "-Cprofile-use=${{ github.workspace }}/merged.profdata"

- name: Run CodSpeed benchmarks
uses: CodSpeedHQ/action@v1
uses: CodSpeedHQ/action@v2
with:
run: pytest tests/benchmarks/ --codspeed
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ node_modules/
/foobar.py
/python/pydantic_core/*.so
/src/self_schema.py

# samply
/profile.json
105 changes: 58 additions & 47 deletions Cargo.lock
58 changes: 43 additions & 15 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pydantic-core"
version = "2.14.4"
version = "2.16.1"
edition = "2021"
license = "MIT"
homepage = "https://github.com/pydantic/pydantic-core"
Expand All @@ -26,25 +26,24 @@ include = [
]

[dependencies]
pyo3 = { version = "0.20.0", features = ["generate-import-lib", "num-bigint"] }
pyo3 = { version = "0.20.2", features = ["generate-import-lib", "num-bigint"] }
regex = "1.10.2"
strum = { version = "0.25.0", features = ["derive"] }
strum_macros = "0.25.3"
serde_json = {version = "1.0.108", features = ["arbitrary_precision", "preserve_order"]}
serde_json = {version = "1.0.109", features = ["arbitrary_precision", "preserve_order"]}
enum_dispatch = "0.3.8"
serde = { version = "1.0.190", features = ["derive"] }
serde = { version = "1.0.195", features = ["derive"] }
speedate = "0.13.0"
smallvec = "1.11.1"
ahash = "0.8.6"
url = "2.4.1"
smallvec = "1.11.2"
ahash = "0.8.7"
url = "2.5.0"
# idna is already required by url, added here to be explicit
idna = "0.4.0"
base64 = "0.21.5"
idna = "0.5.0"
base64 = "0.21.7"
num-bigint = "0.4.4"
python3-dll-a = "0.2.7"
uuid = "1.5.0"
jiter = {version = "0.0.4", features = ["python"]}
#jiter = {path = "../jiter", features = ["python"]}
uuid = "1.6.1"
jiter = {version = "0.0.6", features = ["python"]}

[lib]
name = "_pydantic_core"
Expand All @@ -63,19 +62,48 @@ strip = true
debug = true
strip = false

[profile.codspeed]
# This is separate to benchmarks because `bench` ends up building testing
# harnesses into code, as it's a special cargo profile.
[profile.profiling]
inherits = "release"
debug = true
strip = false

[dev-dependencies]
pyo3 = { version = "0.20.0", features = ["auto-initialize"] }
pyo3 = { version = "0.20.2", features = ["auto-initialize"] }

[build-dependencies]
version_check = "0.9.4"
# used where logic has to be version/distribution specific, e.g. pypy
pyo3-build-config = { version = "0.20.0" }
pyo3-build-config = { version = "0.20.2" }

[lints.clippy]
dbg_macro = "warn"
print_stdout = "warn"

# in general we lint against the pedantic group, but we will whitelist
# certain lints which we don't want to enforce (for now)
pedantic = { level = "warn", priority = -1 }
cast_possible_truncation = "allow"
cast_possible_wrap = "allow"
cast_precision_loss = "allow"
cast_sign_loss = "allow"
doc_markdown = "allow"
float_cmp = "allow"
fn_params_excessive_bools = "allow"
if_not_else = "allow"
manual_let_else = "allow"
match_bool = "allow"
match_same_arms = "allow"
missing_errors_doc = "allow"
missing_panics_doc = "allow"
module_name_repetitions = "allow"
must_use_candidate = "allow"
needless_pass_by_value = "allow"
similar_names = "allow"
single_match_else = "allow"
struct_excessive_bools = "allow"
too_many_lines = "allow"
unnecessary_wraps = "allow"
unused_self = "allow"
used_underscore_binding = "allow"
31 changes: 3 additions & 28 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ endif
build-profiling:
@rm -f python/pydantic_core/*.so
ifneq ($(USE_MATURIN),)
CARGO_PROFILE_RELEASE_STRIP=false CARGO_PROFILE_RELEASE_DEBUG=true maturin develop --release
maturin develop --profile profiling
else
CARGO_PROFILE_RELEASE_STRIP=false CARGO_PROFILE_RELEASE_DEBUG=true pip install -v -e .
pip install -v -e . --config-settings=build-args='--profile profiling'
endif

.PHONY: build-coverage
Expand Down Expand Up @@ -106,32 +106,7 @@ lint-rust:
cargo fmt --version
cargo fmt --all -- --check
cargo clippy --version
cargo clippy --tests -- \
-D warnings \
-W clippy::pedantic \
-A clippy::cast-possible-truncation \
-A clippy::cast-possible-wrap \
-A clippy::cast-precision-loss \
-A clippy::cast-sign-loss \
-A clippy::doc-markdown \
-A clippy::float-cmp \
-A clippy::fn-params-excessive-bools \
-A clippy::if-not-else \
-A clippy::manual-let-else \
-A clippy::match-bool \
-A clippy::match-same-arms \
-A clippy::missing-errors-doc \
-A clippy::missing-panics-doc \
-A clippy::module-name-repetitions \
-A clippy::must-use-candidate \
-A clippy::needless-pass-by-value \
-A clippy::similar-names \
-A clippy::single-match-else \
-A clippy::struct-excessive-bools \
-A clippy::too-many-lines \
-A clippy::unnecessary-wraps \
-A clippy::unused-self \
-A clippy::used-underscore-binding
cargo clippy --tests -- -D warnings

.PHONY: lint
lint: lint-python lint-rust
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ except ValidationError as e:

You'll need rust stable [installed](https://rustup.rs/), or rust nightly if you want to generate accurate coverage.

With rust and python 3.7+ installed, compiling pydantic-core should be possible with roughly the following:
With rust and python 3.8+ installed, compiling pydantic-core should be possible with roughly the following:

```bash
# clone this repo or your fork
Expand Down
4 changes: 2 additions & 2 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def type_dict_schema( # noqa: C901
'type': 'union',
'choices': [
schema_ref_validator,
{'type': 'tuple-positional', 'items_schema': [schema_ref_validator, {'type': 'str'}]},
{'type': 'tuple', 'items_schema': [schema_ref_validator, {'type': 'str'}]},
],
},
}
Expand Down Expand Up @@ -191,7 +191,7 @@ def eval_forward_ref(type_: Any) -> Any:
try:
return type_._evaluate(core_schema.__dict__, None, set())
except TypeError:
# for older python (3.7 at least)
# for Python 3.8
return type_._evaluate(core_schema.__dict__, None)


Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build-backend = 'maturin'

[project]
name = 'pydantic_core'
requires-python = '>=3.7'
requires-python = '>=3.8'
authors = [
{name = 'Samuel Colvin', email = 's@muelcolvin.com'}
]
Expand All @@ -16,7 +16,6 @@ classifiers = [
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
Expand Down
20 changes: 19 additions & 1 deletion python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def to_json(
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
) -> bytes:
Expand All @@ -373,6 +374,7 @@ def to_json(
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
Expand All @@ -385,7 +387,7 @@ def to_json(
JSON bytes.
"""

def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> Any:
def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True, cache_strings: bool = True) -> Any:
"""
Deserialize JSON data to a Python object.
Expand All @@ -394,6 +396,8 @@ def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> A
Arguments:
data: The JSON data to deserialize.
allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default.
cache_strings: Whether to cache strings to avoid constructing new Python objects,
this should have a significant impact on performance while increasing memory usage slightly.
Raises:
ValueError: If deserialization fails.
Expand All @@ -412,6 +416,7 @@ def to_jsonable_python(
round_trip: bool = False,
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
bytes_mode: Literal['utf8', 'base64'] = 'utf8',
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
) -> Any:
Expand All @@ -430,6 +435,7 @@ def to_jsonable_python(
round_trip: Whether to enable serialization and validation round-trip support.
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`.
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`.
serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
Expand Down Expand Up @@ -785,6 +791,18 @@ class ValidationError(ValueError):
a JSON string.
"""

def __repr__(self) -> str:
"""
A string representation of the validation error.
Whether or not documentation URLs are included in the repr is controlled by the
environment variable `PYDANTIC_ERRORS_INCLUDE_URL` being set to `1` or
`true`; by default, URLs are shown.
Due to implementation details, this environment variable can only be set once,
before the first validation error is created.
"""

@final
class PydanticCustomError(ValueError):
def __new__(
Expand Down
121 changes: 86 additions & 35 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ def is_instance_schema(
serialization: SerSchema | None = None,
) -> IsInstanceSchema:
"""
Returns a schema that checks if a value is an instance of a class, equivalent to python's `isinstnace` method, e.g.:
Returns a schema that checks if a value is an instance of a class, equivalent to python's `isinstance` method, e.g.:
```py
from pydantic_core import SchemaValidator, core_schema
Expand Down Expand Up @@ -1384,16 +1384,7 @@ def list_schema(
)


class TuplePositionalSchema(TypedDict, total=False):
type: Required[Literal['tuple-positional']]
items_schema: Required[List[CoreSchema]]
extras_schema: CoreSchema
strict: bool
ref: str
metadata: Any
serialization: IncExSeqOrElseSerSchema


# @deprecated('tuple_positional_schema is deprecated. Use pydantic_core.core_schema.tuple_schema instead.')
def tuple_positional_schema(
items_schema: list[CoreSchema],
*,
Expand All @@ -1402,7 +1393,7 @@ def tuple_positional_schema(
ref: str | None = None,
metadata: Any = None,
serialization: IncExSeqOrElseSerSchema | None = None,
) -> TuplePositionalSchema:
) -> TupleSchema:
"""
Returns a schema that matches a tuple of schemas, e.g.:
Expand All @@ -1427,20 +1418,70 @@ def tuple_positional_schema(
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='tuple-positional',
if extras_schema is not None:
variadic_item_index = len(items_schema)
items_schema = items_schema + [extras_schema]
else:
variadic_item_index = None
return tuple_schema(
items_schema=items_schema,
extras_schema=extras_schema,
variadic_item_index=variadic_item_index,
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class TupleVariableSchema(TypedDict, total=False):
type: Required[Literal['tuple-variable']]
items_schema: CoreSchema
# @deprecated('tuple_variable_schema is deprecated. Use pydantic_core.core_schema.tuple_schema instead.')
def tuple_variable_schema(
items_schema: CoreSchema | None = None,
*,
min_length: int | None = None,
max_length: int | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: IncExSeqOrElseSerSchema | None = None,
) -> TupleSchema:
"""
Returns a schema that matches a tuple of a given schema, e.g.:
```py
from pydantic_core import SchemaValidator, core_schema
schema = core_schema.tuple_variable_schema(
items_schema=core_schema.int_schema(), min_length=0, max_length=10
)
v = SchemaValidator(schema)
assert v.validate_python(('1', 2, 3)) == (1, 2, 3)
```
Args:
items_schema: The value must be a tuple with items that match this schema
min_length: The value must be a tuple with at least this many items
max_length: The value must be a tuple with at most this many items
strict: The value must be a tuple with exactly this many items
ref: Optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return tuple_schema(
items_schema=[items_schema or any_schema()],
variadic_item_index=0,
min_length=min_length,
max_length=max_length,
strict=strict,
ref=ref,
metadata=metadata,
serialization=serialization,
)


class TupleSchema(TypedDict, total=False):
type: Required[Literal['tuple']]
items_schema: Required[List[CoreSchema]]
variadic_item_index: int
min_length: int
max_length: int
strict: bool
Expand All @@ -1449,41 +1490,45 @@ class TupleVariableSchema(TypedDict, total=False):
serialization: IncExSeqOrElseSerSchema


def tuple_variable_schema(
items_schema: CoreSchema | None = None,
def tuple_schema(
items_schema: list[CoreSchema],
*,
variadic_item_index: int | None = None,
min_length: int | None = None,
max_length: int | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
serialization: IncExSeqOrElseSerSchema | None = None,
) -> TupleVariableSchema:
) -> TupleSchema:
"""
Returns a schema that matches a tuple of a given schema, e.g.:
Returns a schema that matches a tuple of schemas, with an optional variadic item, e.g.:
```py
from pydantic_core import SchemaValidator, core_schema
schema = core_schema.tuple_variable_schema(
items_schema=core_schema.int_schema(), min_length=0, max_length=10
schema = core_schema.tuple_schema(
[core_schema.int_schema(), core_schema.str_schema(), core_schema.float_schema()],
variadic_item_index=1,
)
v = SchemaValidator(schema)
assert v.validate_python(('1', 2, 3)) == (1, 2, 3)
assert v.validate_python((1, 'hello', 'world', 1.5)) == (1, 'hello', 'world', 1.5)
```
Args:
items_schema: The value must be a tuple with items that match this schema
items_schema: The value must be a tuple with items that match these schemas
variadic_item_index: The index of the schema in `items_schema` to be treated as variadic (following PEP 646)
min_length: The value must be a tuple with at least this many items
max_length: The value must be a tuple with at most this many items
strict: The value must be a tuple with exactly this many items
ref: optional unique identifier of the schema, used to reference the schema in other places
ref: Optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='tuple-variable',
type='tuple',
items_schema=items_schema,
variadic_item_index=variadic_item_index,
min_length=min_length,
max_length=max_length,
strict=strict,
Expand Down Expand Up @@ -2940,6 +2985,7 @@ class DataclassField(TypedDict, total=False):
name: Required[str]
schema: Required[CoreSchema]
kw_only: bool # default: True
init: bool # default: True
init_only: bool # default: False
frozen: bool # default: False
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
Expand All @@ -2953,6 +2999,7 @@ def dataclass_field(
schema: CoreSchema,
*,
kw_only: bool | None = None,
init: bool | None = None,
init_only: bool | None = None,
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
serialization_alias: str | None = None,
Expand All @@ -2978,6 +3025,7 @@ def dataclass_field(
name: The name to use for the argument parameter
schema: The schema to use for the argument parameter
kw_only: Whether the field can be set with a positional argument as well as a keyword argument
init: Whether the field should be validated during initialization
init_only: Whether the field should be omitted from `__dict__` and passed to `__post_init__`
validation_alias: The alias(es) to use to find the field in the validation data
serialization_alias: The alias to use as a key when serializing
Expand All @@ -2990,6 +3038,7 @@ def dataclass_field(
name=name,
schema=schema,
kw_only=kw_only,
init=init,
init_only=init_only,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
Expand Down Expand Up @@ -3576,12 +3625,13 @@ def definitions_schema(schema: CoreSchema, definitions: list[CoreSchema]) -> Def
class DefinitionReferenceSchema(TypedDict, total=False):
type: Required[Literal['definition-ref']]
schema_ref: Required[str]
ref: str
metadata: Any
serialization: SerSchema


def definition_reference_schema(
schema_ref: str, metadata: Any = None, serialization: SerSchema | None = None
schema_ref: str, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None
) -> DefinitionReferenceSchema:
"""
Returns a schema that points to a schema stored in "definitions", this is useful for nested recursive
Expand All @@ -3606,7 +3656,9 @@ def definition_reference_schema(
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(type='definition-ref', schema_ref=schema_ref, metadata=metadata, serialization=serialization)
return _dict_not_none(
type='definition-ref', schema_ref=schema_ref, ref=ref, metadata=metadata, serialization=serialization
)


MYPY = False
Expand All @@ -3631,8 +3683,7 @@ def definition_reference_schema(
IsSubclassSchema,
CallableSchema,
ListSchema,
TuplePositionalSchema,
TupleVariableSchema,
TupleSchema,
SetSchema,
FrozenSetSchema,
GeneratorSchema,
Expand Down Expand Up @@ -3686,8 +3737,7 @@ def definition_reference_schema(
'is-subclass',
'callable',
'list',
'tuple-positional',
'tuple-variable',
'tuple',
'set',
'frozenset',
'generator',
Expand Down Expand Up @@ -3787,6 +3837,7 @@ def definition_reference_schema(
'datetime_type',
'datetime_parsing',
'datetime_object_invalid',
'datetime_from_date_parsing',
'datetime_past',
'datetime_future',
'timezone_naive',
Expand Down Expand Up @@ -3862,7 +3913,7 @@ def field_after_validator_function(function: WithInfoValidatorFunction, field_na
@deprecated('`general_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.')
def general_after_validator_function(*args, **kwargs):
warnings.warn(
'`with_info_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.',
'`general_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.',
DeprecationWarning,
)
return with_info_after_validator_function(*args, **kwargs)
Expand Down
6 changes: 0 additions & 6 deletions src/argument_markers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,3 @@ impl PydanticUndefinedType {
"PydanticUndefined"
}
}

impl PydanticUndefinedType {
pub fn py_undefined() -> Py<Self> {
Python::with_gil(PydanticUndefinedType::new)
}
}
94 changes: 47 additions & 47 deletions src/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
fmt::Debug,
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
Arc, OnceLock, Weak,
},
};

Expand All @@ -28,47 +28,50 @@ use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};
/// They get indexed by a ReferenceId, which are integer identifiers
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
/// gets build.
#[derive(Clone)]
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);

/// Internal type which contains a definition to be filled
pub struct Definition<T>(Arc<DefinitionInner<T>>);

struct DefinitionInner<T> {
value: OnceLock<T>,
name: LazyName,
struct Definition<T> {
value: Arc<OnceLock<T>>,
name: Arc<LazyName>,
}

/// Reference to a definition.
pub struct DefinitionRef<T> {
name: Arc<String>,
value: Definition<T>,
reference: Arc<String>,
// We use a weak reference to the definition to avoid a reference cycle
// when recursive definitions are used.
value: Weak<OnceLock<T>>,
name: Arc<LazyName>,
}

// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone)
impl<T> Clone for DefinitionRef<T> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
reference: self.reference.clone(),
value: self.value.clone(),
name: self.name.clone(),
}
}
}

impl<T> DefinitionRef<T> {
pub fn id(&self) -> usize {
Arc::as_ptr(&self.value.0) as usize
Weak::as_ptr(&self.value) as usize
}

pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str {
match self.value.0.value.get() {
Some(value) => self.value.0.name.get_or_init(|| init(value)),
let Some(definition) = self.value.upgrade() else {
return "...";
};
match definition.get() {
Some(value) => self.name.get_or_init(|| init(value)),
None => "...",
}
}

pub fn get(&self) -> Option<&T> {
self.value.0.value.get()
pub fn read<R>(&self, f: impl FnOnce(Option<&T>) -> R) -> R {
f(self.value.upgrade().as_ref().and_then(|value| value.get()))
}
}

Expand Down Expand Up @@ -96,15 +99,9 @@ impl<T: Debug> Debug for Definitions<T> {
}
}

impl<T> Clone for Definition<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<T: Debug> Debug for Definition<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0.value.get() {
match self.value.get() {
Some(value) => value.fmt(f),
None => "...".fmt(f),
}
Expand All @@ -113,7 +110,7 @@ impl<T: Debug> Debug for Definition<T> {

impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
if let Some(value) = self.value.0.value.get() {
if let Some(value) = self.value.upgrade().as_ref().and_then(|v| v.get()) {
value.py_gc_traverse(visit)?;
}
Ok(())
Expand All @@ -123,15 +120,15 @@ impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
for value in self.0.values() {
if let Some(value) = value.0.value.get() {
if let Some(value) = value.value.get() {
value.py_gc_traverse(visit)?;
}
}
Ok(())
}
}

#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct DefinitionsBuilder<T> {
definitions: Definitions<T>,
}
Expand All @@ -148,45 +145,48 @@ impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
// We either need a String copy or two hashmap lookups
// Neither is better than the other
// We opted for the easier outward facing API
let name = Arc::new(reference.to_string());
let value = match self.definitions.0.entry(name.clone()) {
let reference = Arc::new(reference.to_string());
let value = match self.definitions.0.entry(reference.clone()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::new(),
name: LazyName::new(),
}))),
Entry::Vacant(entry) => entry.insert(Definition {
value: Arc::new(OnceLock::new()),
name: Arc::new(LazyName::new()),
}),
};
DefinitionRef {
name,
value: value.clone(),
reference,
value: Arc::downgrade(&value.value),
name: value.name.clone(),
}
}

/// Add a definition, returning the ReferenceId that maps to it
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<DefinitionRef<T>> {
let name = Arc::new(reference);
let value = match self.definitions.0.entry(name.clone()) {
let reference = Arc::new(reference);
let value = match self.definitions.0.entry(reference.clone()) {
Entry::Occupied(entry) => {
let definition = entry.into_mut();
match definition.0.value.set(value) {
Ok(()) => definition.clone(),
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
match definition.value.set(value) {
Ok(()) => definition,
Err(_) => return py_schema_err!("Duplicate ref: `{}`", reference),
}
}
Entry::Vacant(entry) => entry
.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::from(value),
name: LazyName::new(),
})))
.clone(),
Entry::Vacant(entry) => entry.insert(Definition {
value: Arc::new(OnceLock::from(value)),
name: Arc::new(LazyName::new()),
}),
};
Ok(DefinitionRef { name, value })
Ok(DefinitionRef {
reference,
value: Arc::downgrade(&value.value),
name: value.name.clone(),
})
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Definitions<T>> {
for (reference, def) in &self.definitions.0 {
if def.0.value.get().is_none() {
if def.value.get().is_none() {
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
}
}
Expand Down
83 changes: 34 additions & 49 deletions src/errors/line_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,54 @@ use crate::input::Input;
use super::location::{LocItem, Location};
use super::types::ErrorType;

pub type ValResult<'a, T> = Result<T, ValError<'a>>;
pub type ValResult<T> = Result<T, ValError>;

pub trait AsErrorValue {
fn as_error_value(&self) -> InputValue;
}

impl<'a, T: Input<'a>> AsErrorValue for T {
fn as_error_value(&self) -> InputValue {
Input::as_error_value(self)
}
}

#[cfg_attr(debug_assertions, derive(Debug))]
pub enum ValError<'a> {
LineErrors(Vec<ValLineError<'a>>),
pub enum ValError {
LineErrors(Vec<ValLineError>),
InternalErr(PyErr),
Omit,
UseDefault,
}

impl<'a> From<PyErr> for ValError<'a> {
impl From<PyErr> for ValError {
fn from(py_err: PyErr) -> Self {
Self::InternalErr(py_err)
}
}

impl<'a> From<PyDowncastError<'_>> for ValError<'a> {
impl From<PyDowncastError<'_>> for ValError {
fn from(py_downcast: PyDowncastError) -> Self {
Self::InternalErr(PyTypeError::new_err(py_downcast.to_string()))
}
}

impl<'a> From<Vec<ValLineError<'a>>> for ValError<'a> {
fn from(line_errors: Vec<ValLineError<'a>>) -> Self {
impl From<Vec<ValLineError>> for ValError {
fn from(line_errors: Vec<ValLineError>) -> Self {
Self::LineErrors(line_errors)
}
}

impl<'a> ValError<'a> {
pub fn new(error_type: ErrorType, input: &'a impl Input<'a>) -> ValError<'a> {
impl ValError {
pub fn new(error_type: ErrorType, input: &impl AsErrorValue) -> ValError {
Self::LineErrors(vec![ValLineError::new(error_type, input)])
}

pub fn new_with_loc(error_type: ErrorType, input: &'a impl Input<'a>, loc: impl Into<LocItem>) -> ValError<'a> {
pub fn new_with_loc(error_type: ErrorType, input: &impl AsErrorValue, loc: impl Into<LocItem>) -> ValError {
Self::LineErrors(vec![ValLineError::new_with_loc(error_type, input, loc)])
}

pub fn new_custom_input(error_type: ErrorType, input_value: InputValue<'a>) -> ValError<'a> {
pub fn new_custom_input(error_type: ErrorType, input_value: InputValue) -> ValError {
Self::LineErrors(vec![ValLineError::new_custom_input(error_type, input_value)])
}

Expand All @@ -62,55 +72,45 @@ impl<'a> ValError<'a> {
other => other,
}
}

/// a bit like clone but change the lifetime to match py
pub fn into_owned(self, py: Python<'_>) -> ValError<'_> {
match self {
ValError::LineErrors(errors) => errors.into_iter().map(|e| e.into_owned(py)).collect::<Vec<_>>().into(),
ValError::InternalErr(err) => ValError::InternalErr(err),
ValError::Omit => ValError::Omit,
ValError::UseDefault => ValError::UseDefault,
}
}
}

/// A `ValLineError` is a single error that occurred during validation which is converted to a `PyLineError`
/// to eventually form a `ValidationError`.
/// I don't like the name `ValLineError`, but it's the best I could come up with (for now).
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct ValLineError<'a> {
pub struct ValLineError {
pub error_type: ErrorType,
// location is reversed so that adding an "outer" location item is pushing, it's reversed before showing to the user
pub location: Location,
pub input_value: InputValue<'a>,
pub input_value: InputValue,
}

impl<'a> ValLineError<'a> {
pub fn new(error_type: ErrorType, input: &'a impl Input<'a>) -> ValLineError<'a> {
impl ValLineError {
pub fn new(error_type: ErrorType, input: &impl AsErrorValue) -> ValLineError {
Self {
error_type,
input_value: input.as_error_value(),
location: Location::default(),
}
}

pub fn new_with_loc(error_type: ErrorType, input: &'a impl Input<'a>, loc: impl Into<LocItem>) -> ValLineError<'a> {
pub fn new_with_loc(error_type: ErrorType, input: &impl AsErrorValue, loc: impl Into<LocItem>) -> ValLineError {
Self {
error_type,
input_value: input.as_error_value(),
location: Location::new_some(loc.into()),
}
}

pub fn new_with_full_loc(error_type: ErrorType, input: &'a impl Input<'a>, location: Location) -> ValLineError<'a> {
pub fn new_with_full_loc(error_type: ErrorType, input: &impl AsErrorValue, location: Location) -> ValLineError {
Self {
error_type,
input_value: input.as_error_value(),
location,
}
}

pub fn new_custom_input(error_type: ErrorType, input_value: InputValue<'a>) -> ValLineError<'a> {
pub fn new_custom_input(error_type: ErrorType, input_value: InputValue) -> ValLineError {
Self {
error_type,
input_value,
Expand All @@ -130,35 +130,20 @@ impl<'a> ValLineError<'a> {
self.error_type = error_type;
self
}

/// a bit like clone but change the lifetime to match py, used by ValError.into_owned above
pub fn into_owned(self, py: Python<'_>) -> ValLineError<'_> {
ValLineError {
error_type: self.error_type,
input_value: match self.input_value {
InputValue::PyAny(input) => InputValue::PyAny(input.to_object(py).into_ref(py)),
InputValue::JsonInput(input) => InputValue::JsonInput(input),
InputValue::String(input) => InputValue::PyAny(input.to_object(py).into_ref(py)),
},
location: self.location,
}
}
}

#[cfg_attr(debug_assertions, derive(Debug))]
#[derive(Clone)]
pub enum InputValue<'a> {
PyAny(&'a PyAny),
JsonInput(JsonValue),
String(&'a str),
pub enum InputValue {
Python(PyObject),
Json(JsonValue),
}

impl<'a> ToPyObject for InputValue<'a> {
impl ToPyObject for InputValue {
fn to_object(&self, py: Python) -> PyObject {
match self {
Self::PyAny(input) => input.into_py(py),
Self::JsonInput(input) => input.to_object(py),
Self::String(input) => input.into_py(py),
Self::Python(input) => input.clone_ref(py),
Self::Json(input) => input.to_object(py),
}
}
}
2 changes: 1 addition & 1 deletion src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod types;
mod validation_exception;
mod value_exception;

pub use self::line_error::{InputValue, ValError, ValLineError, ValResult};
pub use self::line_error::{AsErrorValue, InputValue, ValError, ValLineError, ValResult};
pub use self::location::{AsLocItem, LocItem};
pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number};
pub use self::validation_exception::ValidationError;
Expand Down
7 changes: 6 additions & 1 deletion src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ error_types! {
DatetimeObjectInvalid {
error: {ctx_type: String, ctx_fn: field_from_context},
},
DatetimeFromDateParsing {
error: {ctx_type: Cow<'static, str>, ctx_fn: cow_field_from_context<String, _>},
},
DatetimePast {},
DatetimeFuture {},
// ---------------------
Expand Down Expand Up @@ -529,6 +532,7 @@ impl ErrorType {
Self::DatetimeType {..} => "Input should be a valid datetime",
Self::DatetimeParsing {..} => "Input should be a valid datetime, {error}",
Self::DatetimeObjectInvalid {..} => "Invalid datetime object, got {error}",
Self::DatetimeFromDateParsing {..} => "Input should be a valid datetime or date, {error}",
Self::DatetimePast {..} => "Input should be in the past",
Self::DatetimeFuture {..} => "Input should be in the future",
Self::TimezoneNaive {..} => "Input should not have timezone info",
Expand Down Expand Up @@ -684,6 +688,7 @@ impl ErrorType {
Self::DateFromDatetimeParsing { error, .. } => render!(tmpl, error),
Self::TimeParsing { error, .. } => render!(tmpl, error),
Self::DatetimeParsing { error, .. } => render!(tmpl, error),
Self::DatetimeFromDateParsing { error, .. } => render!(tmpl, error),
Self::DatetimeObjectInvalid { error, .. } => render!(tmpl, error),
Self::TimezoneOffset {
tz_expected, tz_actual, ..
Expand Down Expand Up @@ -781,7 +786,7 @@ impl From<Int> for Number {

impl FromPyObject<'_> for Number {
fn extract(obj: &PyAny) -> PyResult<Self> {
if let Ok(int) = extract_i64(obj) {
if let Some(int) = extract_i64(obj) {
Ok(Number::Int(int))
} else if let Ok(float) = obj.extract::<f64>() {
Ok(Number::Float(float))
Expand Down
68 changes: 47 additions & 21 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,31 @@ impl ValidationError {

static URL_ENV_VAR: GILOnceCell<bool> = GILOnceCell::new();

fn _get_include_url_env() -> bool {
match std::env::var("PYDANTIC_ERRORS_OMIT_URL") {
Ok(val) => val.is_empty(),
Err(_) => true,
}
}

fn include_url_env(py: Python) -> bool {
*URL_ENV_VAR.get_or_init(py, _get_include_url_env)
*URL_ENV_VAR.get_or_init(py, || {
// Check the legacy env var first.
// Using `var_os` here instead of `var` because we don't care about
// the value (or whether we're able to decode it as UTF-8), just
// whether it exists (and if it does, whether it's non-empty).
match std::env::var_os("PYDANTIC_ERRORS_OMIT_URL") {
Some(val) => {
// We don't care whether warning succeeded or not, hence the assignment
let _ = PyErr::warn(
py,
py.get_type::<pyo3::exceptions::PyDeprecationWarning>(),
"PYDANTIC_ERRORS_OMIT_URL is deprecated, use PYDANTIC_ERRORS_INCLUDE_URL instead",
1,
);
// If OMIT_URL exists but is empty, we include the URL:
val.is_empty()
}
// If the legacy env var doesn't exist, check the documented one:
None => match std::env::var("PYDANTIC_ERRORS_INCLUDE_URL") {
Ok(val) => val == "1" || val.to_lowercase() == "true",
Err(_) => true,
},
}
})
}

static URL_PREFIX: GILOnceCell<String> = GILOnceCell::new();
Expand All @@ -225,12 +241,8 @@ fn get_url_prefix(py: Python, include_url: bool) -> Option<&str> {

// used to convert a validation error back to ValError for wrap functions
impl ValidationError {
pub(crate) fn into_val_error(self, py: Python<'_>) -> ValError<'_> {
self.line_errors
.into_iter()
.map(|e| e.into_val_line_error(py))
.collect::<Vec<_>>()
.into()
pub(crate) fn into_val_error(self) -> ValError {
self.line_errors.into_iter().map(Into::into).collect::<Vec<_>>().into()
}
}

Expand Down Expand Up @@ -307,7 +319,7 @@ impl ValidationError {
include_context: bool,
include_input: bool,
) -> PyResult<&'py PyString> {
let state = SerializationState::new("iso8601", "utf8")?;
let state = SerializationState::new("iso8601", "utf8", "constants")?;
let extra = state.extra(py, &SerMode::Json, true, false, false, true, None);
let serializer = ValidationErrorSerializer {
py,
Expand Down Expand Up @@ -345,6 +357,20 @@ impl ValidationError {
fn __str__(&self, py: Python) -> String {
self.__repr__(py)
}

fn __reduce__(slf: &PyCell<Self>) -> PyResult<(&PyAny, PyObject)> {
let py = slf.py();
let callable = slf.getattr("from_exception_data")?;
let borrow = slf.try_borrow()?;
let args = (
borrow.title.as_ref(py),
borrow.errors(py, include_url_env(py), true, true)?,
borrow.input_type.into_py(py),
borrow.hide_input,
)
.into_py(slf.py());
Ok((callable, args))
}
}

// TODO: is_utf8_char_boundary, floor_char_boundary and ceil_char_boundary
Expand Down Expand Up @@ -416,7 +442,7 @@ pub struct PyLineError {
input_value: PyObject,
}

impl<'a> IntoPy<PyLineError> for ValLineError<'a> {
impl IntoPy<PyLineError> for ValLineError {
fn into_py(self, py: Python<'_>) -> PyLineError {
PyLineError {
error_type: self.error_type,
Expand All @@ -426,13 +452,13 @@ impl<'a> IntoPy<PyLineError> for ValLineError<'a> {
}
}

impl PyLineError {
impl From<PyLineError> for ValLineError {
/// Used to extract line errors from a validation error for wrap functions
fn into_val_line_error(self, py: Python<'_>) -> ValLineError<'_> {
fn from(other: PyLineError) -> ValLineError {
ValLineError {
error_type: self.error_type,
location: self.location,
input_value: InputValue::PyAny(self.input_value.into_ref(py)),
error_type: other.error_type,
location: other.location,
input_value: InputValue::Python(other.input_value),
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions src/errors/value_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use pyo3::exceptions::{PyException, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};

use crate::input::{Input, InputType};
use crate::input::InputType;
use crate::tools::extract_i64;

use super::line_error::AsErrorValue;
use super::{ErrorType, ValError};

#[pyclass(extends=PyException, module="pydantic_core._pydantic_core")]
Expand Down Expand Up @@ -105,7 +106,7 @@ impl PydanticCustomError {
}

impl PydanticCustomError {
pub fn into_val_error<'a>(self, input: &'a impl Input<'a>) -> ValError<'a> {
pub fn into_val_error(self, input: &impl AsErrorValue) -> ValError {
let error_type = ErrorType::CustomError {
error_type: self.error_type,
message_template: self.message_template,
Expand All @@ -121,7 +122,7 @@ impl PydanticCustomError {
let key: &PyString = key.downcast()?;
if let Ok(py_str) = value.downcast::<PyString>() {
message = message.replace(&format!("{{{}}}", key.to_str()?), py_str.to_str()?);
} else if let Ok(value_int) = extract_i64(value) {
} else if let Some(value_int) = extract_i64(value) {
message = message.replace(&format!("{{{}}}", key.to_str()?), &value_int.to_string());
} else {
// fallback for anything else just in case
Expand Down Expand Up @@ -184,7 +185,7 @@ impl PydanticKnownError {
}

impl PydanticKnownError {
pub fn into_val_error<'a>(self, input: &'a impl Input<'a>) -> ValError<'a> {
pub fn into_val_error(self, input: &impl AsErrorValue) -> ValError {
ValError::new(self.error_type, input)
}
}
10 changes: 5 additions & 5 deletions src/input/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl<'a> EitherDateTime<'a> {
}
}

pub fn bytes_as_date<'a>(input: &'a impl Input<'a>, bytes: &[u8]) -> ValResult<'a, EitherDate<'a>> {
pub fn bytes_as_date<'a>(input: &'a impl Input<'a>, bytes: &[u8]) -> ValResult<EitherDate<'a>> {
match Date::parse_bytes(bytes) {
Ok(date) => Ok(date.into()),
Err(err) => Err(ValError::new(
Expand All @@ -303,7 +303,7 @@ pub fn bytes_as_time<'a>(
input: &'a impl Input<'a>,
bytes: &[u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<'a, EitherTime<'a>> {
) -> ValResult<EitherTime<'a>> {
match Time::parse_bytes_with_config(
bytes,
&TimeConfig {
Expand All @@ -326,7 +326,7 @@ pub fn bytes_as_datetime<'a, 'b>(
input: &'a impl Input<'a>,
bytes: &'b [u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<'a, EitherDateTime<'a>> {
) -> ValResult<EitherDateTime<'a>> {
match DateTime::parse_bytes_with_config(
bytes,
&TimeConfig {
Expand Down Expand Up @@ -455,7 +455,7 @@ pub fn float_as_time<'a>(input: &'a impl Input<'a>, timestamp: f64) -> ValResult
int_as_time(input, timestamp.floor() as i64, microseconds.round() as u32)
}

fn map_timedelta_err<'a>(input: &'a impl Input<'a>, err: ParseError) -> ValError<'a> {
fn map_timedelta_err<'a>(input: &'a impl Input<'a>, err: ParseError) -> ValError {
ValError::new(
ErrorType::TimeDeltaParsing {
error: Cow::Borrowed(err.get_documentation().unwrap_or_default()),
Expand All @@ -469,7 +469,7 @@ pub fn bytes_as_timedelta<'a, 'b>(
input: &'a impl Input<'a>,
bytes: &'b [u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<'a, EitherTimedelta<'a>> {
) -> ValResult<EitherTimedelta<'a>> {
match Duration::parse_bytes_with_config(
bytes,
&TimeConfig {
Expand Down
26 changes: 9 additions & 17 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ use pyo3::exceptions::PyValueError;
use pyo3::types::{PyDict, PyType};
use pyo3::{intern, prelude::*};

use jiter::JsonValue;

use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, ValError, ValResult};
use crate::tools::py_err;
use crate::{PyMultiHostUrl, PyUrl};
Expand All @@ -14,7 +12,7 @@ use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, ValidationMatch};

#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum InputType {
Python,
Json,
Expand Down Expand Up @@ -49,7 +47,7 @@ impl TryFrom<&str> for InputType {
/// * `strict_*` & `lax_*` if they have different behavior
/// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same
pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized {
fn as_error_value(&'a self) -> InputValue<'a>;
fn as_error_value(&self) -> InputValue;

fn identity(&self) -> Option<usize> {
None
Expand Down Expand Up @@ -85,11 +83,9 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized {
false
}

fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>>;

fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult<'a, GenericArguments<'a>>;
fn validate_args(&'a self) -> ValResult<GenericArguments<'a>>;

fn parse_json(&'a self) -> ValResult<'a, JsonValue>;
fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult<GenericArguments<'a>>;

fn validate_str(
&'a self,
Expand All @@ -99,9 +95,9 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized {

fn validate_bytes(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a>>>;

fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>>;
fn validate_bool(&self, strict: bool) -> ValResult<ValidationMatch<bool>>;

fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>>;
fn validate_int(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherInt<'a>>>;

fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
self.validate_int(true).and_then(|val_match| {
Expand All @@ -121,7 +117,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized {
})
}

fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>>;
fn validate_float(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherFloat<'a>>>;

fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> {
if strict {
Expand Down Expand Up @@ -230,15 +226,11 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized {
) -> ValResult<ValidationMatch<EitherTimedelta>>;
}

/// The problem to solve here is that iterating a `StringMapping` returns an owned
/// `StringMapping`, but all the other iterators return references. By introducing
/// The problem to solve here is that iterating collections often returns owned
/// values, but inputs are usually taken by reference. By introducing
/// this trait we abstract over whether the return value from the iterator is owned
/// or borrowed; all we care about is that we can borrow it again with `borrow_input`
/// for some lifetime 'a.
///
/// This lifetime `'a` is shorter than the original lifetime `'data` of the input,
/// which is only a problem in error branches. To resolve we have to call `into_owned`
/// to extend out the lifetime to match the original input.
pub trait BorrowInput {
type Input<'a>: Input<'a>
where
Expand Down
55 changes: 29 additions & 26 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::datetime::{
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
};
use super::return_enums::ValidationMatch;
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_float, str_as_int};
use super::{
BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
GenericIterator, GenericMapping, Input, JsonArgs,
Expand All @@ -31,10 +31,16 @@ impl AsLocItem for JsonValue {
}
}

impl AsLocItem for &JsonValue {
fn as_loc_item(&self) -> LocItem {
AsLocItem::as_loc_item(*self)
}
}

impl<'a> Input<'a> for JsonValue {
fn as_error_value(&'a self) -> InputValue<'a> {
fn as_error_value(&self) -> InputValue {
// cloning JsonValue is cheap due to use of Arc
InputValue::JsonInput(self.clone())
InputValue::Json(self.clone())
}

fn is_none(&self) -> bool {
Expand All @@ -54,15 +60,15 @@ impl<'a> Input<'a> for JsonValue {
}
}

fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
fn validate_args(&'a self) -> ValResult<GenericArguments<'a>> {
match self {
JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()),
JsonValue::Array(array) => Ok(JsonArgs::new(Some(array), None).into()),
_ => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)),
}
}

fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> {
fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<GenericArguments<'a>> {
match self {
JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()),
_ => {
Expand All @@ -78,13 +84,6 @@ impl<'a> Input<'a> for JsonValue {
}
}

fn parse_json(&'a self) -> ValResult<'a, JsonValue> {
match self {
JsonValue::Str(s) => JsonValue::parse(s.as_bytes(), true).map_err(|e| map_json_err(self, e)),
_ => Err(ValError::new(ErrorTypeDefaults::JsonType, self)),
}
}

fn exact_str(&'a self) -> ValResult<EitherString<'a>> {
match self {
JsonValue::Str(s) => Ok(s.as_str().into()),
Expand Down Expand Up @@ -118,7 +117,7 @@ impl<'a> Input<'a> for JsonValue {
}
}

fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>> {
fn validate_bool(&self, strict: bool) -> ValResult<ValidationMatch<bool>> {
match self {
JsonValue::Bool(b) => Ok(ValidationMatch::exact(*b)),
JsonValue::Str(s) if !strict => str_as_bool(self, s).map(ValidationMatch::lax),
Expand All @@ -134,7 +133,7 @@ impl<'a> Input<'a> for JsonValue {
}
}

fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>> {
fn validate_int(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherInt<'a>>> {
match self {
JsonValue::Int(i) => Ok(ValidationMatch::exact(EitherInt::I64(*i))),
JsonValue::BigInt(b) => Ok(ValidationMatch::exact(EitherInt::BigInt(b.clone()))),
Expand All @@ -145,7 +144,7 @@ impl<'a> Input<'a> for JsonValue {
}
}

fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>> {
fn validate_float(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherFloat<'a>>> {
match self {
JsonValue::Float(f) => Ok(ValidationMatch::exact(EitherFloat::F64(*f))),
JsonValue::Int(i) => Ok(ValidationMatch::strict(EitherFloat::F64(*i as f64))),
Expand Down Expand Up @@ -326,23 +325,31 @@ impl AsLocItem for String {
}
}

impl AsLocItem for &String {
fn as_loc_item(&self) -> LocItem {
AsLocItem::as_loc_item(*self)
}
}

/// Required for JSON Object keys so the string can behave like an Input
impl<'a> Input<'a> for String {
fn as_error_value(&'a self) -> InputValue<'a> {
InputValue::String(self)
fn as_error_value(&self) -> InputValue {
// Justification for the clone: this is on the error pathway and we are generally ok
// with errors having a performance penalty
InputValue::Json(JsonValue::Str(self.clone()))
}

fn as_kwargs(&'a self, _py: Python<'a>) -> Option<&'a PyDict> {
None
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
fn validate_args(&'a self) -> ValResult<GenericArguments<'a>> {
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

#[cfg_attr(has_coverage_attribute, coverage(off))]
fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> {
fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<GenericArguments<'a>> {
let class_name = class_name.to_string();
Err(ValError::new(
ErrorType::DataclassType {
Expand All @@ -353,10 +360,6 @@ impl<'a> Input<'a> for String {
))
}

fn parse_json(&'a self) -> ValResult<'a, JsonValue> {
JsonValue::parse(self.as_bytes(), true).map_err(|e| map_json_err(self, e))
}

fn validate_str(
&'a self,
_strict: bool,
Expand All @@ -374,18 +377,18 @@ impl<'a> Input<'a> for String {
Ok(ValidationMatch::strict(self.as_bytes().into()))
}

fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch<bool>> {
fn validate_bool(&self, _strict: bool) -> ValResult<ValidationMatch<bool>> {
str_as_bool(self, self).map(ValidationMatch::lax)
}

fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>> {
fn validate_int(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherInt<'a>>> {
match self.parse() {
Ok(i) => Ok(ValidationMatch::lax(EitherInt::I64(i))),
Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)),
}
}

fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>> {
fn validate_float(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherFloat<'a>>> {
str_as_float(self, self).map(ValidationMatch::lax)
}

Expand Down
50 changes: 19 additions & 31 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use pyo3::types::{
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
use pyo3::{intern, PyTypeInfo};

use jiter::JsonValue;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
Expand All @@ -26,8 +25,7 @@ use super::datetime::{
};
use super::return_enums::ValidationMatch;
use super::shared::{
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float,
str_as_int,
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
};
use super::{
py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments,
Expand Down Expand Up @@ -98,17 +96,23 @@ impl AsLocItem for PyAny {
fn as_loc_item(&self) -> LocItem {
if let Ok(py_str) = self.downcast::<PyString>() {
py_str.to_string_lossy().as_ref().into()
} else if let Ok(key_int) = extract_i64(self) {
} else if let Some(key_int) = extract_i64(self) {
key_int.into()
} else {
safe_repr(self).to_string().into()
}
}
}

impl AsLocItem for &'_ PyAny {
fn as_loc_item(&self) -> LocItem {
AsLocItem::as_loc_item(*self)
}
}

impl<'a> Input<'a> for PyAny {
fn as_error_value(&'a self) -> InputValue<'a> {
InputValue::PyAny(self)
fn as_error_value(&self) -> InputValue {
InputValue::Python(self.into())
}

fn identity(&self) -> Option<usize> {
Expand Down Expand Up @@ -154,7 +158,7 @@ impl<'a> Input<'a> for PyAny {
self.is_callable()
}

fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
fn validate_args(&'a self) -> ValResult<GenericArguments<'a>> {
if let Ok(dict) = self.downcast::<PyDict>() {
Ok(PyArgs::new(None, Some(dict)).into())
} else if let Ok(args_kwargs) = self.extract::<ArgsKwargs>() {
Expand All @@ -170,7 +174,7 @@ impl<'a> Input<'a> for PyAny {
}
}

fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> {
fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<GenericArguments<'a>> {
if let Ok(dict) = self.downcast::<PyDict>() {
Ok(PyArgs::new(None, Some(dict)).into())
} else if let Ok(args_kwargs) = self.extract::<ArgsKwargs>() {
Expand All @@ -189,22 +193,6 @@ impl<'a> Input<'a> for PyAny {
}
}

fn parse_json(&'a self) -> ValResult<'a, JsonValue> {
let bytes = if let Ok(py_bytes) = self.downcast::<PyBytes>() {
py_bytes.as_bytes()
} else if let Ok(py_str) = self.downcast::<PyString>() {
let str = py_string_str(py_str)?;
str.as_bytes()
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
// Safety: from_slice does not run arbitrary Python code and the GIL is held so the
// bytes array will not be mutated while `JsonValue::parse` is reading it
unsafe { py_byte_array.as_bytes() }
} else {
return Err(ValError::new(ErrorTypeDefaults::JsonType, self));
};
JsonValue::parse(bytes, true).map_err(|e| map_json_err(self, e))
}

fn validate_str(
&'a self,
strict: bool,
Expand Down Expand Up @@ -296,15 +284,15 @@ impl<'a> Input<'a> for PyAny {
Err(ValError::new(ErrorTypeDefaults::BytesType, self))
}

fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>> {
fn validate_bool(&self, strict: bool) -> ValResult<ValidationMatch<bool>> {
if let Ok(bool) = self.downcast::<PyBool>() {
return Ok(ValidationMatch::exact(bool.is_true()));
}

if !strict {
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
return str_as_bool(self, &cow_str).map(ValidationMatch::lax);
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
return int_as_bool(self, int).map(ValidationMatch::lax);
} else if let Ok(float) = self.extract::<f64>() {
if let Ok(int) = float_as_int(self, float) {
Expand All @@ -319,7 +307,7 @@ impl<'a> Input<'a> for PyAny {
Err(ValError::new(ErrorTypeDefaults::BoolType, self))
}

fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>> {
fn validate_int(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherInt<'a>>> {
if self.is_exact_instance_of::<PyInt>() {
return Ok(ValidationMatch::exact(EitherInt::Py(self)));
} else if self.is_instance_of::<PyInt>() {
Expand Down Expand Up @@ -359,7 +347,7 @@ impl<'a> Input<'a> for PyAny {
Err(ValError::new(ErrorTypeDefaults::IntType, self))
}

fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>> {
fn validate_float(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherFloat<'a>>> {
if let Ok(float) = self.downcast_exact::<PyFloat>() {
return Ok(ValidationMatch::exact(EitherFloat::Py(float)));
}
Expand Down Expand Up @@ -647,7 +635,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if PyBool::is_exact_type_of(self) {
Err(ValError::new(ErrorTypeDefaults::TimeType, self))
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
int_as_time(self, int, 0)
} else if let Ok(float) = self.extract::<f64>() {
float_as_time(self, float)
Expand Down Expand Up @@ -681,7 +669,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if PyBool::is_exact_type_of(self) {
Err(ValError::new(ErrorTypeDefaults::DatetimeType, self))
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
int_as_datetime(self, int, 0)
} else if let Ok(float) = self.extract::<f64>() {
float_as_datetime(self, float)
Expand Down Expand Up @@ -718,7 +706,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior)
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
Ok(int_as_duration(self, int)?.into())
} else if let Ok(float) = self.extract::<f64>() {
Ok(float_as_duration(self, float)?.into())
Expand Down
31 changes: 10 additions & 21 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};

use jiter::JsonValue;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
Expand All @@ -12,7 +11,7 @@ use crate::validators::decimal::create_decimal;
use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
};
use super::shared::{map_json_err, str_as_bool, str_as_float};
use super::shared::{str_as_bool, str_as_float};
use super::{
BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
GenericIterator, GenericMapping, Input, ValidationMatch,
Expand All @@ -34,15 +33,15 @@ impl<'py> ToPyObject for StringMapping<'py> {
}

impl<'py> StringMapping<'py> {
pub fn new_key(py_key: &'py PyAny) -> ValResult<'py, StringMapping> {
pub fn new_key(py_key: &'py PyAny) -> ValResult<StringMapping> {
if let Ok(py_str) = py_key.downcast::<PyString>() {
Ok(Self::String(py_str))
} else {
Err(ValError::new(ErrorTypeDefaults::StringType, py_key))
}
}

pub fn new_value(py_value: &'py PyAny) -> ValResult<'py, Self> {
pub fn new_value(py_value: &'py PyAny) -> ValResult<Self> {
if let Ok(py_str) = py_value.downcast::<PyString>() {
Ok(Self::String(py_str))
} else if let Ok(value) = py_value.downcast::<PyDict>() {
Expand All @@ -63,39 +62,29 @@ impl AsLocItem for StringMapping<'_> {
}

impl<'a> Input<'a> for StringMapping<'a> {
fn as_error_value(&'a self) -> InputValue<'a> {
fn as_error_value(&self) -> InputValue {
match self {
Self::String(s) => s.as_error_value(),
Self::Mapping(d) => InputValue::PyAny(d),
Self::Mapping(d) => d.as_error_value(),
}
}

fn as_kwargs(&'a self, _py: Python<'a>) -> Option<&'a PyDict> {
None
}

fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
fn validate_args(&'a self) -> ValResult<GenericArguments<'a>> {
// do we want to support this?
Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self))
}

fn validate_dataclass_args(&'a self, _dataclass_name: &str) -> ValResult<'a, GenericArguments<'a>> {
fn validate_dataclass_args(&'a self, _dataclass_name: &str) -> ValResult<GenericArguments<'a>> {
match self {
StringMapping::String(_) => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)),
StringMapping::Mapping(m) => Ok(GenericArguments::StringMapping(m)),
}
}

fn parse_json(&'a self) -> ValResult<'a, JsonValue> {
match self {
Self::String(s) => {
let str = py_string_str(s)?;
JsonValue::parse(str.as_bytes(), true).map_err(|e| map_json_err(self, e))
}
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::JsonType, self)),
}
}

fn validate_str(
&'a self,
_strict: bool,
Expand All @@ -114,14 +103,14 @@ impl<'a> Input<'a> for StringMapping<'a> {
}
}

fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch<bool>> {
fn validate_bool(&self, _strict: bool) -> ValResult<ValidationMatch<bool>> {
match self {
Self::String(s) => str_as_bool(self, py_string_str(s)?).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BoolType, self)),
}
}

fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>> {
fn validate_int(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherInt<'a>>> {
match self {
Self::String(s) => match py_string_str(s)?.parse() {
Ok(i) => Ok(ValidationMatch::strict(EitherInt::I64(i))),
Expand All @@ -131,7 +120,7 @@ impl<'a> Input<'a> for StringMapping<'a> {
}
}

fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>> {
fn validate_float(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherFloat<'a>>> {
match self {
Self::String(s) => str_as_float(self, py_string_str(s)?).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::FloatType, self)),
Expand Down
58 changes: 29 additions & 29 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use pyo3::PyTypeInfo;
use serde::{ser::Error, Serialize, Serializer};

use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult};
use crate::tools::py_err;
use crate::tools::{extract_i64, py_err};
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};

use super::input_string::StringMapping;
Expand Down Expand Up @@ -153,7 +153,7 @@ impl<'a, INPUT: Input<'a>> MaxLengthCheck<'a, INPUT> {
}
}

fn incr(&mut self) -> ValResult<'a, ()> {
fn incr(&mut self) -> ValResult<()> {
if let Some(max_length) = self.max_length {
self.current_length += 1;
if self.current_length > max_length {
Expand Down Expand Up @@ -193,7 +193,7 @@ fn validate_iter_to_vec<'a, 's>(
mut max_length_check: MaxLengthCheck<'a, impl Input<'a>>,
validator: &'s CombinedValidator,
state: &mut ValidationState,
) -> ValResult<'a, Vec<PyObject>> {
) -> ValResult<Vec<PyObject>> {
let mut output: Vec<PyObject> = Vec::with_capacity(capacity);
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
Expand Down Expand Up @@ -259,7 +259,7 @@ fn validate_iter_to_set<'a, 's>(
max_length: Option<usize>,
validator: &'s CombinedValidator,
state: &mut ValidationState,
) -> ValResult<'a, ()> {
) -> ValResult<()> {
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?;
Expand Down Expand Up @@ -303,7 +303,7 @@ fn no_validator_iter_to_vec<'a, 's>(
input: &'a (impl Input<'a> + 'a),
iter: impl Iterator<Item = PyResult<&'a (impl Input<'a> + 'a)>>,
mut max_length_check: MaxLengthCheck<'a, impl Input<'a>>,
) -> ValResult<'a, Vec<PyObject>> {
) -> ValResult<Vec<PyObject>> {
iter.enumerate()
.map(|(index, result)| {
let v = result.map_err(|e| any_next_error!(py, e, input, index))?;
Expand Down Expand Up @@ -348,7 +348,7 @@ impl<'a> GenericIterable<'a> {
field_type: &'static str,
validator: &'s CombinedValidator,
state: &mut ValidationState,
) -> ValResult<'a, Vec<PyObject>> {
) -> ValResult<Vec<PyObject>> {
let actual_length = self.generic_len();
let capacity = actual_length.unwrap_or(DEFAULT_CAPACITY);
let max_length_check = MaxLengthCheck::new(max_length, field_type, input, actual_length);
Expand Down Expand Up @@ -381,7 +381,7 @@ impl<'a> GenericIterable<'a> {
field_type: &'static str,
validator: &'s CombinedValidator,
state: &mut ValidationState,
) -> ValResult<'a, ()> {
) -> ValResult<()> {
macro_rules! validate_set {
($iter:expr) => {
validate_iter_to_set(py, set, $iter, input, field_type, max_length, validator, state)
Expand All @@ -406,7 +406,7 @@ impl<'a> GenericIterable<'a> {
input: &'a impl Input<'a>,
field_type: &'static str,
max_length: Option<usize>,
) -> ValResult<'a, Vec<PyObject>> {
) -> ValResult<Vec<PyObject>> {
let actual_length = self.generic_len();
let max_length_check = MaxLengthCheck::new(max_length, field_type, input, actual_length);

Expand Down Expand Up @@ -456,13 +456,13 @@ pub struct DictGenericIterator<'py> {
}

impl<'py> DictGenericIterator<'py> {
pub fn new(dict: &'py PyDict) -> ValResult<'py, Self> {
pub fn new(dict: &'py PyDict) -> ValResult<Self> {
Ok(Self { dict_iter: dict.iter() })
}
}

impl<'py> Iterator for DictGenericIterator<'py> {
type Item = ValResult<'py, (&'py PyAny, &'py PyAny)>;
type Item = ValResult<(&'py PyAny, &'py PyAny)>;

fn next(&mut self) -> Option<Self::Item> {
self.dict_iter.next().map(Ok)
Expand All @@ -475,7 +475,7 @@ pub struct MappingGenericIterator<'py> {
iter: &'py PyIterator,
}

fn mapping_err<'py>(err: PyErr, py: Python<'py>, input: &'py impl Input<'py>) -> ValError<'py> {
fn mapping_err<'py>(err: PyErr, py: Python<'py>, input: &'py impl Input<'py>) -> ValError {
ValError::new(
ErrorType::MappingType {
error: py_err_string(py, err).into(),
Expand All @@ -486,7 +486,7 @@ fn mapping_err<'py>(err: PyErr, py: Python<'py>, input: &'py impl Input<'py>) ->
}

impl<'py> MappingGenericIterator<'py> {
pub fn new(mapping: &'py PyMapping) -> ValResult<'py, Self> {
pub fn new(mapping: &'py PyMapping) -> ValResult<Self> {
let py = mapping.py();
let input: &PyAny = mapping;
let iter = mapping
Expand All @@ -501,7 +501,7 @@ impl<'py> MappingGenericIterator<'py> {
const MAPPING_TUPLE_ERROR: &str = "Mapping items must be tuples of (key, value) pairs";

impl<'py> Iterator for MappingGenericIterator<'py> {
type Item = ValResult<'py, (&'py PyAny, &'py PyAny)>;
type Item = ValResult<(&'py PyAny, &'py PyAny)>;

fn next(&mut self) -> Option<Self::Item> {
Some(match self.iter.next()? {
Expand All @@ -524,14 +524,14 @@ pub struct StringMappingGenericIterator<'py> {
}

impl<'py> StringMappingGenericIterator<'py> {
pub fn new(dict: &'py PyDict) -> ValResult<'py, Self> {
pub fn new(dict: &'py PyDict) -> ValResult<Self> {
Ok(Self { dict_iter: dict.iter() })
}
}

impl<'py> Iterator for StringMappingGenericIterator<'py> {
// key (the first member of the tuple could be a simple String)
type Item = ValResult<'py, (StringMapping<'py>, StringMapping<'py>)>;
type Item = ValResult<(StringMapping<'py>, StringMapping<'py>)>;

fn next(&mut self) -> Option<Self::Item> {
match self.dict_iter.next() {
Expand All @@ -558,7 +558,7 @@ pub struct AttributesGenericIterator<'py> {
}

impl<'py> AttributesGenericIterator<'py> {
pub fn new(py_any: &'py PyAny) -> ValResult<'py, Self> {
pub fn new(py_any: &'py PyAny) -> ValResult<Self> {
Ok(Self {
object: py_any,
attributes_iterator: py_any.dir().into_iter(),
Expand All @@ -567,7 +567,7 @@ impl<'py> AttributesGenericIterator<'py> {
}

impl<'py> Iterator for AttributesGenericIterator<'py> {
type Item = ValResult<'py, (&'py PyAny, &'py PyAny)>;
type Item = ValResult<(&'py PyAny, &'py PyAny)>;

fn next(&mut self) -> Option<Self::Item> {
// loop until we find an attribute who's name does not start with underscore,
Expand Down Expand Up @@ -610,15 +610,15 @@ pub struct JsonObjectGenericIterator<'py> {
}

impl<'py> JsonObjectGenericIterator<'py> {
pub fn new(json_object: &'py JsonObject) -> ValResult<'py, Self> {
pub fn new(json_object: &'py JsonObject) -> ValResult<Self> {
Ok(Self {
object_iter: json_object.iter(),
})
}
}

impl<'py> Iterator for JsonObjectGenericIterator<'py> {
type Item = ValResult<'py, (&'py String, &'py JsonValue)>;
type Item = ValResult<(&'py String, &'py JsonValue)>;

fn next(&mut self) -> Option<Self::Item> {
self.object_iter.next().map(|(key, value)| Ok((key, value)))
Expand Down Expand Up @@ -670,8 +670,8 @@ impl GenericPyIterator {
}
}

pub fn input_as_error_value<'py>(&self, py: Python<'py>) -> InputValue<'py> {
InputValue::PyAny(self.obj.clone_ref(py).into_ref(py))
pub fn input_as_error_value(&self, py: Python<'_>) -> InputValue {
InputValue::Python(self.obj.clone_ref(py))
}

pub fn index(&self) -> usize {
Expand Down Expand Up @@ -699,8 +699,8 @@ impl GenericJsonIterator {
}
}

pub fn input_as_error_value<'py>(&self, _py: Python<'py>) -> InputValue<'py> {
InputValue::JsonInput(JsonValue::Array(self.array.clone()))
pub fn input_as_error_value(&self, _py: Python<'_>) -> InputValue {
InputValue::Json(JsonValue::Array(self.array.clone()))
}

pub fn index(&self) -> usize {
Expand Down Expand Up @@ -758,7 +758,7 @@ pub enum EitherString<'a> {
}

impl<'a> EitherString<'a> {
pub fn as_cow(&self) -> ValResult<'a, Cow<str>> {
pub fn as_cow(&self) -> ValResult<Cow<str>> {
match self {
Self::Cow(data) => Ok(data.clone()),
Self::Py(py_str) => Ok(Cow::Borrowed(py_string_str(py_str)?)),
Expand Down Expand Up @@ -800,7 +800,7 @@ impl<'a> IntoPy<PyObject> for EitherString<'a> {
pub fn py_string_str(py_str: &PyString) -> ValResult<&str> {
py_str
.to_str()
.map_err(|_| ValError::new_custom_input(ErrorTypeDefaults::StringUnicode, InputValue::PyAny(py_str as &PyAny)))
.map_err(|_| ValError::new_custom_input(ErrorTypeDefaults::StringUnicode, InputValue::Python(py_str.into())))
}

#[cfg_attr(debug_assertions, derive(Debug))]
Expand Down Expand Up @@ -863,14 +863,14 @@ pub enum EitherInt<'a> {
impl<'a> EitherInt<'a> {
pub fn upcast(py_any: &'a PyAny) -> ValResult<Self> {
// Safety: we know that py_any is a python int
if let Ok(int_64) = py_any.extract::<i64>() {
if let Some(int_64) = extract_i64(py_any) {
Ok(Self::I64(int_64))
} else {
let big_int: BigInt = py_any.extract()?;
Ok(Self::BigInt(big_int))
}
}
pub fn into_i64(self, py: Python<'a>) -> ValResult<'a, i64> {
pub fn into_i64(self, py: Python<'a>) -> ValResult<i64> {
match self {
EitherInt::I64(i) => Ok(i),
EitherInt::U64(u) => match i64::try_from(u) {
Expand All @@ -893,7 +893,7 @@ impl<'a> EitherInt<'a> {
}
}

pub fn as_int(&self) -> ValResult<'a, Int> {
pub fn as_int(&self) -> ValResult<Int> {
match self {
EitherInt::I64(i) => Ok(Int::I64(*i)),
EitherInt::U64(u) => match i64::try_from(*u) {
Expand Down Expand Up @@ -1021,7 +1021,7 @@ impl<'a> Rem for &'a Int {

impl<'a> FromPyObject<'a> for Int {
fn extract(obj: &'a PyAny) -> PyResult<Self> {
if let Ok(i) = obj.extract::<i64>() {
if let Some(i) = extract_i64(obj) {
Ok(Int::I64(i))
} else if let Ok(b) = obj.extract::<BigInt>() {
Ok(Int::Big(b))
Expand Down
25 changes: 7 additions & 18 deletions src/input/shared.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use pyo3::sync::GILOnceCell;
use pyo3::{intern, Py, PyAny, Python, ToPyObject};

use jiter::JsonValueError;
use num_bigint::BigInt;

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};
use crate::errors::{ErrorTypeDefaults, ValError, ValResult};

use super::{EitherFloat, EitherInt, Input};
static ENUM_META_OBJECT: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
Expand All @@ -20,17 +19,7 @@ pub fn get_enum_meta_object(py: Python) -> Py<PyAny> {
.clone()
}

pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> ValError<'a> {
ValError::new(
ErrorType::JsonInvalid {
error: error.to_string(),
context: None,
},
input,
)
}

pub fn str_as_bool<'a>(input: &'a impl Input<'a>, str: &str) -> ValResult<'a, bool> {
pub fn str_as_bool<'a>(input: &'a impl Input<'a>, str: &str) -> ValResult<bool> {
if str == "0"
|| str.eq_ignore_ascii_case("f")
|| str.eq_ignore_ascii_case("n")
Expand All @@ -52,7 +41,7 @@ pub fn str_as_bool<'a>(input: &'a impl Input<'a>, str: &str) -> ValResult<'a, bo
}
}

pub fn int_as_bool<'a>(input: &'a impl Input<'a>, int: i64) -> ValResult<'a, bool> {
pub fn int_as_bool<'a>(input: &'a impl Input<'a>, int: i64) -> ValResult<bool> {
if int == 0 {
Ok(false)
} else if int == 1 {
Expand Down Expand Up @@ -82,7 +71,7 @@ fn strip_underscores(s: &str) -> Option<String> {
/// max length of the input is 4300, see
/// https://docs.python.org/3/whatsnew/3.11.html#other-cpython-implementation-changes and
/// https://github.com/python/cpython/issues/95778 for more info in that length bound
pub fn str_as_int<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult<'s, EitherInt<'s>> {
pub fn str_as_int<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult<EitherInt<'s>> {
let len = str.len();
if len > 4300 {
Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input))
Expand All @@ -106,7 +95,7 @@ pub fn str_as_int<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult<
}

/// parse a float as a float
pub fn str_as_float<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult<'s, EitherFloat<'s>> {
pub fn str_as_float<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult<EitherFloat<'s>> {
match str.parse() {
Ok(float) => Ok(EitherFloat::F64(float)),
Err(_) => match strip_underscores(str).and_then(|stripped| stripped.parse().ok()) {
Expand Down Expand Up @@ -140,7 +129,7 @@ fn strip_decimal_zeros(s: &str) -> Option<&str> {
None
}

pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult<'a, EitherInt<'a>> {
pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult<EitherInt<'a>> {
if float.is_infinite() || float.is_nan() {
Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input))
} else if float % 1.0 != 0.0 {
Expand All @@ -152,7 +141,7 @@ pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult<'a,
}
}

pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a PyAny) -> ValResult<'a, EitherInt<'a>> {
pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a PyAny) -> ValResult<EitherInt<'a>> {
if !decimal.call_method0(intern!(py, "is_finite"))?.extract::<bool>()? {
return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input));
}
Expand Down
22 changes: 10 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ extern crate core;
use std::sync::OnceLock;

use pyo3::exceptions::PyTypeError;
use pyo3::types::{PyByteArray, PyBytes, PyString};
use pyo3::{prelude::*, sync::GILOnceCell};

// parse this first to get access to the contained macro
Expand Down Expand Up @@ -37,17 +36,16 @@ pub use serializers::{
};
pub use validators::{validate_core_schema, PySome, SchemaValidator};

#[pyfunction(signature = (data, *, allow_inf_nan=true))]
pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool) -> PyResult<PyObject> {
if let Ok(py_bytes) = data.downcast::<PyBytes>() {
jiter::python_parse(py, py_bytes.as_bytes(), allow_inf_nan)
} else if let Ok(py_str) = data.downcast::<PyString>() {
jiter::python_parse(py, py_str.to_str()?.as_bytes(), allow_inf_nan)
} else if let Ok(py_byte_array) = data.downcast::<PyByteArray>() {
jiter::python_parse(py, &py_byte_array.to_vec(), allow_inf_nan)
} else {
Err(PyTypeError::new_err("Expected bytes, bytearray or str"))
}
use crate::input::Input;

#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=true))]
pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool, cache_strings: bool) -> PyResult<PyObject> {
let v_match = data
.validate_bytes(false)
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
let json_either_bytes = v_match.into_inner();
let json_bytes = json_either_bytes.as_slice();
jiter::python_parse(py, json_bytes, allow_inf_nan, cache_strings).map_err(|e| jiter::map_json_error(json_bytes, &e))
}

pub fn get_pydantic_core_version() -> &'static str {
Expand Down
41 changes: 15 additions & 26 deletions src/lookup_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,6 @@ impl fmt::Display for LookupKey {
}
}

macro_rules! py_string {
($py:ident, $str:expr) => {
PyString::intern($py, $str).into()
};
}

impl LookupKey {
pub fn from_py(py: Python, value: &PyAny, alt_alias: Option<&str>) -> PyResult<Self> {
if let Ok(alias_py) = value.downcast::<PyString>() {
Expand All @@ -67,7 +61,7 @@ impl LookupKey {
py_key1: alias_py.into_py(py),
path1: LookupPath::from_str(py, alias, Some(alias_py)),
key2: alt_alias.to_string(),
py_key2: py_string!(py, alt_alias),
py_key2: PyString::new(py, alt_alias).into(),
path2: LookupPath::from_str(py, alt_alias, None),
}),
None => Ok(Self::simple(py, alias, Some(alias_py))),
Expand Down Expand Up @@ -98,20 +92,20 @@ impl LookupKey {

fn simple(py: Python, key: &str, opt_py_key: Option<&PyString>) -> Self {
let py_key = match opt_py_key {
Some(py_key) => py_key.into_py(py),
None => py_string!(py, key),
Some(py_key) => py_key,
None => PyString::new(py, key),
};
Self::Simple {
key: key.to_string(),
py_key,
py_key: py_key.into(),
path: LookupPath::from_str(py, key, opt_py_key),
}
}

pub fn py_get_dict_item<'data, 's>(
&'s self,
dict: &'data PyDict,
) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> {
) -> ValResult<Option<(&'s LookupPath, &'data PyAny)>> {
match self {
Self::Simple { py_key, path, .. } => match dict.get_item(py_key)? {
Some(value) => Ok(Some((path, value))),
Expand Down Expand Up @@ -148,7 +142,7 @@ impl LookupKey {
pub fn py_get_string_mapping_item<'data, 's>(
&'s self,
dict: &'data PyDict,
) -> ValResult<'data, Option<(&'s LookupPath, StringMapping<'data>)>> {
) -> ValResult<Option<(&'s LookupPath, StringMapping<'data>)>> {
if let Some((path, py_any)) = self.py_get_dict_item(dict)? {
let value = StringMapping::new_value(py_any)?;
Ok(Some((path, value)))
Expand All @@ -160,7 +154,7 @@ impl LookupKey {
pub fn py_get_mapping_item<'data, 's>(
&'s self,
dict: &'data PyMapping,
) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> {
) -> ValResult<Option<(&'s LookupPath, &'data PyAny)>> {
match self {
Self::Simple { py_key, path, .. } => match dict.get_item(py_key) {
Ok(value) => Ok(Some((path, value))),
Expand Down Expand Up @@ -198,7 +192,7 @@ impl LookupKey {
&'s self,
obj: &'data PyAny,
kwargs: Option<&'data PyDict>,
) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> {
) -> ValResult<Option<(&'s LookupPath, &'data PyAny)>> {
match self._py_get_attr(obj, kwargs) {
Ok(v) => Ok(v),
Err(err) => {
Expand Down Expand Up @@ -266,7 +260,7 @@ impl LookupKey {
pub fn json_get<'data, 's>(
&'s self,
dict: &'data JsonObject,
) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonValue)>> {
) -> ValResult<Option<(&'s LookupPath, &'data JsonValue)>> {
match self {
Self::Simple { key, path, .. } => match dict.get(key) {
Some(value) => Ok(Some((path, value))),
Expand Down Expand Up @@ -316,7 +310,7 @@ impl LookupKey {
input: &'d impl Input<'d>,
loc_by_alias: bool,
field_name: &str,
) -> ValLineError<'d> {
) -> ValLineError {
if loc_by_alias {
let lookup_path = match self {
Self::Simple { path, .. } => path,
Expand Down Expand Up @@ -348,10 +342,10 @@ impl fmt::Display for LookupPath {
impl LookupPath {
fn from_str(py: Python, key: &str, py_key: Option<&PyString>) -> Self {
let py_key = match py_key {
Some(py_key) => py_key.into_py(py),
None => py_string!(py, key),
Some(py_key) => py_key,
None => PyString::new(py, key),
};
Self(vec![PathItem::S(key.to_string(), py_key)])
Self(vec![PathItem::S(key.to_string(), py_key.into())])
}

fn from_list(obj: &PyAny) -> PyResult<LookupPath> {
Expand All @@ -369,12 +363,7 @@ impl LookupPath {
}
}

pub fn apply_error_loc<'a>(
&self,
mut line_error: ValLineError<'a>,
loc_by_alias: bool,
field_name: &str,
) -> ValLineError<'a> {
pub fn apply_error_loc(&self, mut line_error: ValLineError, loc_by_alias: bool, field_name: &str) -> ValLineError {
if loc_by_alias {
for path_item in self.iter().rev() {
line_error = line_error.with_outer_location(path_item.clone().into());
Expand Down Expand Up @@ -440,7 +429,7 @@ impl PathItem {
} else {
Ok(Self::Pos(usize_key))
}
} else if let Ok(int_key) = extract_i64(obj) {
} else if let Some(int_key) = extract_i64(obj) {
if index == 0 {
py_err!(PyTypeError; "The first item in an alias path should be a string")
} else {
Expand Down
Loading