Skip to content

Commit

Permalink
modify: Add some typeshed codes
Browse files Browse the repository at this point in the history
  • Loading branch information
tanbro committed Apr 10, 2024
1 parent 814c9e3 commit 0c6604a
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
exclude: ^(src/sqlalchemy_dlock(/asyncio)?/_sa_types(_backward)?\.py)|(tests/.+)$
exclude: ^(src/sqlalchemy_dlock(/asyncio)?/_sa_types(_backward)?\.py)$

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
2 changes: 2 additions & 0 deletions src/sqlalchemy_dlock/asyncio/_sa_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_scoped_session

__all__ = ["TAsyncConnectionOrSession"]

type TAsyncConnectionOrSession = AsyncConnection | AsyncSession | async_scoped_session
2 changes: 2 additions & 0 deletions src/sqlalchemy_dlock/asyncio/_sa_types_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@

from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_scoped_session

__all__ = ["TAsyncConnectionOrSession"]

TAsyncConnectionOrSession: TypeAlias = Union[AsyncConnection, AsyncSession, async_scoped_session]
17 changes: 10 additions & 7 deletions src/sqlalchemy_dlock/asyncio/factory.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from importlib import import_module
from typing import Union
import sys
from typing import Union, Type

from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncConnection

from ..utils import pascal_case, safe_name
from .lock.base import BaseAsyncSadLock, TAsyncConnectionOrSession
from .lock.base import BaseAsyncSadLock

if sys.version_info >= (3, 12): # pragma: no cover
from ._sa_types import TAsyncConnectionOrSession
else: # pragma: no cover
from ._sa_types_backward import TAsyncConnectionOrSession

__all__ = ["create_async_sadlock"]


