Skip to content

Commit

Permalink
fix: async bucket_factory.get() (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
haoyuhu committed Apr 30, 2024
1 parent 0e4895f commit c6d5f88
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pyrate-limiter"
version = "3.6.0"
version = "3.6.1"
description = "Python Rate-Limiter using Leaky-Bucket Algorithm"
authors = ["vutr <me@vutr.io>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion pyrate_limiter/abstracts/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def wrap_item(
"""

@abstractmethod
def get(self, item: RateItem) -> AbstractBucket:
def get(self, item: RateItem) -> Union[AbstractBucket, Awaitable[AbstractBucket]]:
"""Get the corresponding bucket to this item"""

def create(
Expand Down
25 changes: 21 additions & 4 deletions pyrate_limiter/limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def wrap_item(self, name: str, weight: int = 1):
async def wrap_async():
return RateItem(name, await now, weight=weight)

def wrap_sycn():
def wrap_sync():
return RateItem(name, now, weight=weight)

return wrap_async() if isawaitable(now) else wrap_sycn()
return wrap_async() if isawaitable(now) else wrap_sync()

def get(self, _: RateItem) -> AbstractBucket:
return self.bucket
Expand Down Expand Up @@ -252,15 +252,15 @@ async def _put_async():
return _handle_result(acquire) # type: ignore

def try_acquire(self, name: str, weight: int = 1) -> Union[bool, Awaitable[bool]]:
"""Try accquiring an item with name & weight
"""Try acquiring an item with name & weight
Return true on success, false on failure
"""
with self.lock:
assert weight >= 0, "item's weight must be >= 0"

if weight == 0:
# NOTE: if item is weightless, just let it go through
# NOTE: this might change in the futre
# NOTE: this might change in the future
return True

item = self.bucket_factory.wrap_item(name, weight)
Expand All @@ -271,6 +271,8 @@ async def _handle_async():
nonlocal item
item = await item
bucket = self.bucket_factory.get(item)
if isawaitable(bucket):
bucket = await bucket
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
result = self.handle_bucket_put(bucket, item)

Expand All @@ -283,6 +285,21 @@ async def _handle_async():

assert isinstance(item, RateItem) # NOTE: this is to silence mypy warning
bucket = self.bucket_factory.get(item)
if isawaitable(bucket):

async def _handle_async_bucket():
nonlocal bucket
bucket = await bucket
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
result = self.handle_bucket_put(bucket, item)

while isawaitable(result):
result = await result

return result

return _handle_async_bucket()

assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
result = self.handle_bucket_put(bucket, item)

Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from pyrate_limiter import TimeAsyncClock
from pyrate_limiter import TimeClock


# Make log messages visible on test failure (or with pytest -s)
basicConfig(level="INFO")
# Uncomment for more verbose output:
Expand Down Expand Up @@ -72,7 +71,7 @@ async def create_async_redis_bucket(rates: List[Rate]):
from redis.asyncio import ConnectionPool as AsyncConnectionPool
from redis.asyncio import Redis as AsyncRedis

pool = AsyncConnectionPool.from_url(getenv("REDIS", "redis://localhost:6379"))
pool: AsyncConnectionPool = AsyncConnectionPool.from_url(getenv("REDIS", "redis://localhost:6379"))
redis_db: AsyncRedis = AsyncRedis(connection_pool=pool)
bucket_key = f"test-bucket/{id_generator()}"
await redis_db.delete(bucket_key)
Expand Down
58 changes: 58 additions & 0 deletions tests/demo_bucket_factory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from inspect import isawaitable
from os import getenv
from typing import Dict
from typing import Optional

from redis.asyncio import ConnectionPool as AsyncConnectionPool
from redis.asyncio import Redis as AsyncRedis

from .conftest import DEFAULT_RATES
from .helpers import flushing_bucket
from pyrate_limiter import AbstractBucket
from pyrate_limiter import AbstractClock
from pyrate_limiter import BucketFactory
from pyrate_limiter import id_generator
from pyrate_limiter import InMemoryBucket
from pyrate_limiter import RateItem
from pyrate_limiter import RedisBucket


class DemoBucketFactory(BucketFactory):
Expand Down Expand Up @@ -54,3 +61,54 @@ def get(self, item: RateItem) -> AbstractBucket:
def schedule_leak(self, *args):
if self.auto_leak:
super().schedule_leak(*args)


class DemoAsyncGetBucketFactory(BucketFactory):
"""Async multi-bucket factory used for testing schedule-leaks"""

def __init__(self, bucket_clock: AbstractClock, auto_leak=False, **buckets: AbstractBucket):
self.auto_leak = auto_leak
self.clock = bucket_clock
self.buckets = {}
self.leak_interval = 300

for item_name_pattern, bucket in buckets.items():
assert isinstance(bucket, AbstractBucket)
self.schedule_leak(bucket, bucket_clock)
self.buckets[item_name_pattern] = bucket

def wrap_item(self, name: str, weight: int = 1):
now = self.clock.now()

async def wrap_async():
return RateItem(name, await now, weight=weight)

def wrap_sync():
return RateItem(name, now, weight=weight)

return wrap_async() if isawaitable(now) else wrap_sync()

async def get(self, item: RateItem) -> AbstractBucket:
assert self.buckets is not None

if item.name in self.buckets:
bucket = self.buckets[item.name]
assert isinstance(bucket, AbstractBucket)
return bucket

pool: AsyncConnectionPool = AsyncConnectionPool.from_url(getenv("REDIS", "redis://localhost:6379"))
redis_db: AsyncRedis = AsyncRedis(connection_pool=pool)
key = f"test-bucket/{id_generator()}"
await redis_db.delete(key)
bucket = await RedisBucket.init(DEFAULT_RATES, redis_db, key)
self.schedule_leak(bucket, self.clock)
self.buckets.update({item.name: bucket})
return bucket

def schedule_leak(self, *args):
if self.auto_leak:
super().schedule_leak(*args)

async def flush(self):
for bucket in self.buckets.values():
await flushing_bucket(bucket)
81 changes: 81 additions & 0 deletions tests/test_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .conftest import DEFAULT_RATES
from .conftest import logger
from .demo_bucket_factory import DemoAsyncGetBucketFactory
from .demo_bucket_factory import DemoBucketFactory
from .helpers import async_acquire
from .helpers import concurrent_acquire
Expand Down Expand Up @@ -167,6 +168,86 @@ async def test_limiter_01(
assert not acquire_ok


@pytest.mark.asyncio
async def test_limiter_async_factory_get(
clock,
limiter_should_raise,
limiter_delay,
):
factory = DemoAsyncGetBucketFactory(clock)
limiter = Limiter(
factory,
raise_when_fail=limiter_should_raise,
max_delay=limiter_delay,
)
item = "demo"

logger.info("If weight = 0, it just passes thru")
acquire_ok, cost = await async_acquire(limiter, item, weight=0)
assert acquire_ok
assert cost <= 10

logger.info("Limiter Test #1")
await prefilling_bucket(limiter, 0.3, item)

if not limiter_should_raise:
acquire_ok, cost = await async_acquire(limiter, item)
if limiter_delay is None:
assert cost <= 50
assert not acquire_ok
else:
assert acquire_ok
else:
if limiter_delay is None:
with pytest.raises(BucketFullException):
acquire_ok, cost = await async_acquire(limiter, item)
else:
acquire_ok, cost = await async_acquire(limiter, item)
assert cost > 400
assert acquire_ok

# # Flush before testing again
await factory.flush()
logger.info("Limiter Test #2")
await prefilling_bucket(limiter, 0, item)

if limiter_should_raise:
if limiter_delay == 500:
with pytest.raises(LimiterDelayException) as err:
await async_acquire(limiter, item)
assert err.meta_info["max_delay"] == 500
assert err.meta_info["actual_delay"] > 600
assert err.meta_info["name"] == item
elif limiter_delay == 2000:
acquire_ok, cost = await async_acquire(limiter, item)
assert acquire_ok
elif limiter_delay == Duration.MINUTE:
acquire_ok, cost = await async_acquire(limiter, item)
assert acquire_ok
else:
with pytest.raises(BucketFullException) as err:
await async_acquire(limiter, item)
else:
acquire_ok, cost = await async_acquire(limiter, item)
if limiter_delay == 500 or limiter_delay is None:
assert not acquire_ok
else:
assert acquire_ok

# Flush before testing again
await factory.flush()
logger.info("Limiter Test #3: exceeding weight")
await prefilling_bucket(limiter, 0, item)

if limiter_should_raise:
with pytest.raises(BucketFullException) as err:
await async_acquire(limiter, item, 5)
else:
acquire_ok, cost = await async_acquire(limiter, item, 5)
assert cost <= 50
assert not acquire_ok


@pytest.mark.asyncio
async def test_limiter_concurrency(
clock,
Expand Down

0 comments on commit c6d5f88

Please sign in to comment.