Skip to content

Commit

Permalink
feat(Listeners): allow using query params in the condition through se…
Browse files Browse the repository at this point in the history
…arch parameter (#1627)

Closes #1622

(cherry picked from commit e6f28fe)
  • Loading branch information
frascuchon committed Aug 22, 2022
1 parent 6109648 commit a0a245d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
29 changes: 27 additions & 2 deletions src/rubrix/listeners/listener.py
@@ -1,4 +1,6 @@
import copy
import dataclasses
import functools
import logging
import threading
import time
Expand Down Expand Up @@ -76,6 +78,23 @@ def is_running(self):
"""True if listener is running"""
return self.__listener_job__ is not None

def __catch_exceptions__(self, cancel_on_failure=False):
def catch_exceptions_decorator(job_func):
@functools.wraps(job_func)
def wrapper(*args, **kwargs):
try:
return job_func(*args, **kwargs)
except:
import traceback

print(traceback.format_exc())
if cancel_on_failure:
self.stop() # We stop the scheduler

return wrapper

return catch_exceptions_decorator

def start(self, *action_args, **action_kwargs):
"""
Start listen to changes in the dataset. Additionally, args and kwargs can be passed to action
Expand All @@ -87,9 +106,13 @@ def start(self, *action_args, **action_kwargs):
if self.is_running():
raise ValueError("Listener is already running")

job_step = self.__catch_exceptions__(cancel_on_failure=True)(
self.__listener_iteration_job__
)

self.__listener_job__ = self.__scheduler__.every(
self.interval_in_seconds
).seconds.do(self.__listener_iteration_job__, *action_args, **action_kwargs)
).seconds.do(job_step, *action_args, **action_kwargs)

class _ScheduleThread(threading.Thread):
_WAIT_EVENT = threading.Event()
Expand Down Expand Up @@ -158,7 +181,9 @@ def __listener_iteration_job__(self, *args, **kwargs):
name=self.dataset, task=dataset.task, query=self.formatted_query, size=0
)

ctx.search = Search(total=search_results.total)
ctx.search = Search(
total=search_results.total, query_params=copy.deepcopy(ctx.query_params)
)
condition_args = [ctx.search]
if self.metrics:
condition_args.append(ctx.metrics)
Expand Down
4 changes: 3 additions & 1 deletion src/rubrix/listeners/models.py
Expand Up @@ -13,9 +13,11 @@ class Search:
Args:
total: The total number of records affected by the listener query
query_params: The query parameters applied to the executed search
"""

total: int
query_params: Optional[Dict[str, Any]] = None


class Metrics(Prodict):
Expand Down Expand Up @@ -63,7 +65,7 @@ def query(self) -> Optional[str]:
return self.__listener__.formatted_query


ListenerCondition = Callable[[Search, Metrics], bool]
ListenerCondition = Callable[[Search, Optional[RBListenerContext]], bool]
ListenerAction = Union[
Callable[[List[Record], RBListenerContext], bool],
Callable[[RBListenerContext], bool],
Expand Down
9 changes: 8 additions & 1 deletion tests/listeners/test_listener.py
Expand Up @@ -8,6 +8,12 @@
from rubrix.client.models import Record


def condition_check_params(search):
if search:
assert "param" in search.query_params and search.query_params["param"] == 100
return True


@pytest.mark.parametrize(
argnames=["dataset", "query", "metrics", "condition", "query_params"],
argvalues=[
Expand All @@ -17,7 +23,7 @@
("dataset", "val", None, None, None),
("dataset", None, ["F1"], lambda search, metrics: False, None),
("dataset", "val", None, lambda q: False, None),
("dataset", "val + {param}", None, lambda q: True, {"param": 100}),
("dataset", "val + {param}", None, condition_check_params, {"param": 100}),
],
)
def test_listener_with_parameters(
Expand Down Expand Up @@ -67,6 +73,7 @@ def action(self, records: List[Record], ctx: RBListenerContext):
test.action.start()

time.sleep(1.5)
assert test.action.is_running()
test.action.stop()
assert not test.action.is_running()

Expand Down

0 comments on commit a0a245d

Please sign in to comment.