Skip to content

Commit

Permalink
Switch from psycopg v2 to v3 (#164)
Browse files Browse the repository at this point in the history
* change dependencies

* Switch to psycopg^3

* refactor

* disable postgresClock

* cancel async-leak-task when exit test

* refactor

* refactor

* update extra

* update lock file

* up

* up

* remove scope

* up

* remove cancel

* re-organize functions & modules

* refactor

* remove clock-exception
  • Loading branch information
vutran1710 committed Mar 17, 2024
1 parent e38daed commit 7d25d43
Show file tree
Hide file tree
Showing 14 changed files with 304 additions and 257 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,15 +508,15 @@ bucket = SQLiteBucket(rates, conn, table)

#### PostgresBucket

Postgres is supported, but you have to install `psycopg2` or `asyncpg` either as an extra or as a separate package.
Postgres is supported, but you have to install `psycopg[pool]` either as an extra or as a separate package.

You can use Postgres's built-in **CURRENT_TIMESTAMP** as the time source with `PostgresClock`, or use an external custom time source.

```python
from pyrate_limiter import PostgresBucket, Rate, PostgresClock
from psycopg2.pool import ThreadedConnectionPool
from psycopg_pool import ConnectionPool

connection_pool = ThreadedConnectionPool(5, 10, 'postgresql://postgres:postgres@localhost:5432')
connection_pool = ConnectionPool('postgresql://postgres:postgres@localhost:5432')

clock = PostgresClock(connection_pool)
rates = [Rate(3, 1000), Rate(4, 1500)]
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Reuse virtualenv created by poetry instead of creating new ones
nox.options.reuse_existing_virtualenvs = True

PYTEST_ARGS = ["--verbose", "--maxfail=1", "--numprocesses=8"]
PYTEST_ARGS = ["--verbose", "--maxfail=1", "--numprocesses=auto"]
COVERAGE_ARGS = ["--cov", "--cov-report=term", "--cov-report=xml", "--cov-report=html"]


Expand Down
97 changes: 78 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pyrate-limiter"
version = "3.5.1"
version = "3.6.0"
description = "Python Rate-Limiter using Leaky-Bucket Algorithm"
authors = ["vutr <me@vutr.io>"]
license = "MIT"
Expand Down Expand Up @@ -29,7 +29,7 @@ python = "^3.8"
# Optional backend dependencies
filelock = {optional=true, version=">=3.0"}
redis = {optional=true, version="^5.0.0"}
psycopg2 = {version = "^2.9.9", optional = true}
psycopg = {extras = ["pool"], version = "^3.1.18", optional = true}

# Documentation dependencies needed for Readthedocs builds
furo = {optional=true, version="^2022.3.4"}
Expand All @@ -40,7 +40,7 @@ sphinx-copybutton = {optional=true, version=">=0.5"}
sphinxcontrib-apidoc = {optional=true, version="^0.3"}

[tool.poetry.extras]
all = ["filelock", "redis", "psycopg2"]
all = ["filelock", "redis", "psycopg"]
docs = ["furo", "myst-parser", "sphinx", "sphinx-autodoc-typehints",
"sphinx-copybutton", "sphinxcontrib-apidoc"]

Expand All @@ -58,7 +58,7 @@ coverage = "6"
[tool.poetry.group.dev.dependencies]
pytest = "^8.1.1"
pytest-asyncio = "^0.23.5.post1"
psycopg2 = "^2.9.9"
psycopg = {extras = ["pool"], version = "^3.1.18"}

[tool.black]
line-length = 120
Expand Down
2 changes: 1 addition & 1 deletion pyrate_limiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
from .buckets import *
from .clocks import *
from .exceptions import *
from .limiter import Limiter
from .limiter import *
from .utils import *
7 changes: 6 additions & 1 deletion pyrate_limiter/abstracts/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(self, leak_interval: int):
self.async_buckets = defaultdict()
self.clocks = defaultdict()
self.leak_interval = leak_interval
self._task = None
super().__init__()

def register(self, bucket: AbstractBucket, clock: AbstractClock):
Expand Down Expand Up @@ -171,7 +172,7 @@ async def _leak(self, sync=True) -> None:
def leak_async(self):
if self.async_buckets and not self.is_async_leak_started:
self.is_async_leak_started = True
asyncio.create_task(self._leak(sync=False))
self._task = asyncio.create_task(self._leak(sync=False))

def run(self) -> None:
assert self.sync_buckets
Expand All @@ -181,6 +182,10 @@ def start(self) -> None:
if self.sync_buckets and not self.is_alive():
super().start()

def cancel(self) -> None:
if self._task:
self._task.cancel()


class BucketFactory(ABC):
"""Asbtract BucketFactory class.
Expand Down
56 changes: 25 additions & 31 deletions pyrate_limiter/buckets/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..abstracts import RateItem

if TYPE_CHECKING:
from psycopg2.pool import AbstractConnectionPool
from psycopg_pool import ConnectionPool


class Queries:
Expand Down Expand Up @@ -54,9 +54,9 @@ class Queries:

class PostgresBucket(AbstractBucket):
table: str
pool: AbstractConnectionPool
pool: ConnectionPool

def __init__(self, pool: AbstractConnectionPool, table: str, rates: List[Rate]):
def __init__(self, pool: ConnectionPool, table: str, rates: List[Rate]):
self.table = table.lower()
self.pool = pool
assert rates
Expand All @@ -65,21 +65,15 @@ def __init__(self, pool: AbstractConnectionPool, table: str, rates: List[Rate]):
self._create_table()

@contextmanager
def _get_conn(self, autocommit=False):
with self.pool._getconn() as conn:
with conn.cursor() as cur:
yield cur

if autocommit:
conn.commit()

self.pool._putconn(conn)
def _get_conn(self):
with self.pool.connection() as conn:
yield conn

def _create_table(self):
with self._get_conn(autocommit=True) as cur:
cur.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
with self._get_conn() as conn:
conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
index_name = f'timestampIndex_{self.table}'
cur.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))
conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))

def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Put an item (typically the current time) in the bucket
Expand All @@ -88,12 +82,12 @@ def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
if item.weight == 0:
return True

with self._get_conn(autocommit=True) as cur:
with self._get_conn() as conn:
for rate in self.rates:
bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'"
query = f'SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})'
cur.execute(query)
count = int(cur.fetchone()[0])
conn = conn.execute(query)
count = int(conn.fetchone()[0])

if rate.limit - count < item.weight:
self.failing_rate = rate
Expand All @@ -103,7 +97,7 @@ def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:

query = Queries.PUT.format(table=self._full_tbl)
arguments = [(item.name, item.weight, item.timestamp / 1000)] * item.weight
cur.executemany(query, tuple(arguments))
conn.executemany(query, tuple(arguments))

return True

Expand All @@ -120,12 +114,12 @@ def leak(

count = 0

with self._get_conn(autocommit=True) as cur:
cur.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
result = cur.fetchone()
with self._get_conn() as conn:
conn = conn.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
result = conn.fetchone()

if result:
cur.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
conn.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
count = int(result[0])

return count
Expand All @@ -134,18 +128,18 @@ def flush(self) -> Union[None, Awaitable[None]]:
"""Flush the whole bucket
- Must remove `failing-rate` after flushing
"""
with self._get_conn(autocommit=True) as cur:
cur.execute(Queries.FLUSH.format(table=self._full_tbl))
with self._get_conn() as conn:
conn.execute(Queries.FLUSH.format(table=self._full_tbl))
self.failing_rate = None

return None

def count(self) -> Union[int, Awaitable[int]]:
"""Count number of items in the bucket"""
count = 0
with self._get_conn() as cur:
cur.execute(Queries.COUNT.format(table=self._full_tbl))
result = cur.fetchone()
with self._get_conn() as conn:
conn = conn.execute(Queries.COUNT.format(table=self._full_tbl))
result = conn.fetchone()
assert result
count = int(result[0])

Expand All @@ -158,9 +152,9 @@ def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateI
"""
item = None

with self._get_conn() as cur:
cur.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
result = cur.fetchone()
with self._get_conn() as conn:
conn = conn.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
result = conn.fetchone()
if result:
name, weight, timestamp = result[0], int(result[1]), int(result[2])
item = RateItem(name=name, weight=weight, timestamp=timestamp)
Expand Down

0 comments on commit 7d25d43

Please sign in to comment.