Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 47 additions & 37 deletions pydantic_core/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import sys
from datetime import date, datetime, time
from typing import Any, Callable, Dict, List, Sequence, Union
from typing import Any, Callable, Dict, List, Union

if sys.version_info < (3, 11):
from typing_extensions import NotRequired, Required
else:
from typing import NotRequired
from typing import NotRequired, Required

if sys.version_info < (3, 8):
from typing_extensions import Literal, TypedDict
Expand All @@ -19,9 +19,10 @@ class AnySchema(TypedDict):
type: Literal['any']


class BoolSchema(TypedDict):
type: Literal['bool']
strict: NotRequired[bool]
class BoolSchema(TypedDict, total=False):
type: Required[Literal['bool']]
strict: bool
ref: str


class ConfigSchema(TypedDict, total=False):
Expand All @@ -39,6 +40,7 @@ class DictSchema(TypedDict, total=False):
min_items: int
max_items: int
strict: bool
ref: str


class FloatSchema(TypedDict, total=False):
Expand All @@ -49,20 +51,22 @@ class FloatSchema(TypedDict, total=False):
lt: float
gt: float
strict: bool
default: float
ref: str


class FunctionSchema(TypedDict):
type: Literal['function']
mode: Literal['before', 'after', 'wrap']
function: Callable[..., Any]
schema: Schema
ref: NotRequired[str]


class FunctionPlainSchema(TypedDict):
type: Literal['function']
mode: Literal['plain']
function: Callable[..., Any]
ref: NotRequired[str]


class IntSchema(TypedDict, total=False):
Expand All @@ -73,6 +77,7 @@ class IntSchema(TypedDict, total=False):
lt: int
gt: int
strict: bool
ref: str


class ListSchema(TypedDict, total=False):
Expand All @@ -81,17 +86,20 @@ class ListSchema(TypedDict, total=False):
min_items: int
max_items: int
strict: bool
ref: str


class LiteralSchema(TypedDict):
type: Literal['literal']
expected: Sequence[Any]
expected: List[Any]
ref: NotRequired[str]


class ModelClassSchema(TypedDict):
type: Literal['model-class']
class_type: type
schema: TypedDictSchema
ref: NotRequired[str]


class TypedDictField(TypedDict, total=False):
Expand All @@ -102,33 +110,30 @@ class TypedDictField(TypedDict, total=False):
aliases: List[List[Union[str, int]]]


class TypedDictSchema(TypedDict):
type: Literal['typed-dict']
fields: Dict[str, TypedDictField]
extra_validator: NotRequired[Schema]
config: NotRequired[ConfigSchema]
return_fields_set: NotRequired[bool]
class TypedDictSchema(TypedDict, total=False):
type: Required[Literal['typed-dict']]
fields: Required[Dict[str, TypedDictField]]
extra_validator: Schema
config: ConfigSchema
return_fields_set: bool
ref: str


class NoneSchema(TypedDict):
type: Literal['none']
ref: NotRequired[str]


class NullableSchema(TypedDict):
type: Literal['nullable']
schema: Schema
strict: NotRequired[bool]
class NullableSchema(TypedDict, total=False):
type: Required[Literal['nullable']]
schema: Required[Schema]
strict: bool
ref: str


class RecursiveReferenceSchema(TypedDict):
type: Literal['recursive-ref']
name: str


class RecursiveContainerSchema(TypedDict):
type: Literal['recursive-container']
name: str
schema: Schema
schema_ref: str


class SetSchema(TypedDict, total=False):
Expand All @@ -137,6 +142,7 @@ class SetSchema(TypedDict, total=False):
min_items: int
max_items: int
strict: bool
ref: str


class FrozenSetSchema(TypedDict, total=False):
Expand All @@ -145,6 +151,7 @@ class FrozenSetSchema(TypedDict, total=False):
min_items: int
max_items: int
strict: bool
ref: str


class StringSchema(TypedDict, total=False):
Expand All @@ -156,20 +163,22 @@ class StringSchema(TypedDict, total=False):
to_lower: bool
to_upper: bool
strict: bool
ref: str


class UnionSchema(TypedDict):
type: Literal['union']
choices: List[Schema]
strict: NotRequired[bool]
default: NotRequired[Any]
class UnionSchema(TypedDict, total=False):
type: Required[Literal['union']]
choices: Required[List[Schema]]
strict: bool
ref: str


class BytesSchema(TypedDict, total=False):
type: Required[Literal['bytes']]
max_length: int
min_length: int
strict: bool
ref: str


class DateSchema(TypedDict, total=False):
Expand All @@ -179,7 +188,7 @@ class DateSchema(TypedDict, total=False):
ge: date
lt: date
gt: date
default: date
ref: str


class TimeSchema(TypedDict, total=False):
Expand All @@ -189,7 +198,7 @@ class TimeSchema(TypedDict, total=False):
ge: time
lt: time
gt: time
default: time
ref: str


class DatetimeSchema(TypedDict, total=False):
Expand All @@ -199,13 +208,14 @@ class DatetimeSchema(TypedDict, total=False):
ge: datetime
lt: datetime
gt: datetime
default: datetime
ref: str


