/
utils.py
223 lines (201 loc) · 7.85 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import re
import warnings
from dataclasses import is_dataclass
from typing import (
TYPE_CHECKING,
Any,
Dict,
MutableMapping,
Optional,
Set,
Type,
Union,
cast,
)
from weakref import WeakKeyDictionary
import fastapi
from fastapi._compat import (
PYDANTIC_V2,
BaseConfig,
ModelField,
PydanticSchemaGenerationError,
Undefined,
UndefinedType,
Validator,
lenient_issubclass,
)
from fastapi.datastructures import DefaultPlaceholder, DefaultType
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from typing_extensions import Literal
if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute
# Cache for `create_cloned_field`
_CLONED_TYPES_CACHE: MutableMapping[
Type[BaseModel], Type[BaseModel]
] = WeakKeyDictionary()
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
if status_code is None:
return True
# Ref: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#patterned-fields-1
if status_code in {
"default",
"1XX",
"2XX",
"3XX",
"4XX",
"5XX",
}:
return True
current_status_code = int(status_code)
return not (current_status_code < 200 or current_status_code in {204, 205, 304})
def get_path_param_names(path: str) -> Set[str]:
return set(re.findall("{(.*?)}", path))
def create_response_field(
name: str,
type_: Type[Any],
class_validators: Optional[Dict[str, Validator]] = None,
default: Optional[Any] = Undefined,
required: Union[bool, UndefinedType] = Undefined,
model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
mode: Literal["validation", "serialization"] = "validation",
) -> ModelField:
"""
Create a new response field. Raises if type_ is invalid.
"""
class_validators = class_validators or {}
if PYDANTIC_V2:
field_info = field_info or FieldInfo(
annotation=type_, default=default, alias=alias
)
else:
field_info = field_info or FieldInfo()
kwargs = {"name": name, "field_info": field_info}
if PYDANTIC_V2:
kwargs.update({"mode": mode})
else:
kwargs.update(
{
"type_": type_,
"class_validators": class_validators,
"default": default,
"required": required,
"model_config": model_config,
"alias": alias,
}
)
try:
return ModelField(**kwargs) # type: ignore[arg-type]
except (RuntimeError, PydanticSchemaGenerationError):
raise fastapi.exceptions.FastAPIError(
"Invalid args for response field! Hint: "
f"check that {type_} is a valid Pydantic field type. "
"If you are using a return type annotation that is not a valid Pydantic "
"field (e.g. Union[Response, dict, None]) you can disable generating the "
"response model from the type annotation with the path operation decorator "
"parameter response_model=None. Read more: "
"https://fastapi.tiangolo.com/tutorial/response-model/"
) from None
def create_cloned_field(
field: ModelField,
*,
cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None,
) -> ModelField:
if PYDANTIC_V2:
return field
# cloned_types caches already cloned types to support recursive models and improve
# performance by avoiding unnecessary cloning
if cloned_types is None:
cloned_types = _CLONED_TYPES_CACHE
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
use_type = cloned_types.get(original_type)
if use_type is None:
use_type = create_model(original_type.__name__, __base__=original_type)
cloned_types[original_type] = use_type
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(
f, cloned_types=cloned_types
)
new_field = create_response_field(name=field.name, type_=use_type)
new_field.has_alias = field.has_alias # type: ignore[attr-defined]
new_field.alias = field.alias # type: ignore[misc]
new_field.class_validators = field.class_validators # type: ignore[attr-defined]
new_field.default = field.default # type: ignore[misc]
new_field.required = field.required # type: ignore[misc]
new_field.model_config = field.model_config # type: ignore[attr-defined]
new_field.field_info = field.field_info
new_field.allow_none = field.allow_none # type: ignore[attr-defined]
new_field.validate_always = field.validate_always # type: ignore[attr-defined]
if field.sub_fields: # type: ignore[attr-defined]
new_field.sub_fields = [ # type: ignore[attr-defined]
create_cloned_field(sub_field, cloned_types=cloned_types)
for sub_field in field.sub_fields # type: ignore[attr-defined]
]
if field.key_field: # type: ignore[attr-defined]
new_field.key_field = create_cloned_field( # type: ignore[attr-defined]
field.key_field, # type: ignore[attr-defined]
cloned_types=cloned_types,
)
new_field.validators = field.validators # type: ignore[attr-defined]
new_field.pre_validators = field.pre_validators # type: ignore[attr-defined]
new_field.post_validators = field.post_validators # type: ignore[attr-defined]
new_field.parse_json = field.parse_json # type: ignore[attr-defined]
new_field.shape = field.shape # type: ignore[attr-defined]
new_field.populate_validators() # type: ignore[attr-defined]
return new_field
def generate_operation_id_for_path(
*, name: str, path: str, method: str
) -> str: # pragma: nocover
warnings.warn(
"fastapi.utils.generate_operation_id_for_path() was deprecated, "
"it is not used internally, and will be removed soon",
DeprecationWarning,
stacklevel=2,
)
operation_id = f"{name}{path}"
operation_id = re.sub(r"\W", "_", operation_id)
operation_id = f"{operation_id}_{method.lower()}"
return operation_id
def generate_unique_id(route: "APIRoute") -> str:
operation_id = f"{route.name}{route.path_format}"
operation_id = re.sub(r"\W", "_", operation_id)
assert route.methods
operation_id = f"{operation_id}_{list(route.methods)[0].lower()}"
return operation_id
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
for key, value in update_dict.items():
if (
key in main_dict
and isinstance(main_dict[key], dict)
and isinstance(value, dict)
):
deep_dict_update(main_dict[key], value)
elif (
key in main_dict
and isinstance(main_dict[key], list)
and isinstance(update_dict[key], list)
):
main_dict[key] = main_dict[key] + update_dict[key]
else:
main_dict[key] = value
def get_value_or_default(
first_item: Union[DefaultPlaceholder, DefaultType],
*extra_items: Union[DefaultPlaceholder, DefaultType],
) -> Union[DefaultPlaceholder, DefaultType]:
"""
Pass items or `DefaultPlaceholder`s by descending priority.
The first one to _not_ be a `DefaultPlaceholder` will be returned.
Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
"""
items = (first_item,) + extra_items
for item in items:
if not isinstance(item, DefaultPlaceholder):
return item
return first_item