Skip to content

Commit

Permalink
ValidationException.errors() are ErrorDetails
Browse files Browse the repository at this point in the history
Update the documentation to explain that `RequestValidationError` isn't
literally a subclass since Pydantic V2.
  • Loading branch information
tamird committed May 7, 2024
1 parent 25a8af6 commit 338a33e
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 30 deletions.
2 changes: 1 addition & 1 deletion docs/en/docs/tutorial/handling-errors.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ path -> item_id
!!! warning
These are technical details that you might skip if it's not important for you now.

`RequestValidationError` is a sub-class of Pydantic's <a href="https://docs.pydantic.dev/latest/concepts/models/#error-handling" class="external-link" target="_blank">`ValidationError`</a>.
`RequestValidationError` is morally a sub-class of Pydantic's <a href="https://docs.pydantic.dev/latest/concepts/models/#error-handling" class="external-link" target="_blank">`ValidationError`</a>.

**FastAPI** uses it so that, if you use a Pydantic model in `response_model`, and your data has an error, you will see the error in your log.

Expand Down
27 changes: 17 additions & 10 deletions fastapi/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
Union,
)

from fastapi.exceptions import RequestErrorModel
from fastapi.exceptions import ErrorDetails, RequestErrorModel
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, create_model
from pydantic.version import VERSION as P_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, Literal, get_args, get_origin
from typing_extensions import Annotated, Literal, TypeAlias, get_args, get_origin

# Reassign variable to make it reexported for mypy
PYDANTIC_VERSION = P_VERSION
Expand Down Expand Up @@ -82,6 +82,9 @@ class BaseConfig:
class ErrorWrapper(Exception):
pass

# See https://github.com/pydantic/pydantic/blob/07b6473/pydantic/v1/error_wrappers.py#L45-L47
ErrorList: TypeAlias = Union[Sequence[Any], ErrorWrapper]

@dataclass
class ModelField:
field_info: FieldInfo
Expand Down Expand Up @@ -115,13 +118,14 @@ def get_default(self) -> Any:
return Undefined
return self.field_info.get_default(call_default_factory=True)

