diff --git a/pixels/README.md b/pixels/README.md index 69c0798..751debd 100644 --- a/pixels/README.md +++ b/pixels/README.md @@ -36,5 +36,5 @@ To avoid having one person fill the entire board, we set up rate limits, leverag One of the main requirements is that the request count must be removed after a set amount of time. Redis [TTL functionality](https://redis.io/commands/TTL) is perfect to automatically remove expired requests. -We insert request counts in Redis as soon as we receive them, set the expiry time to be the reset time, and set dummy values for users under cooldown. -The lack of needing a background clearing task makes it fast and efficient, even if we have to handle a lot of concurrent connections. +Another requirement is to have a rolling window mechanism, so requesting the same endpoint every X seconds or in burst will result in the same speed. With that in mind, a sorted set is created for each bucket. Each entry contains a random dummy value and its score is set to be the current timestamp plus the rate limit duration. +Before counting entries, we simply remove entries from `-inf` to the current timestamp using a [ZREMBYSCORE](https://redis.io/commands/zremrangebyscore) operation, allowing it to stay O(log n). diff --git a/pixels/utils/ratelimits.py b/pixels/utils/ratelimits.py index ac61c8c..283b9ba 100644 --- a/pixels/utils/ratelimits.py +++ b/pixels/utils/ratelimits.py @@ -7,8 +7,10 @@ import itertools import logging import typing +import uuid from collections import namedtuple from dataclasses import dataclass +from time import time import fastapi from aioredis import Redis @@ -116,15 +118,12 @@ async def head_endpoint(request: requests.Request) -> Response: await self._init_state(request_id, request) if await self._check_cooldown(request_id): - response.headers["CooldownReset"] = str( - await self._get_remaining_cooldown(request_id) + response.headers.append( + "Cooldown-Reset", + str(await self._get_remaining_cooldown(request_id)) ) else: - remaining_requests = await self.get_remaining_requests(request_id) - - response.headers.append("Requests-Remaining", str(remaining_requests)) - response.headers.append("Requests-Limit", str(self.LIMITS.requests)) - response.headers.append("Requests-Reset", str(self.LIMITS.time_unit)) + await self.add_headers(response, request_id) return response # functools.wraps is used here to wrap the endpoint while maintaining the signature @@ -170,11 +169,7 @@ async def caller(*_args, **_kwargs) -> typing.Union[JSONResponse, Response]: clean_result = jsonable_encoder(result) response = JSONResponse(content=clean_result) - remaining_requests = await self.get_remaining_requests(request_id) - - response.headers.append("Requests-Remaining", str(remaining_requests)) - response.headers.append("Requests-Limit", str(self.LIMITS.requests)) - response.headers.append("Requests-Reset", str(self.LIMITS.time_unit)) + await self.add_headers(response, request_id) # Setup post interaction tasks state = self.state[request_id] @@ -193,6 +188,16 @@ async def caller(*_args, **_kwargs) -> typing.Union[JSONResponse, Response]: return caller + async def add_headers(self, response: Response, request_id: int) -> None: + """Add ratelimit headers to the provided request.""" + remaining_requests = await self.get_remaining_requests(request_id) + request_reset = await self._reset_time(request_id) + + response.headers.append("Requests-Remaining", str(remaining_requests)) + response.headers.append("Requests-Limit", str(self.LIMITS.requests)) + response.headers.append("Requests-Period", str(self.LIMITS.time_unit)) + response.headers.append("Requests-Reset", str(request_reset)[:6]) + async def _increment(self, request_id: int) -> None: """Reduce remaining quota, and check if a cooldown is needed.""" if await self._check_cooldown(request_id): @@ -239,6 +244,10 @@ async def _get_remaining_cooldown(self, request_id: int) -> int: """Return the time, in seconds, until a cooldown ends.""" raise NotImplementedError() + async def _reset_time(self, request_id: int) -> int: + """Return the time, in seconds, before getting every interaction back.""" + raise NotImplementedError() + class UserRedis(__BucketBase): """A per user request bucket backed by Redis.""" @@ -273,13 +282,15 @@ async def _record_interaction(self, request_id: int) -> None: key = f"interaction-{self.ROUTE_NAME}-{self.state[request_id].user_id}" log.debug(f"Recorded interaction of user {self.state[request_id].user_id} on {self.ROUTE_NAME}.") - await self.redis.incr(key) + await self.redis.zadd(key, time() + self.LIMITS.time_unit, str(uuid.uuid4())) await self.redis.expire(key, self.LIMITS.time_unit) async def _calculate_remaining_requests(self, request_id: int) -> int: key = f"interaction-{self.ROUTE_NAME}-{self.state[request_id].user_id}" - remaining = self.LIMITS.requests - int(await self.redis.get(key) or 0) + # Cleanup expired entries + await self.redis.zremrangebyscore(key, max=time()) + remaining = self.LIMITS.requests - int(await self.redis.zcount(key) or 0) log.debug(f"Remaining interactions of user {self.state[request_id].user_id} on {self.ROUTE_NAME}: {remaining}.") return remaining @@ -305,3 +316,11 @@ async def _get_remaining_cooldown(self, request_id: int) -> int: key = f"cooldown-{self.ROUTE_NAME}-{self.state[request_id].user_id}" return await self.redis.ttl(key) + + async def _reset_time(self, request_id: int) -> int: + key = f"interaction-{self.ROUTE_NAME}-{self.state[request_id].user_id}" + + if not (newest_uuid := await self.redis.zrange(key, 0, 0)): + return -1 + + return await self.redis.zscore(key, newest_uuid[0]) - time()