Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pixels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ To avoid having one person fill the entire board, we set up rate limits, leverag
One of the main requirements is that the request count must be removed after a set amount of time.
Redis [TTL functionality](https://redis.io/commands/TTL) is perfect to automatically remove expired requests.

We insert request counts in Redis as soon as we receive them, set the expiry time to be the reset time, and set dummy values for users under cooldown.
The lack of needing a background clearing task makes it fast and efficient, even if we have to handle a lot of concurrent connections.
Another requirement is to have a rolling window mechanism, so requesting the same endpoint every X seconds or in burst will result in the same speed. With that in mind, a sorted set is created for each bucket. Each entry contains a random dummy value and its score is set to be the current timestamp plus the rate limit duration.
Before counting entries, we simply remove entries from `-inf` to the current timestamp using a [ZREMBYSCORE](https://redis.io/commands/zremrangebyscore) operation, allowing it to stay O(log n).
47 changes: 33 additions & 14 deletions pixels/utils/ratelimits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import itertools
import logging
import typing
import uuid
from collections import namedtuple
from dataclasses import dataclass
from time import time

import fastapi
from aioredis import Redis
Expand Down Expand Up @@ -116,15 +118,12 @@ async def head_endpoint(request: requests.Request) -> Response:
await self._init_state(request_id, request)

if await self._check_cooldown(request_id):
response.headers["CooldownReset"] = str(
await self._get_remaining_cooldown(request_id)
response.headers.append(
"Cooldown-Reset",
str(await self._get_remaining_cooldown(request_id))
)
else:
remaining_requests = await self.get_remaining_requests(request_id)

response.headers.append("Requests-Remaining", str(remaining_requests))
response.headers.append("Requests-Limit", str(self.LIMITS.requests))
response.headers.append("Requests-Reset", str(self.LIMITS.time_unit))
await self.add_headers(response, request_id)
return response

# functools.wraps is used here to wrap the endpoint while maintaining the signature
Expand Down Expand Up @@ -170,11 +169,7 @@ async def caller(*_args, **_kwargs) -> typing.Union[JSONResponse, Response]:
clean_result = jsonable_encoder(result)
response = JSONResponse(content=clean_result)

remaining_requests = await self.get_remaining_requests(request_id)

response.headers.append("Requests-Remaining", str(remaining_requests))
response.headers.append("Requests-Limit", str(self.LIMITS.requests))
response.headers.append("Requests-Reset", str(self.LIMITS.time_unit))
await self.add_headers(response, request_id)

# Setup post interaction tasks
state = self.state[request_id]
Expand All @@ -193,6 +188,16 @@ async def caller(*_args, **_kwargs) -> typing.Union[JSONResponse, Response]:

return caller

async def add_headers(self, response: Response, request_id: int) -> None:
"""Add ratelimit headers to the provided request."""
remaining_requests = await self.get_remaining_requests(request_id)
request_reset = await self._reset_time(request_id)

response.headers.append("Requests-Remaining", str(remaining_requests))
response.headers.append("Requests-Limit", str(self.LIMITS.requests))
response.headers.append("Requests-Period", str(self.LIMITS.time_unit))
response.headers.append("Requests-Reset", str(request_reset)[:6])

async def _increment(self, request_id: int) -> None:
"""Reduce remaining quota, and check if a cooldown is needed."""
if await self._check_cooldown(request_id):
Expand Down Expand Up @@ -239,6 +244,10 @@ async def _get_remaining_cooldown(self, request_id: int) -> int:
"""Return the time, in seconds, until a cooldown ends."""
raise NotImplementedError()

async def _reset_time(self, request_id: int) -> int:
"""Return the time, in seconds, before getting every interaction back."""
raise NotImplementedError()


class UserRedis(__BucketBase):
"""A per user request bucket backed by Redis."""
Expand Down Expand Up @@ -273,13 +282,15 @@ async def _record_interaction(self, request_id: int) -> None:
key = f"interaction-{self.ROUTE_NAME}-{self.state[request_id].user_id}"
log.debug(f"Recorded interaction of user {self.state[request_id].user_id} on {self.ROUTE_NAME}.")

await self.redis.incr(key)
await self.redis.zadd(key, time() + self.LIMITS.time_unit, str(uuid.uuid4()))
await self.redis.expire(key, self.LIMITS.time_unit)

async def _calculate_remaining_requests(self, request_id: int) -> int:
key = f"interaction-{self.ROUTE_NAME}-{self.state[request_id].user_id}"

remaining = self.LIMITS.requests - int(await self.redis.get(key) or 0)
# Cleanup expired entries
await self.redis.zremrangebyscore(key, max=time())
remaining = self.LIMITS.requests - int(await self.redis.zcount(key) or 0)

log.debug(f"Remaining interactions of user {self.state[request_id].user_id} on {self.ROUTE_NAME}: {remaining}.")
return remaining
Expand All @@ -305,3 +316,11 @@ async def _get_remaining_cooldown(self, request_id: int) -> int:
key = f"cooldown-{self.ROUTE_NAME}-{self.state[request_id].user_id}"

return await self.redis.ttl(key)

async def _reset_time(self, request_id: int) -> int:
key = f"interaction-{self.ROUTE_NAME}-{self.state[request_id].user_id}"

if not (newest_uuid := await self.redis.zrange(key, 0, 0)):
return -1

return await self.redis.zscore(key, newest_uuid[0]) - time()