Skip to content

Commit

Permalink
feat(#1602): new rubrix dataset listeners (#1507, #1586, #1583, #1596)
Browse files Browse the repository at this point in the history
(cherry picked from commit b1658da)

- refactor(listener): support dynamic query parameters
(cherry picked from commit 2f54f5b)

- docs(listeners): add to python reference
(cherry picked from commit 3e225fe)

- fix: include missing packages for rb.listener reference
(cherry picked from commit 870e71d)
  • Loading branch information
frascuchon committed Jul 8, 2022
1 parent a21bcf3 commit 65747ab
Show file tree
Hide file tree
Showing 15 changed files with 528 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/reference/python/index.rst
Expand Up @@ -8,6 +8,7 @@ The python reference guide for Rubrix. This section contains:
* :ref:`python_client`: The base client module
* :ref:`python_metrics`: The module for dataset metrics
* :ref:`python_labeling`: A toolbox to enhance your labeling workflow (weak labels, noisy labels, etc.)
* :ref:`python_listeners`: This module contains all you need to define and configure dataset rubrix listeners

.. toctree::
:maxdepth: 2
Expand All @@ -17,3 +18,4 @@ The python reference guide for Rubrix. This section contains:
python_client
python_metrics
python_labeling
python_listeners
12 changes: 12 additions & 0 deletions docs/reference/python/python_listeners.rst
@@ -0,0 +1,12 @@
.. _python_listeners:

Listeners
=========

Here we describe the Rubrix listeners capabilities

.. automodule:: rubrix.listeners
:members: listener, RBDatasetListener, Metrics, RBListenerContext, Search



2 changes: 1 addition & 1 deletion environment_dev.yml
Expand Up @@ -45,4 +45,4 @@ dependencies:
- transformers[torch]~=4.18.0
- loguru
# install Rubrix in editable mode
- -e .[server]
- -e .[server,listeners]
2 changes: 1 addition & 1 deletion environment_docs.yml
Expand Up @@ -105,4 +105,4 @@ dependencies:
- webencodings==0.5.1
- wrapt==1.13.3
- zipp==3.7.0
- -e .
- -e ".[listeners]"
4 changes: 4 additions & 0 deletions pyproject.toml
Expand Up @@ -68,6 +68,10 @@ server = [
"hurry.filesize", # TODO: remove
"psutil ~= 5.8.0",
]
listeners = [
"schedule ~= 1.1.0",
"prodict ~= 0.8.0"
]

[project.urls]
homepage = "https://www.rubrix.ml"
Expand Down
2 changes: 2 additions & 0 deletions src/rubrix/__init__.py
Expand Up @@ -57,6 +57,7 @@
TokenClassificationSettings,
configure_dataset,
)
from rubrix.listeners import Metrics, RBListenerContext, Search, listener
from rubrix.monitoring.model_monitor import monitor
from rubrix.server.server import app

Expand Down Expand Up @@ -85,6 +86,7 @@
"read_pandas",
],
"monitoring.model_monitor": ["monitor"],
"listeners.listener": ["listener", "RBListenerContext", "Search", "Metrics"],
"datasets": [
"configure_dataset",
"TextClassificationSettings",
Expand Down
10 changes: 10 additions & 0 deletions src/rubrix/client/api.py
Expand Up @@ -30,6 +30,8 @@
RUBRIX_WORKSPACE_HEADER_NAME,
)
from rubrix.client.apis.datasets import Datasets
from rubrix.client.apis.metrics import MetricsAPI
from rubrix.client.apis.searches import Searches
from rubrix.client.datasets import (
Dataset,
DatasetForText2Text,
Expand Down Expand Up @@ -161,6 +163,14 @@ def client(self):
def datasets(self) -> Datasets:
return Datasets(client=self._client)

@property
def searches(self):
return Searches(client=self._client)

@property
def metrics(self):
return MetricsAPI(client=self.client)

def set_workspace(self, workspace: str):
"""Sets the active workspace.
Expand Down
26 changes: 26 additions & 0 deletions src/rubrix/client/apis/metrics.py
@@ -0,0 +1,26 @@
from typing import Optional

from rubrix.client.apis import AbstractApi
from rubrix.client.sdk.datasets.models import TaskType


class MetricsAPI(AbstractApi):

_API_URL_PATTERN = "/api/datasets/{task}/{name}/metrics/{metric}:summary"

def metric_summary(
self,
name: str,
task: TaskType,
metric: str,
query: Optional[str] = None,
**metric_params,
):
url = self._API_URL_PATTERN.format(task=task, name=name, metric=metric)
metric_params = metric_params or {}
query_params = {k: v for k, v in metric_params.items() if v is not None}
if query_params:
url += "?" + "&".join([f"{k}={v}" for k, v in query_params.items()])

metric_summary = self.__client__.post(url, json={"query_text": query})
return metric_summary
70 changes: 70 additions & 0 deletions src/rubrix/client/apis/searches.py
@@ -0,0 +1,70 @@
import dataclasses
from typing import List, Optional

from rubrix.client.apis import AbstractApi
from rubrix.client.models import Record
from rubrix.client.sdk.datasets.models import TaskType
from rubrix.client.sdk.text2text.models import Text2TextRecord
from rubrix.client.sdk.text_classification.models import TextClassificationRecord
from rubrix.client.sdk.token_classification.models import TokenClassificationRecord


@dataclasses.dataclass
class SearchResults:
total: int

records: List[Record]


class Searches(AbstractApi):

_API_URL_PATTERN = "/api/datasets/{name}/{task}:search"

def search_records(
self,
name: str,
task: TaskType,
query: Optional[str],
size: Optional[int] = None,
):
"""
Searches records over a dataset
Args:
name: The dataset name
task: The dataset task type
query: The query string
size: If provided, only the provided number of records will be fetched
Returns:
An instance of ``SearchResults`` class containing the search results
"""

if task == TaskType.text_classification:
record_class = TextClassificationRecord
elif task == TaskType.token_classification:
record_class = TokenClassificationRecord
elif task == TaskType.text2text:
record_class = Text2TextRecord
else:
raise ValueError(f"Task {task} not supported")

url = self._API_URL_PATTERN.format(name=name, task=task)
if size:
url += f"{url}?size={size}"

query_request = {}
if query:
query_request["query_text"] = query

response = self.__client__.post(
path=url,
json={"query": query_request},
)

return SearchResults(
total=response["total"],
records=[
record_class.parse_obj(r).to_client() for r in response["records"]
],
)
4 changes: 2 additions & 2 deletions src/rubrix/client/sdk/client.py
Expand Up @@ -85,7 +85,7 @@ def post(self, path: str, *args, **kwargs):
*args,
**kwargs,
)
return build_raw_response(response)
return build_raw_response(response).parsed

def put(self, path: str, *args, **kwargs):
path = self._normalize_path(path)
Expand All @@ -99,7 +99,7 @@ def put(self, path: str, *args, **kwargs):
*args,
**kwargs,
)
return build_raw_response(response)
return build_raw_response(response).parsed

@staticmethod
def _normalize_path(path: str) -> str:
Expand Down
2 changes: 2 additions & 0 deletions src/rubrix/listeners/__init__.py
@@ -0,0 +1,2 @@
from .listener import RBDatasetListener, listener
from .models import Metrics, RBListenerContext, Search

0 comments on commit 65747ab

Please sign in to comment.