diff --git a/docs/en/docs/tutorial/handling-errors.md b/docs/en/docs/tutorial/handling-errors.md index 98ac55d1f7722..175820be8e0ab 100644 --- a/docs/en/docs/tutorial/handling-errors.md +++ b/docs/en/docs/tutorial/handling-errors.md @@ -163,7 +163,7 @@ path -> item_id !!! warning These are technical details that you might skip if it's not important for you now. -`RequestValidationError` is a sub-class of Pydantic's `ValidationError`. +`RequestValidationError` is morally a sub-class of Pydantic's `ValidationError`. **FastAPI** uses it so that, if you use a Pydantic model in `response_model`, and your data has an error, you will see the error in your log. diff --git a/fastapi/_compat.py b/fastapi/_compat.py index ad1410158fbd0..14b9a3cba97a0 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -17,12 +17,19 @@ Union, ) -from fastapi.exceptions import RequestErrorModel +from fastapi.exceptions import ErrorDetails, RequestErrorModel from fastapi.types import IncEx, ModelNameMap, UnionType from pydantic import BaseModel, create_model from pydantic.version import VERSION as P_VERSION from starlette.datastructures import UploadFile -from typing_extensions import Annotated, Literal, get_args, get_origin +from typing_extensions import ( + Annotated, + Literal, + TypeAlias, + assert_never, + get_args, + get_origin, +) # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION @@ -82,6 +89,9 @@ class BaseConfig: class ErrorWrapper(Exception): pass + # See https://github.com/pydantic/pydantic/blob/07b6473/pydantic/v1/error_wrappers.py#L45-L47 + ErrorList: TypeAlias = Union[Sequence["ErrorList"], ErrorWrapper] + @dataclass class ModelField: field_info: FieldInfo @@ -115,22 +125,25 @@ def get_default(self) -> Any: return Undefined return self.field_info.get_default(call_default_factory=True) + # See https://github.com/pydantic/pydantic/blob/07b6473/pydantic/v1/fields.py#L850-L852 for the signature. def validate( self, value: Any, values: Dict[str, Any] = {}, # noqa: B006 *, loc: Tuple[Union[int, str], ...] = (), - ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + ) -> Tuple[Any, Union[ErrorList, Sequence[ErrorDetails], None]]: try: return ( self._type_adapter.validate_python(value, from_attributes=True), None, ) except ValidationError as exc: - return None, _regenerate_error_with_loc( - errors=exc.errors(include_url=False), loc_prefix=loc - ) + errors: List[ErrorDetails] = [ + {**err, "loc": loc + err.get("loc", ())} # type: ignore [typeddict-unknown-key] + for err in exc.errors(include_url=False) + ] + return None, errors def serialize( self, @@ -167,8 +180,16 @@ def get_annotation_from_field_info( ) -> Any: return annotation - def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: - return errors # type: ignore[return-value] + def _normalize_errors( + errors: Union[ErrorList, Sequence[ErrorDetails]], + ) -> Sequence[ErrorDetails]: + assert isinstance(errors, Sequence), type(errors) + use_errors: List[ErrorDetails] = [] + for error in errors: + assert not isinstance(error, ErrorWrapper) + assert not isinstance(error, Sequence) + use_errors.append(error) + return use_errors def _model_rebuild(model: Type[BaseModel]) -> None: model.model_rebuild() @@ -265,12 +286,12 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: assert issubclass(origin_type, sequence_types) # type: ignore[arg-type] return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return] - def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: + def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails: error = ValidationError.from_exception_data( "Field required", [{"type": "missing", "loc": loc, "input": {}}] ).errors(include_url=False)[0] error["input"] = None - return error # type: ignore[return-value] + return error def create_body_model( *, fields: Sequence[ModelField], model_name: str @@ -289,6 +310,9 @@ def create_body_model( from pydantic.class_validators import ( # type: ignore[no-redef] Validator as Validator, ) + from pydantic.error_wrappers import ( # type: ignore[no-redef] + ErrorList as ErrorList, + ) from pydantic.error_wrappers import ( # type: ignore[no-redef] ErrorWrapper as ErrorWrapper, ) @@ -418,18 +442,23 @@ def is_pv1_scalar_sequence_field(field: ModelField) -> bool: return True return False - def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: - use_errors: List[Any] = [] - for error in errors: - if isinstance(error, ErrorWrapper): - new_errors = ValidationError( # type: ignore[call-arg] - errors=[error], model=RequestErrorModel + def _normalize_errors( + errors: Union[ErrorList, Sequence[ErrorDetails]], + ) -> Sequence[ErrorDetails]: + use_errors: List[ErrorDetails] = [] + if isinstance(errors, ErrorWrapper): + use_errors.extend( + ValidationError( # type: ignore[call-arg] + errors=[errors], model=RequestErrorModel ).errors() - use_errors.extend(new_errors) - elif isinstance(error, list): + ) + elif isinstance(errors, Sequence): + for error in errors: + assert not isinstance(error, dict) use_errors.extend(_normalize_errors(error)) - else: - use_errors.append(error) + return use_errors + else: + assert_never(errors) # pragma: no cover return use_errors def _model_rebuild(model: Type[BaseModel]) -> None: @@ -500,10 +529,10 @@ def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined] - def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: + def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails: missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg] new_error = ValidationError([missing_field_error], RequestErrorModel) - return new_error.errors()[0] # type: ignore[return-value] + return new_error.errors()[0] def create_body_model( *, fields: Sequence[ModelField], model_name: str @@ -514,17 +543,6 @@ def create_body_model( return BodyModel -def _regenerate_error_with_loc( - *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] -) -> List[Dict[str, Any]]: - updated_loc_errors: List[Any] = [ - {**err, "loc": loc_prefix + err.get("loc", ())} - for err in _normalize_errors(errors) - ] - - return updated_loc_errors - - def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: if lenient_issubclass(annotation, (str, bytes)): return False diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 4f984177a4085..77917d26fa820 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -21,11 +21,10 @@ from fastapi import params from fastapi._compat import ( PYDANTIC_V2, - ErrorWrapper, ModelField, Required, Undefined, - _regenerate_error_with_loc, + _normalize_errors, copy_field_info, create_body_model, evaluate_forwardref, @@ -50,6 +49,7 @@ contextmanager_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement +from fastapi.exceptions import ErrorDetails from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes @@ -533,13 +533,13 @@ async def solve_dependencies( async_exit_stack: AsyncExitStack, ) -> Tuple[ Dict[str, Any], - List[Any], + Sequence[ErrorDetails], Optional[StarletteBackgroundTasks], Response, Dict[Tuple[Callable[..., Any], Tuple[str]], Any], ]: values: Dict[str, Any] = {} - errors: List[Any] = [] + errors: List[ErrorDetails] = [] if response is None: response = Response() del response.headers["content-length"] @@ -620,7 +620,8 @@ async def solve_dependencies( values.update(query_values) values.update(header_values) values.update(cookie_values) - errors += path_errors + query_errors + header_errors + cookie_errors + for errors_ in (path_errors, query_errors, header_errors, cookie_errors): + errors.extend(errors_) if dependant.body_params: ( body_values, @@ -652,9 +653,9 @@ async def solve_dependencies( def request_params_to_args( required_params: Sequence[ModelField], received_params: Union[Mapping[str, Any], QueryParams, Headers], -) -> Tuple[Dict[str, Any], List[Any]]: +) -> Tuple[Dict[str, Any], Sequence[ErrorDetails]]: values = {} - errors = [] + errors: List[ErrorDetails] = [] for field in required_params: if is_scalar_sequence_field(field) and isinstance( received_params, (QueryParams, Headers) @@ -674,11 +675,8 @@ def request_params_to_args( values[field.name] = deepcopy(field.default) continue v_, errors_ = field.validate(value, values, loc=loc) - if isinstance(errors_, ErrorWrapper): - errors.append(errors_) - elif isinstance(errors_, list): - new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) - errors.extend(new_errors) + if errors_ is not None: + errors.extend(_normalize_errors(errors_)) else: values[field.name] = v_ return values, errors @@ -687,9 +685,9 @@ def request_params_to_args( async def request_body_to_args( required_params: List[ModelField], received_body: Optional[Union[Dict[str, Any], FormData]], -) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: +) -> Tuple[Dict[str, Any], Sequence[ErrorDetails]]: values = {} - errors: List[Dict[str, Any]] = [] + errors: List[ErrorDetails] = [] if required_params: field = required_params[0] field_info = field.field_info @@ -756,11 +754,8 @@ async def process_fn( value = serialize_sequence_value(field=field, value=results) v_, errors_ = field.validate(value, values, loc=loc) - - if isinstance(errors_, list): - errors.extend(errors_) - elif errors_: - errors.append(errors_) + if errors_ is not None: + errors.extend(_normalize_errors(errors_)) else: values[field.name] = v_ return values, errors diff --git a/fastapi/exceptions.py b/fastapi/exceptions.py index 44d4ada86d7e4..dd4de78d7d424 100644 --- a/fastapi/exceptions.py +++ b/fastapi/exceptions.py @@ -1,9 +1,22 @@ -from typing import Any, Dict, Optional, Sequence, Type, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union from pydantic import BaseModel, create_model from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import WebSocketException as StarletteWebSocketException -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc, TypedDict + + +class ErrorDetails(TypedDict): + """ + The common subset shared by `ErrorDict` in Pydantic V1[0] and `ErrorDetails` in Pydantic V2[1]. + + [0] https://github.com/pydantic/pydantic/blob/4d7bef62aeff10985fe67d13477fe666b13ba070/pydantic/v1/error_wrappers.py#L21-L22 + [1] https://github.com/pydantic/pydantic-core/blob/e1fc99dd3207157aad77defc20ab6873fd268b5b/python/pydantic_core/__init__.py#L73-L91 + """ + + loc: Tuple[Union[int, str], ...] + msg: str + type: str class HTTPException(StarletteHTTPException): @@ -147,15 +160,15 @@ class FastAPIError(RuntimeError): class ValidationException(Exception): - def __init__(self, errors: Sequence[Any]) -> None: + def __init__(self, errors: Sequence[ErrorDetails]) -> None: self._errors = errors - def errors(self) -> Sequence[Any]: + def errors(self) -> Sequence[ErrorDetails]: return self._errors class RequestValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + def __init__(self, errors: Sequence[ErrorDetails], *, body: Any = None) -> None: super().__init__(errors) self.body = body @@ -165,7 +178,7 @@ class WebSocketRequestValidationError(ValidationException): class ResponseValidationError(ValidationException): - def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + def __init__(self, errors: Sequence[ErrorDetails], *, body: Any = None) -> None: super().__init__(errors) self.body = body diff --git a/fastapi/routing.py b/fastapi/routing.py index fa1351859fb91..25ba9dd47267e 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -132,7 +132,6 @@ async def serialize_response( is_coroutine: bool = True, ) -> Any: if field: - errors = [] if not hasattr(field, "serialize"): # pydantic v1 response_content = _prepare_response_content( @@ -142,15 +141,11 @@ async def serialize_response( exclude_none=exclude_none, ) if is_coroutine: - value, errors_ = field.validate(response_content, {}, loc=("response",)) + value, errors = field.validate(response_content, {}, loc=("response",)) else: - value, errors_ = await run_in_threadpool( + value, errors = await run_in_threadpool( field.validate, response_content, {}, loc=("response",) ) - if isinstance(errors_, list): - errors.extend(errors_) - elif errors_: - errors.append(errors_) if errors: raise ResponseValidationError( errors=_normalize_errors(errors), body=response_content @@ -251,7 +246,7 @@ async def app(request: Request) -> Response: "msg": "JSON decode error", "input": {}, "ctx": {"error": e.msg}, - } + } # type: ignore [typeddict-unknown-key] ], body=e.doc, ) @@ -264,7 +259,6 @@ async def app(request: Request) -> Response: status_code=400, detail="There was an error parsing the body" ) raise http_error from e - errors: List[Any] = [] async with AsyncExitStack() as async_exit_stack: solved_result = await solve_dependencies( request=request, @@ -309,9 +303,7 @@ async def app(request: Request) -> Response: response.body = b"" response.headers.raw.extend(sub_response.headers.raw) if errors: - validation_error = RequestValidationError( - _normalize_errors(errors), body=body - ) + validation_error = RequestValidationError(errors, body=body) raise validation_error if response is None: raise FastAPIError( @@ -343,7 +335,7 @@ async def app(websocket: WebSocket) -> None: ) values, errors, _, _2, _3 = solved_result if errors: - raise WebSocketRequestValidationError(_normalize_errors(errors)) + raise WebSocketRequestValidationError(errors) assert dependant.call is not None, "dependant.call must be a function" await dependant.call(**values)