-
-
Notifications
You must be signed in to change notification settings - Fork 35
/
sessions.py
363 lines (315 loc) · 11.5 KB
/
sessions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
"""prawcore.sessions: Provides prawcore.Session and prawcore.session."""
import logging
import random
import time
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urljoin
from requests.exceptions import ChunkedEncodingError, ConnectionError, ReadTimeout
from requests.status_codes import codes
from .auth import BaseAuthorizer
from .const import TIMEOUT
from .exceptions import (
BadJSON,
BadRequest,
Conflict,
InvalidInvocation,
NotFound,
Redirect,
RequestException,
ServerError,
SpecialError,
TooLarge,
TooManyRequests,
UnavailableForLegalReasons,
URITooLong,
)
from .rate_limit import RateLimiter
from .util import authorization_error_class
if TYPE_CHECKING:
from io import BufferedReader
from requests.models import Response
from .auth import Authorizer
from .requestor import Requestor
log = logging.getLogger(__package__)
class RetryStrategy(ABC):
"""An abstract class for scheduling request retries.
The strategy controls both the number and frequency of retry attempts.
Instances of this class are immutable.
"""
@abstractmethod
def _sleep_seconds(self) -> Optional[float]:
pass
def sleep(self) -> None:
"""Sleep until we are ready to attempt the request."""
sleep_seconds = self._sleep_seconds()
if sleep_seconds is not None:
message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to retry"
log.debug(message)
time.sleep(sleep_seconds)
class FiniteRetryStrategy(RetryStrategy):
"""A ``RetryStrategy`` that retries requests a finite number of times."""
def _sleep_seconds(self) -> Optional[float]:
if self._retries < 3:
base = 0 if self._retries == 2 else 2
return base + 2 * random.random()
return None
def __init__(self, retries: int = 3) -> None:
"""Initialize the strategy.
:param retries: Number of times to attempt a request (default: ``3``).
"""
self._retries = retries
def consume_available_retry(self) -> "FiniteRetryStrategy":
"""Allow one fewer retry."""
return type(self)(self._retries - 1)
def should_retry_on_failure(self) -> bool:
"""Return ``True`` if and only if the strategy will allow another retry."""
return self._retries > 1
class Session(object):
"""The low-level connection interface to Reddit's API."""
RETRY_EXCEPTIONS = (ChunkedEncodingError, ConnectionError, ReadTimeout)
RETRY_STATUSES = {
520,
522,
codes["bad_gateway"],
codes["gateway_timeout"],
codes["internal_server_error"],
codes["request_timeout"],
codes["service_unavailable"],
}
STATUS_EXCEPTIONS = {
codes["bad_gateway"]: ServerError,
codes["bad_request"]: BadRequest,
codes["conflict"]: Conflict,
codes["found"]: Redirect,
codes["forbidden"]: authorization_error_class,
codes["gateway_timeout"]: ServerError,
codes["internal_server_error"]: ServerError,
codes["media_type"]: SpecialError,
codes["moved_permanently"]: Redirect,
codes["not_found"]: NotFound,
codes["request_entity_too_large"]: TooLarge,
codes["request_uri_too_large"]: URITooLong,
codes["service_unavailable"]: ServerError,
codes["too_many_requests"]: TooManyRequests,
codes["unauthorized"]: authorization_error_class,
codes[
"unavailable_for_legal_reasons"
]: UnavailableForLegalReasons, # Cloudflare's status (not named in requests)
520: ServerError,
522: ServerError,
}
SUCCESS_STATUSES = {codes["accepted"], codes["created"], codes["ok"]}
@staticmethod
def _log_request(
data: Optional[List[Tuple[str, str]]],
method: str,
params: Dict[str, int],
url: str,
) -> None:
log.debug(f"Fetching: {method} {url}")
log.debug(f"Data: {data}")
log.debug(f"Params: {params}")
def __init__(
self,
authorizer: Optional[BaseAuthorizer],
) -> None:
"""Prepare the connection to Reddit's API.
:param authorizer: An instance of :class:`.Authorizer`.
"""
if not isinstance(authorizer, BaseAuthorizer):
raise InvalidInvocation(f"invalid Authorizer: {authorizer}")
self._authorizer = authorizer
self._rate_limiter = RateLimiter()
self._retry_strategy_class = FiniteRetryStrategy
def __enter__(self) -> "Session":
"""Allow this object to be used as a context manager."""
return self
def __exit__(self, *_args) -> None:
"""Allow this object to be used as a context manager."""
self.close()
def _do_retry(
self,
data: List[Tuple[str, Any]],
files: Dict[str, "BufferedReader"],
json: Dict[str, Any],
method: str,
params: Dict[str, int],
response: Optional["Response"],
retry_strategy_state: "FiniteRetryStrategy",
saved_exception: Optional[Exception],
timeout: float,
url: str,
) -> Optional[Union[Dict[str, Any], str]]:
if saved_exception:
status = repr(saved_exception)
else:
status = response.status_code
log.warning(f"Retrying due to {status} status: {method} {url}")
return self._request_with_retries(
data=data,
files=files,
json=json,
method=method,
params=params,
timeout=timeout,
url=url,
retry_strategy_state=retry_strategy_state.consume_available_retry(),
# noqa: E501
)
def _make_request(
self,
data: List[Tuple[str, Any]],
files: Dict[str, "BufferedReader"],
json: Dict[str, Any],
method: str,
params: Dict[str, Any],
retry_strategy_state: "FiniteRetryStrategy",
timeout: float,
url: str,
) -> Union[Tuple["Response", None], Tuple[None, Exception]]:
try:
response = self._rate_limiter.call(
self._requestor.request,
self._set_header_callback,
method,
url,
allow_redirects=False,
data=data,
files=files,
json=json,
params=params,
timeout=timeout,
)
log.debug(
f"Response: {response.status_code}"
f" ({response.headers.get('content-length')} bytes)"
)
return response, None
except RequestException as exception:
if (
not retry_strategy_state.should_retry_on_failure()
or not isinstance( # noqa: E501
exception.original_exception, self.RETRY_EXCEPTIONS
)
):
raise
return None, exception.original_exception
def _request_with_retries(
self,
data: List[Tuple[str, Any]],
files: Dict[str, "BufferedReader"],
json: Dict[str, Any],
method: str,
params: Dict[str, int],
timeout: float,
url: str,
retry_strategy_state: Optional["FiniteRetryStrategy"] = None,
) -> Optional[Union[Dict[str, Any], str]]:
if retry_strategy_state is None:
retry_strategy_state = self._retry_strategy_class()
retry_strategy_state.sleep()
self._log_request(data, method, params, url)
response, saved_exception = self._make_request(
data,
files,
json,
method,
params,
retry_strategy_state,
timeout,
url,
)
do_retry = False
if response is not None and response.status_code == codes["unauthorized"]:
self._authorizer._clear_access_token()
if hasattr(self._authorizer, "refresh"):
do_retry = True
if retry_strategy_state.should_retry_on_failure() and (
do_retry or response is None or response.status_code in self.RETRY_STATUSES
):
return self._do_retry(
data,
files,
json,
method,
params,
response,
retry_strategy_state,
saved_exception,
timeout,
url,
)
elif response.status_code in self.STATUS_EXCEPTIONS:
raise self.STATUS_EXCEPTIONS[response.status_code](response)
elif response.status_code == codes["no_content"]:
return
assert (
response.status_code in self.SUCCESS_STATUSES
), f"Unexpected status code: {response.status_code}"
if response.headers.get("content-length") == "0":
return ""
try:
return response.json()
except ValueError:
raise BadJSON(response)
def _set_header_callback(self) -> Dict[str, str]:
if not self._authorizer.is_valid() and hasattr(self._authorizer, "refresh"):
self._authorizer.refresh()
return {"Authorization": f"bearer {self._authorizer.access_token}"}
@property
def _requestor(self) -> "Requestor":
return self._authorizer._authenticator._requestor
def close(self) -> None:
"""Close the session and perform any clean up."""
self._requestor.close()
def request(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, "BufferedReader"]] = None,
json: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
timeout: float = TIMEOUT,
) -> Optional[Union[Dict[str, Any], str]]:
"""Return the json content from the resource at ``path``.
:param method: The request verb. E.g., ``"GET"``, ``"POST"``, ``"PUT"``.
:param path: The path of the request. This path will be combined with the
``oauth_url`` of the Requestor.
:param data: Dictionary, bytes, or file-like object to send in the body of the
request.
:param files: Dictionary, mapping ``filename`` to file-like object.
:param json: Object to be serialized to JSON in the body of the request.
:param params: The query parameters to send with the request.
:param timeout: Specifies a particular timeout, in seconds.
Automatically refreshes the access token if it becomes invalid and a refresh
token is available.
:raises: :class:`.InvalidInvocation` in such a case if a refresh token is not
available.
"""
params = deepcopy(params) or {}
params["raw_json"] = 1
if isinstance(data, dict):
data = deepcopy(data)
data["api_type"] = "json"
data = sorted(data.items())
if isinstance(json, dict):
json = deepcopy(json)
json["api_type"] = "json"
url = urljoin(self._requestor.oauth_url, path)
return self._request_with_retries(
data=data,
files=files,
json=json,
method=method,
params=params,
timeout=timeout,
url=url,
)
def session(authorizer: "Authorizer" = None) -> Session:
"""Return a :class:`.Session` instance.
:param authorizer: An instance of :class:`.Authorizer`.
"""
return Session(authorizer=authorizer)