Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
(cherry picked from commit b1658da) - refactor(listener): support dynamic query parameters (cherry picked from commit 2f54f5b) - docs(listeners): add to python reference (cherry picked from commit 3e225fe) - fix: include missing packages for rb.listener reference (cherry picked from commit 870e71d)
- Loading branch information
1 parent
a21bcf3
commit 65747ab
Showing
15 changed files
with
528 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
.. _python_listeners: | ||
|
||
Listeners | ||
========= | ||
|
||
Here we describe the Rubrix listeners capabilities | ||
|
||
.. automodule:: rubrix.listeners | ||
:members: listener, RBDatasetListener, Metrics, RBListenerContext, Search | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,4 +105,4 @@ dependencies: | |
- webencodings==0.5.1 | ||
- wrapt==1.13.3 | ||
- zipp==3.7.0 | ||
- -e . | ||
- -e ".[listeners]" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Optional | ||
|
||
from rubrix.client.apis import AbstractApi | ||
from rubrix.client.sdk.datasets.models import TaskType | ||
|
||
|
||
class MetricsAPI(AbstractApi): | ||
|
||
_API_URL_PATTERN = "/api/datasets/{task}/{name}/metrics/{metric}:summary" | ||
|
||
def metric_summary( | ||
self, | ||
name: str, | ||
task: TaskType, | ||
metric: str, | ||
query: Optional[str] = None, | ||
**metric_params, | ||
): | ||
url = self._API_URL_PATTERN.format(task=task, name=name, metric=metric) | ||
metric_params = metric_params or {} | ||
query_params = {k: v for k, v in metric_params.items() if v is not None} | ||
if query_params: | ||
url += "?" + "&".join([f"{k}={v}" for k, v in query_params.items()]) | ||
|
||
metric_summary = self.__client__.post(url, json={"query_text": query}) | ||
return metric_summary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import dataclasses | ||
from typing import List, Optional | ||
|
||
from rubrix.client.apis import AbstractApi | ||
from rubrix.client.models import Record | ||
from rubrix.client.sdk.datasets.models import TaskType | ||
from rubrix.client.sdk.text2text.models import Text2TextRecord | ||
from rubrix.client.sdk.text_classification.models import TextClassificationRecord | ||
from rubrix.client.sdk.token_classification.models import TokenClassificationRecord | ||
|
||
|
||
@dataclasses.dataclass | ||
class SearchResults: | ||
total: int | ||
|
||
records: List[Record] | ||
|
||
|
||
class Searches(AbstractApi): | ||
|
||
_API_URL_PATTERN = "/api/datasets/{name}/{task}:search" | ||
|
||
def search_records( | ||
self, | ||
name: str, | ||
task: TaskType, | ||
query: Optional[str], | ||
size: Optional[int] = None, | ||
): | ||
""" | ||
Searches records over a dataset | ||
Args: | ||
name: The dataset name | ||
task: The dataset task type | ||
query: The query string | ||
size: If provided, only the provided number of records will be fetched | ||
Returns: | ||
An instance of ``SearchResults`` class containing the search results | ||
""" | ||
|
||
if task == TaskType.text_classification: | ||
record_class = TextClassificationRecord | ||
elif task == TaskType.token_classification: | ||
record_class = TokenClassificationRecord | ||
elif task == TaskType.text2text: | ||
record_class = Text2TextRecord | ||
else: | ||
raise ValueError(f"Task {task} not supported") | ||
|
||
url = self._API_URL_PATTERN.format(name=name, task=task) | ||
if size: | ||
url += f"{url}?size={size}" | ||
|
||
query_request = {} | ||
if query: | ||
query_request["query_text"] = query | ||
|
||
response = self.__client__.post( | ||
path=url, | ||
json={"query": query_request}, | ||
) | ||
|
||
return SearchResults( | ||
total=response["total"], | ||
records=[ | ||
record_class.parse_obj(r).to_client() for r in response["records"] | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .listener import RBDatasetListener, listener | ||
from .models import Metrics, RBListenerContext, Search |
Oops, something went wrong.