Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🦺 Update implementarion and docs for RequestValidationError in pydantic v2 #11542

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/docs/tutorial/handling-errors.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ These are technical details that you might skip if it's not important for you no

///

`RequestValidationError` is a sub-class of Pydantic's <a href="https://docs.pydantic.dev/latest/concepts/models/#error-handling" class="external-link" target="_blank">`ValidationError`</a>.
`RequestValidationError` is morally a sub-class of Pydantic's <a href="https://docs.pydantic.dev/latest/concepts/models/#error-handling" class="external-link" target="_blank">`ValidationError`</a>.

**FastAPI** uses it so that, if you use a Pydantic model in `response_model`, and your data has an error, you will see the error in your log.

Expand Down
106 changes: 62 additions & 44 deletions fastapi/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@
Union,
)

from fastapi.exceptions import RequestErrorModel
from fastapi.exceptions import ErrorDetails, RequestErrorModel
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, create_model
from pydantic.version import VERSION as P_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, Literal, get_args, get_origin
from typing_extensions import (
Annotated,
Literal,
TypeAlias,
assert_never,
get_args,
get_origin,
)

# Reassign variable to make it reexported for mypy
PYDANTIC_VERSION = P_VERSION
Expand Down Expand Up @@ -68,7 +75,7 @@
)
except ImportError: # pragma: no cover
from pydantic_core.core_schema import (
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
general_plain_validator_function as with_info_plain_validator_function,
)

Required = PydanticUndefined
Expand All @@ -83,6 +90,9 @@ class BaseConfig:
class ErrorWrapper(Exception):
pass

# See https://github.com/pydantic/pydantic/blob/07b6473/pydantic/v1/error_wrappers.py#L45-L47
ErrorList: TypeAlias = Union[Sequence["ErrorList"], ErrorWrapper]

@dataclass
class ModelField:
field_info: FieldInfo
Expand Down Expand Up @@ -116,22 +126,25 @@ def get_default(self) -> Any:
return Undefined
return self.field_info.get_default(call_default_factory=True)

# See https://github.com/pydantic/pydantic/blob/07b6473/pydantic/v1/fields.py#L850-L852 for the signature.
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
) -> Tuple[Any, Union[ErrorList, Sequence[ErrorDetails], None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
None,
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
errors: List[ErrorDetails] = [
{**err, "loc": loc + err.get("loc", ())} # type: ignore [typeddict-unknown-key]
for err in exc.errors(include_url=False)
]
return None, errors

def serialize(
self,
Expand Down Expand Up @@ -168,8 +181,16 @@ def get_annotation_from_field_info(
) -> Any:
return annotation

def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
return errors # type: ignore[return-value]
def _normalize_errors(
errors: Union[ErrorList, Sequence[ErrorDetails]],
) -> List[ErrorDetails]:
assert isinstance(errors, Sequence), type(errors)
use_errors: List[ErrorDetails] = []
for error in errors:
assert not isinstance(error, ErrorWrapper)
assert not isinstance(error, Sequence)
use_errors.append(error)
return use_errors

def _model_rebuild(model: Type[BaseModel]) -> None:
model.model_rebuild()
Expand Down Expand Up @@ -266,12 +287,12 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]

def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
error["input"] = None
return error # type: ignore[return-value]
return error

def create_body_model(
*, fields: Sequence[ModelField], model_name: str
Expand All @@ -290,14 +311,17 @@ def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
from pydantic import ( # type: ignore[assignment]
BaseConfig as BaseConfig, # noqa: F401
BaseConfig as BaseConfig,
)
from pydantic import ValidationError as ValidationError # noqa: F401
from pydantic import ValidationError as ValidationError
from pydantic.class_validators import ( # type: ignore[no-redef]
Validator as Validator, # noqa: F401
Validator as Validator,
)
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorWrapper as ErrorWrapper, # noqa: F401
ErrorList as ErrorList,
)
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorWrapper as ErrorWrapper,
)
from pydantic.errors import MissingError
from pydantic.fields import ( # type: ignore[attr-defined]
Expand All @@ -311,31 +335,31 @@ def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
)
from pydantic.fields import FieldInfo as FieldInfo
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
ModelField as ModelField, # noqa: F401
ModelField as ModelField,
)
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
Required as Required, # noqa: F401
Required as Required,
)
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
Undefined as Undefined,
)
from pydantic.fields import ( # type: ignore[no-redef, attr-defined]
UndefinedType as UndefinedType, # noqa: F401
UndefinedType as UndefinedType,
)
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
model_process_schema,
)
from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401
from pydantic.schema import ( # type: ignore[no-redef]
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.typing import ( # type: ignore[no-redef]
evaluate_forwardref as evaluate_forwardref, # noqa: F401
evaluate_forwardref as evaluate_forwardref,
)
from pydantic.utils import ( # type: ignore[no-redef]
lenient_issubclass as lenient_issubclass, # noqa: F401
lenient_issubclass as lenient_issubclass,
)

GetJsonSchemaHandler = Any # type: ignore[assignment,misc]
Expand Down Expand Up @@ -425,18 +449,23 @@ def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
return True
return False

def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
use_errors: List[Any] = []
for error in errors:
if isinstance(error, ErrorWrapper):
new_errors = ValidationError( # type: ignore[call-arg]
errors=[error], model=RequestErrorModel
def _normalize_errors(
errors: Union[ErrorList, Sequence[ErrorDetails]],
) -> List[ErrorDetails]:
use_errors: List[ErrorDetails] = []
if isinstance(errors, ErrorWrapper):
use_errors.extend(
ValidationError( # type: ignore[call-arg]
errors=[errors], model=RequestErrorModel
).errors()
use_errors.extend(new_errors)
elif isinstance(error, list):
)
elif isinstance(errors, Sequence):
for error in errors:
assert not isinstance(error, dict)
use_errors.extend(_normalize_errors(error))
else:
use_errors.append(error)
return use_errors
else:
assert_never(errors) # pragma: no cover
return use_errors

def _model_rebuild(model: Type[BaseModel]) -> None:
Expand Down Expand Up @@ -507,10 +536,10 @@ def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]

def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails:
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
return new_error.errors()[0]

def create_body_model(
*, fields: Sequence[ModelField], model_name: str
Expand All @@ -524,17 +553,6 @@ def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]


def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]

