Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update AsyncClient to use Semaphores #1916

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tweepy/asynchronous/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
from platform import python_version
import time
from collections import defaultdict

import aiohttp
from async_lru import alru_cache
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(

self.return_type = return_type
self.wait_on_rate_limit = wait_on_rate_limit
self.rate_limit_status = defaultdict(TaskRateSemaphore)

self.session = None
self.user_agent = (
Expand Down Expand Up @@ -87,6 +89,11 @@ async def request(
else:
headers["Authorization"] = f"Bearer {self.bearer_token}"

if self.wait_on_rate_limit:
rate_limit_status = self.rate_limit_status[(method, route)]
await rate_limit_status.acquire()


log.debug(
f"Making API request: {method} {url}\n"
f"Parameters: {params}\n"
Expand Down Expand Up @@ -126,6 +133,7 @@ async def request(
"Rate limit exceeded. "
f"Sleeping for {sleep_time} seconds."
)
rate_limit_status.release(int(response.headers["x-rate-limit-remaining"]),reset_time)
await asyncio.sleep(sleep_time)
return await self.request(method, route, params, json, user_auth)
else:
Expand All @@ -135,6 +143,7 @@ async def request(
if not 200 <= response.status < 300:
raise HTTPException(response, response_json=response_json)

rate_limit_status.release(int(response.headers["x-rate-limit-remaining"]),int(response.headers["x-rate-limit-reset"]))
return response

async def _make_request(
Expand Down Expand Up @@ -3314,3 +3323,46 @@ async def create_compliance_job(self, type, *, name=None, resumable=None):
return await self._make_request(
"POST", "/2/compliance/jobs", json=json
)

class TaskRateSemaphore(asyncio.Semaphore):
def __init__(self):
super().__init__(10)
self.max_seen = 10
self.inprogress = 0
self.reset_time = None
self.fut = None
self.reset_callback = None


def reset(self):
self.reset_time = None
for _ in range(self.max_seen):
super().release()


async def acquire(self):
await super().acquire()
self.inprogress +=1

def update(self, remaining, reset_time):
if remaining > self.max_seen:
self.max_seen = remaining
if (reset_time and self.reset_callback and self.reset_time != reset_time) or (reset_time and not self.reset_callback):
if self.reset_callback:
self.reset_callback.cancel()
try:
loop = self._get_loop()
except AttributeError:
loop = self._loop

self.reset_callback = loop.call_later(reset_time - int(time.time()) + 1, self.reset)
self.reset_time = reset_time


def release(self, remaining, reset_time):
self.update(remaining, reset_time)
self.inprogress -= 1

while remaining and (remaining - self.inprogress-1) > self._value:

super().release()