Skip to content

Commit

Permalink
bugfix: wrong pg key; wrong mysql type error; wrong annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
tanbro committed May 6, 2024
1 parent cc053f4 commit ba83161
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 42 deletions.
6 changes: 4 additions & 2 deletions src/sqlalchemy_dlock/lock/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def convert(value) -> str:
""" # noqa: E501
if convert:
self._actual_key = convert(key)
elif not isinstance(key, str):
elif isinstance(key, str):
self._actual_key = key
else:
self._actual_key = default_convert(key)
if not isinstance(key, str):
if not isinstance(self._actual_key, str):
raise TypeError("MySQL named lock requires the key given by string")
if len(self._actual_key) > MYSQL_LOCK_NAME_MAX_LENGTH:
raise ValueError(f"MySQL enforces a maximum length on lock names of {MYSQL_LOCK_NAME_MAX_LENGTH} characters.")
Expand Down
22 changes: 11 additions & 11 deletions src/sqlalchemy_dlock/lock/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ def convert(val: Any) -> int:
#
self._stmt_unlock = None
if not shared and not xact:
self._stmt_lock = LOCK.params(key=key)
self._stmt_try_lock = TRY_LOCK.params(key=key)
self._stmt_unlock = UNLOCK.params(key=key)
self._stmt_lock = LOCK.params(key=self._actual_key)
self._stmt_try_lock = TRY_LOCK.params(key=self._actual_key)
self._stmt_unlock = UNLOCK.params(key=self._actual_key)
elif shared and not xact:
self._stmt_lock = LOCK_SHARED.params(key=key)
self._stmt_try_lock = TRY_LOCK_SHARED.params(key=key)
self._stmt_unlock = UNLOCK_SHARED.params(key=key)
self._stmt_lock = LOCK_SHARED.params(key=self._actual_key)
self._stmt_try_lock = TRY_LOCK_SHARED.params(key=self._actual_key)
self._stmt_unlock = UNLOCK_SHARED.params(key=self._actual_key)
elif not shared and xact:
self._stmt_lock = LOCK_XACT.params(key=key)
self._stmt_try_lock = TRY_LOCK_XACT.params(key=key)
self._stmt_lock = LOCK_XACT.params(key=self._actual_key)
self._stmt_try_lock = TRY_LOCK_XACT.params(key=self._actual_key)
else:
self._stmt_lock = LOCK_XACT_SHARED.params(key=key)
self._stmt_try_lock = TRY_LOCK_XACT_SHARED.params(key=key)
self._stmt_lock = LOCK_XACT_SHARED.params(key=self._actual_key)
self._stmt_try_lock = TRY_LOCK_XACT_SHARED.params(key=self._actual_key)

@property
def actual_key(self) -> int:
Expand Down Expand Up @@ -145,7 +145,7 @@ def acquire(
if block:
if timeout is None:
# None: set the timeout period to infinite.
_ = self.connection_or_session.execute(self._stmt_lock).all()
self.connection_or_session.execute(self._stmt_lock).all()
self._acquired = True
else:
# negative value for `timeout` are equivalent to a `timeout` of zero.
Expand Down
4 changes: 3 additions & 1 deletion src/sqlalchemy_dlock/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import re
from hashlib import blake2b
from sys import byteorder
from typing import Union, TYPE_CHECKING

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
from _typeshed import ReadableBuffer


Expand Down
15 changes: 1 addition & 14 deletions tests/asyncio/test_key_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,7 @@ async def test_mysql_key_gt_max_length(self):
create_async_sadlock(conn, key)

async def test_mysql_key_not_a_string(self):
keys = (
None,
1,
0,
-1,
0.1,
True,
False,
(),
[],
set(),
{},
object(),
)
keys = None, 1, 0, -1, 0.1, True, False, (), [], set(), {}, object()

for engine in get_engines():
if engine.name != "mysql":
Expand Down
15 changes: 1 addition & 14 deletions tests/test_key_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,7 @@ def test_mysql_key_gt_max_length(self):
create_sadlock(conn, key)

def test_mysql_key_not_a_string(self):
keys = (
None,
1,
0,
-1,
0.1,
True,
False,
(),
[],
set(),
{},
object(),
)
keys = None, 1, 0, -1, 0.1, True, False, (), [], set(), {}, object()

for engine in ENGINES:
if engine.name != "mysql":
Expand Down

0 comments on commit ba83161

Please sign in to comment.