class TupleFixLenSchema(TypedDict):
type: Literal['tuple-fix-len']
items_schema: List[Schema]
strict: NotRequired[bool]
class TupleFixLenSchema(TypedDict, total=False):
type: Required[Literal['tuple-fix-len']]
items_schema: Required[List[Schema]]
strict: bool
ref: str


class TupleVarLenSchema(TypedDict, total=False):
Expand All @@ -214,6 +224,7 @@ class TupleVarLenSchema(TypedDict, total=False):
min_items: int
max_items: int
strict: bool
ref: str


# pydantic allows types to be defined via a simple string instead of dict with just `type`, e.g.
Expand Down Expand Up @@ -256,7 +267,6 @@ class TupleVarLenSchema(TypedDict, total=False):
ModelClassSchema,
NoneSchema,
NullableSchema,
RecursiveContainerSchema,
RecursiveReferenceSchema,
SetSchema,
FrozenSetSchema,
Expand Down
60 changes: 36 additions & 24 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,35 @@ pub trait BuildValidator: Sized {
) -> PyResult<CombinedValidator>;
}

fn build_single_validator<'a, T: BuildValidator>(
val_type: &str,
schema_dict: &'a PyDict,
config: Option<&'a PyDict>,
build_context: &mut BuildContext,
) -> PyResult<(CombinedValidator, &'a PyDict)> {
build_context.incr_check_depth()?;

let val: CombinedValidator = if let Some(schema_ref) = schema_dict.get_as::<String>("ref")? {
let slot_id = build_context.prepare_slot(schema_ref)?;
let inner_val = T::build(schema_dict, config, build_context)
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?;
build_context.complete_slot(slot_id, inner_val);
recursive::RecursiveContainerValidator::create(slot_id)
} else {
T::build(schema_dict, config, build_context)
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?
};

build_context.decr_depth();
Ok((val, schema_dict))
}

// macro to build the match statement for validator selection
macro_rules! validator_match {
($type:ident, $dict:ident, $config:ident, $build_context:ident, $($validator:path,)+) => {
match $type {
$(
<$validator>::EXPECTED_TYPE => {
$build_context.incr_check_depth()?;
let val = <$validator>::build($dict, $config, $build_context).map_err(|err| {
SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", $type, err))
})?;
$build_context.decr_depth();
Ok((val, $dict))
},
<$validator>::EXPECTED_TYPE => build_single_validator::<$validator>($type, $dict, $config, $build_context),
)+
_ => {
return py_error!(r#"Unknown schema type: "{}""#, $type)
Expand Down Expand Up @@ -221,7 +237,6 @@ pub fn build_validator<'a>(
// functions - before, after, plain & wrap
function::FunctionBuilder,
// recursive (self-referencing) models
recursive::RecursiveValidator,
recursive::RecursiveRefValidator,
// literals
literal::LiteralBuilder,
Expand Down Expand Up @@ -294,7 +309,7 @@ pub enum CombinedValidator {
FunctionPlain(function::FunctionPlainValidator),
FunctionWrap(function::FunctionWrapValidator),
// recursive (self-referencing) models
Recursive(recursive::RecursiveValidator),
Recursive(recursive::RecursiveContainerValidator),
RecursiveRef(recursive::RecursiveRefValidator),
// literals
LiteralSingleString(literal::LiteralSingleStringValidator),
Expand Down Expand Up @@ -348,29 +363,26 @@ pub trait Validator: Send + Sync + Clone + Debug {
fn get_name(&self, py: Python) -> String;
}

#[derive(Default)]
pub struct BuildContext {
named_slots: Vec<(Option<String>, Option<CombinedValidator>)>,
depth: usize,
}

const MAX_DEPTH: usize = 100;

impl Default for BuildContext {
fn default() -> Self {
let named_slots: Vec<(Option<String>, Option<CombinedValidator>)> = Vec::new();
BuildContext { named_slots, depth: 0 }
}
}

impl BuildContext {
pub fn add_named_slot(&mut self, name: String, schema: &PyAny, config: Option<&PyDict>) -> PyResult<usize> {
pub fn prepare_slot(&mut self, slot_ref: String) -> PyResult<usize> {
let id = self.named_slots.len();
self.named_slots.push((Some(name), None));
let validator = build_validator(schema, config, self)?.0;
self.named_slots[id] = (None, Some(validator));
self.named_slots.push((Some(slot_ref), None));
Ok(id)
}

pub fn complete_slot(&mut self, slot_id: usize, validator: CombinedValidator) {
let (name, _) = self.named_slots.get(slot_id).unwrap();
self.named_slots[slot_id] = (name.clone(), Some(validator));
}

pub fn incr_check_depth(&mut self) -> PyResult<()> {
self.depth += 1;
if self.depth > MAX_DEPTH {
Expand All @@ -384,14 +396,14 @@ impl BuildContext {
self.depth -= 1;
}

pub fn find_id(&self, name: &str) -> PyResult<usize> {
pub fn find_slot_id(&self, slot_ref: &str) -> PyResult<usize> {
let is_match = |(n, _): &(Option<String>, Option<CombinedValidator>)| match n {
Some(n) => n == name,
Some(n) => n == slot_ref,
None => false,
};
match self.named_slots.iter().position(is_match) {
Some(id) => Ok(id),
None => py_error!("Recursive reference error: ref '{}' not found", name),
None => py_error!("Recursive reference error: ref '{}' not found", slot_ref),
}
}

Expand Down
Loading