Skip to content

Commit

Permalink
Adding additional constructor parameters to AsyncDatabase as well
Browse files Browse the repository at this point in the history
  • Loading branch information
blast-hardcheese committed Feb 24, 2024
1 parent ca86759 commit 7cc70b6
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions src/replit/database/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Async and dict-like interfaces for interacting with Repl.it Database."""
"""Async and dict-like interfaces for interacting with Replit Database."""

from collections import abc
import json
Expand Down Expand Up @@ -62,11 +62,23 @@ def dumps(val: Any) -> str:


class AsyncDatabase:
"""Async interface for Repl.it Database."""
"""Async interface for Replit Database.
__slots__ = ("db_url", "sess", "client")
:param str db_url: The Database URL to connect to
:param int retry_count: How many retry attempts we should make
:param get_db_url Callable: A callback that returns the current db_url
:param unbind Callable: Permit additional behavior after Database close
"""

def __init__(self, db_url: str, retry_count: int = 5) -> None:
__slots__ = ("db_url", "sess", "_get_db_url", "_unbind", "_refresh_timer")

def __init__(
self,
db_url: str,
retry_count: int = 5,
get_db_url: Optional[Callable[[], Optional[str]]] = None,
unbind: Optional[Callable[[], None]] = None,
) -> None:
"""Initialize database. You shouldn't have to do this manually.
Args:
Expand All @@ -76,10 +88,27 @@ def __init__(self, db_url: str, retry_count: int = 5) -> None:
"""
self.db_url = db_url
self.sess = aiohttp.ClientSession()
self._get_db_url = get_db_url
self._unbind = unbind

retry_options = ExponentialRetry(attempts=retry_count)
self.client = RetryClient(client_session=self.sess, retry_options=retry_options)

if self._get_db_url:
self._refresh_timer = threading.Timer(3600, self._refresh_db)
self._refresh_timer.start()

def _refresh_db(self):
if self._refresh_timer:
self._refresh_timer.cancel()
self._refresh_timer = None
if self._get_db_url:
db_url = self._get_db_url()
if db_url:
self.update_db_url(db_url)
self._refresh_timer = threading.Timer(3600, self._refresh_db)
self._refresh_timer.start()

def update_db_url(self, db_url: str) -> None:
"""Update the database url.
Expand Down Expand Up @@ -240,6 +269,16 @@ async def items(self) -> Tuple[Tuple[str, str], ...]:
"""
return tuple((await self.to_dict()).items())

async def close(self) -> None:
"""Closes the database client connection."""
await self.sess.close()
if self._refresh_timer:
self._refresh_timer.cancel()
self._refresh_timer = None
if self._unbind:
# Permit signaling to surrounding scopes that we have closed
self._unbind()

def __repr__(self) -> str:
"""A representation of the database.
Expand Down

0 comments on commit 7cc70b6

Please sign in to comment.