# See https://github.com/pydantic/pydantic/blob/07b6473/pydantic/v1/fields.py#L850-L852 for the signature.
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
) -> Tuple[Any, Union[ErrorList, None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
Expand Down Expand Up @@ -167,7 +171,7 @@ def get_annotation_from_field_info(
) -> Any:
return annotation

def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
def _normalize_errors(errors: Sequence[Any]) -> List[ErrorDetails]:
return errors # type: ignore[return-value]

def _model_rebuild(model: Type[BaseModel]) -> None:
Expand Down Expand Up @@ -265,12 +269,12 @@ def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]

def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
error["input"] = None
return error # type: ignore[return-value]
return error

def create_body_model(
*, fields: Sequence[ModelField], model_name: str
Expand All @@ -289,6 +293,9 @@ def create_body_model(
from pydantic.class_validators import ( # type: ignore[no-redef]
Validator as Validator,
)
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorList as ErrorList,
)
from pydantic.error_wrappers import ( # type: ignore[no-redef]
ErrorWrapper as ErrorWrapper,
)
Expand Down Expand Up @@ -418,7 +425,7 @@ def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
return True
return False

def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
def _normalize_errors(errors: Sequence[Any]) -> List[ErrorDetails]:
use_errors: List[Any] = []
for error in errors:
if isinstance(error, ErrorWrapper):
Expand Down Expand Up @@ -500,10 +507,10 @@ def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]

def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorDetails:
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
return new_error.errors()[0]

def create_body_model(
*, fields: Sequence[ModelField], model_name: str
Expand All @@ -516,7 +523,7 @@ def create_body_model(

def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
) -> List[ErrorDetails]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
Expand Down
26 changes: 14 additions & 12 deletions fastapi/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.exceptions import ErrorDetails
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
Expand Down Expand Up @@ -539,7 +540,7 @@ async def solve_dependencies(
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
]:
values: Dict[str, Any] = {}
errors: List[Any] = []
errors: List[ErrorDetails] = []
if response is None:
response = Response()
del response.headers["content-length"]
Expand Down Expand Up @@ -652,9 +653,9 @@ async def solve_dependencies(
def request_params_to_args(
required_params: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
) -> Tuple[Dict[str, Any], List[Any]]:
) -> Tuple[Dict[str, Any], List[ErrorDetails]]:
values = {}
errors = []
errors: List[ErrorDetails] = []
for field in required_params:
if is_scalar_sequence_field(field) and isinstance(
received_params, (QueryParams, Headers)
Expand All @@ -675,10 +676,9 @@ def request_params_to_args(
continue
v_, errors_ = field.validate(value, values, loc=loc)
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
errors.extend(_regenerate_error_with_loc(errors=[errors_], loc_prefix=()))
elif isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
errors.extend(new_errors)
errors.extend(_regenerate_error_with_loc(errors=errors_, loc_prefix=()))
else:
values[field.name] = v_
return values, errors
Expand All @@ -687,9 +687,9 @@ def request_params_to_args(
async def request_body_to_args(
required_params: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
) -> Tuple[Dict[str, Any], List[ErrorDetails]]:
values = {}
errors: List[Dict[str, Any]] = []
errors: List[ErrorDetails] = []
if required_params:
field = required_params[0]
field_info = field.field_info
Expand Down Expand Up @@ -757,10 +757,12 @@ async def process_fn(

v_, errors_ = field.validate(value, values, loc=loc)

if isinstance(errors_, list):
errors.extend(errors_)
elif errors_:
errors.append(errors_)
if isinstance(errors_, ErrorWrapper):
errors.extend(
_regenerate_error_with_loc(errors=[errors_], loc_prefix=())
)
elif isinstance(errors_, list):
errors.extend(_regenerate_error_with_loc(errors=errors_, loc_prefix=()))
else:
values[field.name] = v_
return values, errors
Expand Down
25 changes: 19 additions & 6 deletions fastapi/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from typing import Any, Dict, Optional, Sequence, Type, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

from pydantic import BaseModel, create_model
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.exceptions import WebSocketException as StarletteWebSocketException
from typing_extensions import Annotated, Doc
from typing_extensions import Annotated, Doc, TypedDict


class ErrorDetails(TypedDict):
"""
The common subset shared by `ErrorDict` in Pydantic V1[0] and `ErrorDetails` in Pydantic V2[1].
[0] https://github.com/pydantic/pydantic/blob/4d7bef62aeff10985fe67d13477fe666b13ba070/pydantic/v1/error_wrappers.py#L21-L22
[1] https://github.com/pydantic/pydantic-core/blob/e1fc99dd3207157aad77defc20ab6873fd268b5b/python/pydantic_core/__init__.py#L73-L91
"""

loc: Tuple[Union[int, str], ...]
msg: str
type: str


class HTTPException(StarletteHTTPException):
Expand Down Expand Up @@ -147,15 +160,15 @@ class FastAPIError(RuntimeError):


class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None:
def __init__(self, errors: Sequence[ErrorDetails]) -> None:
self._errors = errors

def errors(self) -> Sequence[Any]:
def errors(self) -> Sequence[ErrorDetails]:
return self._errors


class RequestValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
def __init__(self, errors: Sequence[ErrorDetails], *, body: Any = None) -> None:
super().__init__(errors)
self.body = body

Expand All @@ -165,7 +178,7 @@ class WebSocketRequestValidationError(ValidationException):


class ResponseValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
def __init__(self, errors: Sequence[ErrorDetails], *, body: Any = None) -> None:
super().__init__(errors)
self.body = body

Expand Down
2 changes: 1 addition & 1 deletion fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ async def app(request: Request) -> Response:
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
}
} # type: ignore [typeddict-unknown-key]
],
body=e.doc,
)
Expand Down

0 comments on commit 338a33e

Please sign in to comment.