/
rate_limit.py
107 lines (87 loc) · 3.72 KB
/
rate_limit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Provide the RateLimiter class."""
import asyncio
import logging
import time
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Mapping, Optional
if TYPE_CHECKING:
from aiohttp import ClientResponse
log = logging.getLogger(__package__)
class RateLimiter(object):
"""Facilitates the rate limiting of requests to Reddit.
Rate limits are controlled based on feedback from requests to Reddit.
"""
def __init__(self) -> None:
"""Create an instance of the RateLimit class."""
self.remaining: Optional[float] = None
self.next_request_timestamp: Optional[float] = None
self.reset_timestamp: Optional[float] = None
self.used: Optional[int] = None
self.window_size: Optional[float] = None
async def call(
self,
request_function: Callable[[Any], Awaitable["ClientResponse"]],
set_header_callback: Callable[[], Awaitable[Dict[str, str]]],
*args,
**kwargs,
) -> "ClientResponse":
"""Rate limit the call to ``request_function``.
:param request_function: A function call that returns an HTTP response object.
:param set_header_callback: A callback function used to set the request headers.
This callback is called after any necessary sleep time occurs.
:param args: The positional arguments to ``request_function``.
:param kwargs: The keyword arguments to ``request_function``.
"""
await self.delay()
kwargs["headers"] = await set_header_callback()
response = await request_function(*args, **kwargs)
self.update(response.headers)
return response
async def delay(self) -> None:
"""Sleep for an amount of time to remain under the rate limit."""
if self.next_request_timestamp is None:
return
sleep_seconds = self.next_request_timestamp - time.time()
if sleep_seconds <= 0:
return
message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to call"
log.debug(message)
await asyncio.sleep(sleep_seconds)
def update(self, response_headers: Mapping[str, str]) -> None:
"""Update the state of the rate limiter based on the response headers.
This method should only be called following an HTTP request to Reddit.
Response headers that do not contain ``x-ratelimit`` fields will be treated as a
single request. This behavior is to error on the safe-side as such responses
should trigger exceptions that indicate invalid behavior.
"""
if "x-ratelimit-remaining" not in response_headers:
if self.remaining is not None:
self.remaining -= 1
self.used += 1
return
now = time.time()
seconds_to_reset = int(response_headers["x-ratelimit-reset"])
self.remaining = float(response_headers["x-ratelimit-remaining"])
self.used = int(response_headers["x-ratelimit-used"])
self.reset_timestamp = now + seconds_to_reset
if self.window_size is None:
self.window_size = seconds_to_reset + self.used
elif self.window_size < seconds_to_reset:
self.window_size = seconds_to_reset
if self.remaining <= 0:
self.next_request_timestamp = self.reset_timestamp
return
self.next_request_timestamp = min(
self.reset_timestamp,
now
+ min(
max(
seconds_to_reset
- (
self.window_size
- (self.window_size / (self.remaining + self.used) * self.used)
),
0,
),
10,
),
)