return updated_loc_errors


def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
Expand Down
27 changes: 13 additions & 14 deletions fastapi/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
from fastapi import params
from fastapi._compat import (
PYDANTIC_V2,
ErrorWrapper,
ModelField,
Required,
Undefined,
_regenerate_error_with_loc,
_normalize_errors,
copy_field_info,
create_body_model,
evaluate_forwardref,
Expand All @@ -52,6 +51,7 @@
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.exceptions import ErrorDetails
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
Expand Down Expand Up @@ -553,7 +553,7 @@ async def solve_generator(
@dataclass
class SolvedDependency:
values: Dict[str, Any]
errors: List[Any]
errors: List[ErrorDetails]
background_tasks: Optional[StarletteBackgroundTasks]
response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
Expand All @@ -572,7 +572,7 @@ async def solve_dependencies(
embed_body_fields: bool,
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
errors: List[ErrorDetails] = []
if response is None:
response = Response()
del response.headers["content-length"]
Expand Down Expand Up @@ -648,7 +648,8 @@ async def solve_dependencies(
values.update(query_values)
values.update(header_values)
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
for errors_ in (path_errors, query_errors, header_errors, cookie_errors):
errors.extend(errors_)
if dependant.body_params:
(
body_values,
Expand Down Expand Up @@ -687,17 +688,15 @@ async def solve_dependencies(

def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
) -> Tuple[Any, List[ErrorDetails]]:
if value is None:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc)
if isinstance(errors_, ErrorWrapper):
return None, [errors_]
elif isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
if errors_ is not None:
new_errors = _normalize_errors(errors_)
return None, new_errors
else:
return v_, []
Expand Down Expand Up @@ -730,9 +729,9 @@ def _get_multidict_value(
def request_params_to_args(
fields: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
) -> Tuple[Dict[str, Any], List[Any]]:
) -> Tuple[Dict[str, Any], List[ErrorDetails]]:
values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = []
errors: List[ErrorDetails] = []

if not fields:
return values, errors
Expand Down Expand Up @@ -867,9 +866,9 @@ async def request_body_to_args(
body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]],
embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
) -> Tuple[Dict[str, Any], List[ErrorDetails]]:
values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = []
errors: List[ErrorDetails] = []
assert body_fields, "request_body_to_args() should be called with fields"
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
first_field = body_fields[0]
Expand Down
25 changes: 19 additions & 6 deletions fastapi/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from typing import Any, Dict, Optional, Sequence, Type, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

from pydantic import BaseModel, create_model
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.exceptions import WebSocketException as StarletteWebSocketException
from typing_extensions import Annotated, Doc
from typing_extensions import Annotated, Doc, TypedDict


class ErrorDetails(TypedDict):
"""
The common subset shared by `ErrorDict` in Pydantic V1[0] and `ErrorDetails` in Pydantic V2[1].

[0] https://github.com/pydantic/pydantic/blob/4d7bef62aeff10985fe67d13477fe666b13ba070/pydantic/v1/error_wrappers.py#L21-L22
[1] https://github.com/pydantic/pydantic-core/blob/e1fc99dd3207157aad77defc20ab6873fd268b5b/python/pydantic_core/__init__.py#L73-L91
"""

loc: Tuple[Union[int, str], ...]
msg: str
type: str


class HTTPException(StarletteHTTPException):
Expand Down Expand Up @@ -147,15 +160,15 @@ class FastAPIError(RuntimeError):


class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None:
def __init__(self, errors: Sequence[ErrorDetails]) -> None:
self._errors = errors

def errors(self) -> Sequence[Any]:
def errors(self) -> Sequence[ErrorDetails]:
return self._errors


class RequestValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
def __init__(self, errors: Sequence[ErrorDetails], *, body: Any = None) -> None:
super().__init__(errors)
self.body = body

Expand All @@ -165,7 +178,7 @@ class WebSocketRequestValidationError(ValidationException):


class ResponseValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
def __init__(self, errors: Sequence[ErrorDetails], *, body: Any = None) -> None:
super().__init__(errors)
self.body = body

Expand Down
Loading