/
utils.py
278 lines (222 loc) · 8.44 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import inspect
import re
import sys
from contextlib import contextmanager
from enum import Enum
from functools import lru_cache
from importlib import import_module
from textwrap import dedent
from typing import _eval_type # type: ignore
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generator, List, Optional, Pattern, Tuple, Type, Union
from . import errors
try:
import email_validator
except ImportError:
email_validator = None
try:
from typing import _TypingBase as typing_base # type: ignore
except ImportError:
from typing import _Final as typing_base # type: ignore
try:
from typing import ForwardRef # type: ignore
except ImportError:
# python 3.6
ForwardRef = None
if TYPE_CHECKING: # pragma: no cover
from .main import BaseModel # noqa: F401
if sys.version_info < (3, 7):
from typing import Callable
AnyCallable = Callable[..., Any]
else:
from collections.abc import Callable
from typing import Callable as TypingCallable
AnyCallable = TypingCallable[..., Any]
PRETTY_REGEX = re.compile(r'([\w ]*?) *<(.*)> *')
AnyType = Type[Any]
def validate_email(value: str) -> Tuple[str, str]:
"""
Brutally simple email address validation. Note unlike most email address validation
* raw ip address (literal) domain parts are not allowed.
* "John Doe <local_part@domain.com>" style "pretty" email addresses are processed
* the local part check is extremely basic. This raises the possibility of unicode spoofing, but no better
solution is really possible.
* spaces are striped from the beginning and end of addresses but no error is raised
See RFC 5322 but treat it with suspicion, there seems to exist no universally acknowledged test for a valid email!
"""
if email_validator is None:
raise ImportError('email-validator is not installed, run `pip install pydantic[email]`')
m = PRETTY_REGEX.fullmatch(value)
name: Optional[str] = None
if m:
name, value = m.groups()
email = value.strip()
try:
email_validator.validate_email(email, check_deliverability=False)
except email_validator.EmailNotValidError as e:
raise errors.EmailError() from e
return name or email[: email.index('@')], email.lower()
def _rfc_1738_quote(text: str) -> str:
return re.sub(r'[:@/]', lambda m: '%{:X}'.format(ord(m.group(0))), text)
def make_dsn(
*,
driver: str,
user: str = None,
password: str = None,
host: str = None,
port: str = None,
name: str = None,
query: Dict[str, Any] = None,
) -> str:
"""
Create a DSN from from connection settings.
Stolen approximately from sqlalchemy/engine/url.py:URL.
"""
s = driver + '://'
if user is not None:
s += _rfc_1738_quote(user)
if password is not None:
s += ':' + _rfc_1738_quote(password)
s += '@'
if host is not None:
if ':' in host:
s += '[{}]'.format(host)
else:
s += host
if port is not None:
s += ':{}'.format(int(port))
if name is not None:
s += '/' + name
query = query or {}
if query:
keys = list(query)
keys.sort()
s += '?' + '&'.join('{}={}'.format(k, query[k]) for k in keys)
return s
def import_string(dotted_path: str) -> Any:
"""
Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import fails.
"""
try:
module_path, class_name = dotted_path.strip(' ').rsplit('.', 1)
except ValueError as e:
raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e
module = import_module(module_path)
try:
return getattr(module, class_name)
except AttributeError as e:
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
def truncate(v: str, *, max_len: int = 80) -> str:
"""
Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long
"""
if isinstance(v, str) and len(v) > (max_len - 2):
# -3 so quote + string + … + quote has correct length
return repr(v[: (max_len - 3)] + '…')
v = repr(v)
if len(v) > max_len:
v = v[: max_len - 1] + '…'
return v
def display_as_type(v: AnyType) -> str:
if not isinstance(v, typing_base) and not isinstance(v, type):
v = type(v)
if lenient_issubclass(v, Enum):
if issubclass(v, int):
return 'int'
elif issubclass(v, str):
return 'str'
else:
return 'enum'
try:
return v.__name__
except AttributeError:
# happens with unions
return str(v)
ExcType = Type[Exception]
@contextmanager
def change_exception(raise_exc: ExcType, *except_types: ExcType) -> Generator[None, None, None]:
try:
yield
except except_types as e:
raise raise_exc from e
def clean_docstring(d: str) -> str:
return dedent(d).strip(' \r\n\t')
def sequence_like(v: AnyType) -> bool:
return isinstance(v, (list, tuple, set)) or inspect.isgenerator(v)
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
"""
Ensure that the field's name does not shadow an existing attribute of the model.
"""
for base in bases:
if getattr(base, field_name, None):
raise NameError(
f'Field name "{field_name}" shadows a BaseModel attribute; '
f'use a different field name with "alias=\'{field_name}\'".'
)
@lru_cache(maxsize=None)
def url_regex_generator(*, relative: bool, require_tld: bool) -> Pattern[str]:
"""
Url regex generator taken from Marshmallow library,
for details please follow library source code:
https://github.com/marshmallow-code/marshmallow/blob/298870ef6c089fb4d91efae9ca4168453ffe00d2/marshmallow/validate.py#L37
"""
return re.compile(
r''.join(
(
r'^',
r'(' if relative else r'',
r'(?:[a-z0-9\.\-\+]*)://', # scheme is validated separately
r'(?:[^:@]+?:[^:@]*?@|)', # basic auth
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+',
r'(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|', # domain...
r'localhost|', # localhost...
(
r'(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.?)|' if not require_tld else r''
), # allow dotless hostnames
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|', # ...or ipv4
r'\[[A-F0-9]*:[A-F0-9:]+\])', # ...or ipv6
r'(?::\d+)?', # optional port
r')?' if relative else r'', # host is optional, allow for relative URLs
r'(?:/?|[/?]\S+)$',
)
),
re.IGNORECASE,
)
def lenient_issubclass(cls: Any, class_or_tuple: Union[AnyType, Tuple[AnyType, ...]]) -> bool:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
def in_ipython() -> bool:
"""
Check whether we're in an ipython environment, including jupyter notebooks.
"""
try:
__IPYTHON__ # type: ignore
except NameError:
return False
else: # pragma: no cover
return True
def resolve_annotations(raw_annotations: Dict[str, AnyType], module_name: Optional[str]) -> Dict[str, AnyType]:
"""
Partially taken from typing.get_type_hints.
Resolve string or ForwardRef annotations into type objects if possible.
"""
if module_name:
base_globals: Optional[Dict[str, Any]] = sys.modules[module_name].__dict__
else:
base_globals = None
annotations = {}
for name, value in raw_annotations.items():
if isinstance(value, str):
value = ForwardRef(value, is_argument=False)
try:
value = _eval_type(value, base_globals, None)
except NameError:
# this is ok, it can be fixed with update_forward_refs
pass
annotations[name] = value
return annotations
def is_callable_type(type_: AnyType) -> bool:
return type_ is Callable or getattr(type_, '__origin__', None) is Callable
def _check_classvar(v: AnyType) -> bool:
return type(v) == type(ClassVar) and (sys.version_info < (3, 7) or getattr(v, '_name', None) == 'ClassVar')
def is_classvar(ann_type: AnyType) -> bool:
return _check_classvar(ann_type) or _check_classvar(getattr(ann_type, '__origin__', None))