Skip to content

Commit

Permalink
Support repeated calls to Table.convert()
Browse files Browse the repository at this point in the history
* Test repeated calls to Table.convert()
* Register Table.convert() functions under their own `lambda_hash` name
* Raise exception on registering identical function names

Refs #525
  • Loading branch information
mcarpenter committed May 8, 2023
1 parent 6500fed commit 02f5c4d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
18 changes: 13 additions & 5 deletions sqlite_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ class AlterError(Exception):
pass


class FunctionAlreadyRegistered(Exception):
"A function with this name and arity was already registered"
pass


class NoObviousTable(Exception):
"Could not tell which table this operation refers to"
pass
Expand Down Expand Up @@ -409,7 +414,7 @@ def register(fn):
fn_name = name or fn.__name__
arity = len(inspect.signature(fn).parameters)
if not replace and (fn_name, arity) in self._registered_functions:
return fn
raise FunctionAlreadyRegistered(f'Already registered function with name "{fn_name}" and identical arity')
kwargs = {}
registered = False
if deterministic:
Expand All @@ -434,7 +439,7 @@ def register(fn):

def register_fts4_bm25(self):
"Register the ``rank_bm25(match_info)`` function used for calculating relevance with SQLite FTS4."
self.register_function(rank_bm25, deterministic=True)
self.register_function(rank_bm25, deterministic=True, replace=True)

def attach(self, alias: str, filepath: Union[str, pathlib.Path]):
"""
Expand Down Expand Up @@ -2687,13 +2692,16 @@ def convert_value(v):
return v
return jsonify_if_needed(fn(v))

self.db.register_function(convert_value)
fn_name = fn.__name__
if fn_name == '<lambda>':
fn_name = f'lambda_{hash(fn)}'
self.db.register_function(convert_value, name=fn_name)
sql = "update [{table}] set {sets}{where};".format(
table=self.name,
sets=", ".join(
[
"[{output_column}] = convert_value([{column}])".format(
output_column=output or column, column=column
"[{output_column}] = {fn_name}([{column}])".format(
output_column=output or column, column=column, fn_name=fn_name
)
for column in columns
]
Expand Down
9 changes: 9 additions & 0 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,12 @@ def test_convert_multi_exception(fresh_db):
table.insert({"title": "Mixed Case"})
with pytest.raises(BadMultiValues):
table.convert("title", lambda v: v.upper(), multi=True)


def test_convert_repeated(fresh_db):
table = fresh_db["table"]
col = "num"
table.insert({col: 1})
table.convert(col, lambda x: x*2)
table.convert(col, lambda _x: 0)
assert table.get(1) == {col: 0}
17 changes: 13 additions & 4 deletions tests/test_register_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from unittest.mock import MagicMock, call
from sqlite_utils.utils import sqlite3

from sqlite_utils.db import FunctionAlreadyRegistered

def test_register_function(fresh_db):
@fresh_db.register_function
Expand Down Expand Up @@ -85,9 +85,10 @@ def one():
assert "one" == fresh_db.execute("select one()").fetchone()[0]

# This will fail to replace the function:
@fresh_db.register_function()
def one(): # noqa
return "two"
with pytest.raises(FunctionAlreadyRegistered):
@fresh_db.register_function()
def one(): # noqa
return "two"

assert "one" == fresh_db.execute("select one()").fetchone()[0]

Expand All @@ -97,3 +98,11 @@ def one(): # noqa
return "two"

assert "two" == fresh_db.execute("select one()").fetchone()[0]


def test_register_function_duplicate(fresh_db):
def to_lower(s):
return s.lower()
fresh_db.register_function(to_lower)
with pytest.raises(FunctionAlreadyRegistered):
fresh_db.register_function(to_lower)

0 comments on commit 02f5c4d

Please sign in to comment.