def create_async_sadlock(
connection_or_session: TAsyncConnectionOrSession,
key,
contextual_timeout: Union[float, int, None] = None,
**kwargs,
connection_or_session: TAsyncConnectionOrSession, key, contextual_timeout: Union[float, int, None] = None, **kwargs
) -> BaseAsyncSadLock:
if isinstance(connection_or_session, AsyncConnection):
sync_engine = connection_or_session.sync_engine
Expand All @@ -29,5 +32,5 @@ def create_async_sadlock(
mod = import_module(f"..lock.{engine_name}", __name__)
except ImportError as exception: # pragma: no cover
raise NotImplementedError(f"{engine_name}: {exception}")
clz = getattr(mod, f"{pascal_case(engine_name)}AsyncSadLock")
clz: Type[BaseAsyncSadLock] = getattr(mod, f"{pascal_case(engine_name)}AsyncSadLock")
return clz(connection_or_session, key, contextual_timeout=contextual_timeout, **kwargs)
19 changes: 12 additions & 7 deletions src/sqlalchemy_dlock/asyncio/lock/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import sys
from typing import Any, Union

if sys.version_info >= (3, 11): # pragma: no cover
from typing import Self
else: # pragma: no cover
from typing_extensions import Self
if sys.version_info >= (3, 12): # pragma: no cover
from .._sa_types import TAsyncConnectionOrSession
else: # pragma: no cover
Expand All @@ -21,17 +25,18 @@ def __init__(
self._key = key
self._contextual_timeout = contextual_timeout

async def __aenter__(self):
async def __aenter__(self) -> Self:
if self._contextual_timeout is None:
await self.acquire()
elif not await self.acquire(timeout=self._contextual_timeout): # the timeout period has elapsed and not acquired
elif not await self.acquire(timeout=self._contextual_timeout):
# the timeout period has elapsed and not acquired
raise TimeoutError()
return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.close()

def __str__(self): # pragma: no cover
def __str__(self) -> str:
return "<{} {} key={} at 0x{:x}>".format(
"locked" if self._acquired else "unlocked",
self.__class__.__name__,
Expand All @@ -44,7 +49,7 @@ def connection_or_session(self) -> TAsyncConnectionOrSession:
return self._connection_or_session

@property
def key(self):
def key(self) -> Any:
return self._key

@property
Expand All @@ -57,12 +62,12 @@ async def acquire(
timeout: Union[float, int, None] = None,
*args,
**kwargs,
) -> bool: # pragma: no cover
) -> bool:
raise NotImplementedError()

async def release(self, *args, **kwargs): # pragma: no cover
async def release(self, *args, **kwargs) -> None:
raise NotImplementedError()

async def close(self, *args, **kwargs):
async def close(self, *args, **kwargs) -> None:
if self._acquired:
await self.release(*args, **kwargs)
6 changes: 1 addition & 5 deletions src/sqlalchemy_dlock/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@


def create_sadlock(
connection_or_session: TConnectionOrSession,
key,
/,
contextual_timeout: Union[float, int, None] = None,
**kwargs,
connection_or_session: TConnectionOrSession, key, /, contextual_timeout: Union[float, int, None] = None, **kwargs
) -> BaseSadLock:
"""Create a database distributed lock object
Expand Down
31 changes: 12 additions & 19 deletions src/sqlalchemy_dlock/lock/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import sys
from threading import local
from typing import Union
from typing import Any, Union

if sys.version_info >= (3, 11): # pragma: no cover
from typing import Self
else: # pragma: no cover
from typing_extensions import Self
if sys.version_info >= (3, 12): # pragma: no cover
from .._sa_types import TConnectionOrSession
else: # pragma: no cover
Expand Down Expand Up @@ -37,12 +41,7 @@ class BaseSadLock(local):
""" # noqa: E501

def __init__(
self,
connection_or_session: TConnectionOrSession,
key,
/,
contextual_timeout: Union[float, int, None] = None,
**kwargs,
self, connection_or_session: TConnectionOrSession, key, /, contextual_timeout: Union[float, int, None] = None, **kwargs
):
"""
Args:
Expand Down Expand Up @@ -77,7 +76,7 @@ def __init__(
self._key = key
self._contextual_timeout = contextual_timeout

def __enter__(self):
def __enter__(self) -> Self:
if self._contextual_timeout is None: # timeout period is infinite
self.acquire()
elif not self.acquire(timeout=self._contextual_timeout): # the timeout period has elapsed and not acquired
Expand All @@ -87,7 +86,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_tb):
self.close()

def __str__(self): # pragma: no cover
def __str__(self) -> str:
return "<{} {} key={} at 0x{:x}>".format(
"locked" if self._acquired else "unlocked",
self.__class__.__name__,
Expand All @@ -104,7 +103,7 @@ def connection_or_session(self) -> TConnectionOrSession:
return self._connection_or_session

@property
def key(self):
def key(self) -> Any:
"""ID or name of the SQL locking function
It returns ``key`` parameter of the class's constructor"""
Expand All @@ -118,13 +117,7 @@ def locked(self) -> bool:
"""
return self._acquired

def acquire(
self,
block: bool = True,
timeout: Union[float, int, None] = None,
*args,
**kwargs,
) -> bool: # pragma: no cover
def acquire(self, block: bool = True, timeout: Union[float, int, None] = None, *args, **kwargs) -> bool:
"""Acquire a lock, blocking or non-blocking.
* With the ``block`` argument set to :data:`True` (the default), the method call will block until the lock is in an unlocked state, then set it to locked and return :data:`True`.
Expand All @@ -140,7 +133,7 @@ def acquire(
""" # noqa: E501
raise NotImplementedError()

def release(self, *args, **kwargs): # pragma: no cover
def release(self, *args, **kwargs) -> None:
"""Release a lock.
Since the class is thread-local, this cannot be called from other thread or process,
Expand All @@ -156,7 +149,7 @@ def release(self, *args, **kwargs): # pragma: no cover
""" # noqa: E501
raise NotImplementedError()

def close(self, *args, **kwargs):
def close(self, *args, **kwargs) -> None:
"""Same as :meth:`release`
Except that the :class:`ValueError` is **NOT** raised when invoked on an unlocked lock.
Expand Down
6 changes: 5 additions & 1 deletion src/sqlalchemy_dlock/lock/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def convert(value) -> str:
raise ValueError(f"MySQL enforces a maximum length on lock names of {MYSQL_LOCK_NAME_MAX_LENGTH} characters.")
self._actual_key = key

@property
def actual_key(self) -> str:
return self._actual_key


class MysqlSadLock(MysqlSadLockMixin, BaseSadLock):
"""A distributed lock implemented by MySQL named-lock
Expand All @@ -88,7 +92,7 @@ def __init__(self, connection_or_session: TConnectionOrSession, key, **kwargs):
**kwargs: other named parameters pass to :class:`.BaseSadLock` and :class:`.MysqlSadLockMixin`
"""
MysqlSadLockMixin.__init__(self, key=key, **kwargs)
BaseSadLock.__init__(self, connection_or_session, self._actual_key, **kwargs)
BaseSadLock.__init__(self, connection_or_session, self.actual_key, **kwargs)

def acquire(self, block: bool = True, timeout: Union[float, int, None] = None) -> bool:
if self._acquired:
Expand Down
10 changes: 7 additions & 3 deletions src/sqlalchemy_dlock/lock/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,16 @@ def convert(val: Any) -> int:
self._stmt_try_lock = TRY_LOCK_XACT_SHARED.params(key=key)

@property
def shared(self):
def actual_key(self) -> int:
return self._actual_key

@property
def shared(self) -> bool:
"""Is the advisory lock shared or exclusive"""
return self._shared

@property
def xact(self):
def xact(self) -> bool:
"""Is the advisory lock transaction level or session level"""
return self._xact

Expand Down Expand Up @@ -124,7 +128,7 @@ def __init__(self, connection_or_session: TConnectionOrSession, key, **kwargs):
**kwargs: other named parameters pass to :class:`.BaseSadLock` and :class:`.PostgresqlSadLockMixin`
""" # noqa: E501
PostgresqlSadLockMixin.__init__(self, key=key, **kwargs)
BaseSadLock.__init__(self, connection_or_session, self._actual_key, **kwargs)
BaseSadLock.__init__(self, connection_or_session, self.actual_key, **kwargs)

def acquire(
self,
Expand Down
3 changes: 2 additions & 1 deletion src/sqlalchemy_dlock/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import re
from hashlib import blake2b
from sys import byteorder
from typing import Any


def safe_name(s: str) -> str:
return re.sub(r"[^A-Za-z0-9_]+", "_", s).strip().lower()


def to_int64_key(k) -> int:
def to_int64_key(k: Any) -> int:
if isinstance(k, str):
k = k.encode()
if isinstance(k, (bytearray, bytes, memoryview)):
Expand Down
2 changes: 1 addition & 1 deletion tests/asyncio/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class SessionTestCase(IsolatedAsyncioTestCase):
sessions = []
sessions = [] # type: ignore[var-annotated]

def setUp(self):
create_engines()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class SessionTestCase(TestCase):
Sessions = []
Sessions = [] # type: ignore[var-annotated]

@classmethod
def setUpClass(cls):
Expand Down

0 comments on commit 0c6604a

Please sign in to comment.