Skip to content

Commit

Permalink
feat(#950): include search keywords as part of record results (#1201)
Browse files Browse the repository at this point in the history
* chore: include search_keywords in client records

* chore: signatures

* feat: include search_records as part of client records

* fix: add highlight on dataset scan

* test: add missing tests

* test: estabilize tests

* Apply suggestions from code review

Co-authored-by: David Fidalgo <david@recogn.ai>

* test: try to fix push to hf hub

Co-authored-by: David Fidalgo <david@recogn.ai>

(cherry picked from commit 0678043)
  • Loading branch information
frascuchon committed Mar 4, 2022
1 parent 8cb3dca commit 2dd5853
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 9 deletions.
14 changes: 12 additions & 2 deletions src/rubrix/client/models.py
Expand Up @@ -147,7 +147,9 @@ class TextClassificationRecord(_Validators):
metrics:
READ ONLY! Metrics at record level provided by the server when using `rb.load`.
This attribute will be ignored when using `rb.log`.
search_keywords:
READ ONLY! Relevant record keywords/terms for provided query when using `rb.load`.
This attribute will be ignored when using `rb.log`.
Examples:
>>> import rubrix as rb
>>> record = rb.TextClassificationRecord(
Expand All @@ -172,6 +174,7 @@ class TextClassificationRecord(_Validators):
event_timestamp: Optional[datetime.datetime] = None

metrics: Optional[Dict[str, Any]] = None
search_keywords: Optional[List[str]] = None

@validator("inputs", pre=True)
def input_as_dict(cls, inputs):
Expand Down Expand Up @@ -213,7 +216,9 @@ class TokenClassificationRecord(_Validators):
metrics:
READ ONLY! Metrics at record level provided by the server when using `rb.load`.
This attribute will be ignored when using `rb.log`.
search_keywords:
READ ONLY! Relevant record keywords/terms for provided query when using `rb.load`.
This attribute will be ignored when using `rb.log`.
Examples:
>>> import rubrix as rb
>>> record = rb.TokenClassificationRecord(
Expand All @@ -239,6 +244,7 @@ class TokenClassificationRecord(_Validators):
event_timestamp: Optional[datetime.datetime] = None

metrics: Optional[Dict[str, Any]] = None
search_keywords: Optional[List[str]] = None

@validator("prediction")
def add_default_score(
Expand Down Expand Up @@ -283,6 +289,9 @@ class Text2TextRecord(_Validators):
metrics:
READ ONLY! Metrics at record level provided by the server when using `rb.load`.
This attribute will be ignored when using `rb.log`.
search_keywords:
READ ONLY! Relevant record keywords/terms for provided query when using `rb.load`.
This attribute will be ignored when using `rb.log`.
Examples:
>>> import rubrix as rb
Expand All @@ -305,6 +314,7 @@ class Text2TextRecord(_Validators):
event_timestamp: Optional[datetime.datetime] = None

metrics: Optional[Dict[str, Any]] = None
search_keywords: Optional[List[str]] = None

@validator("prediction")
def prediction_as_tuples(
Expand Down
1 change: 1 addition & 0 deletions src/rubrix/client/sdk/commons/models.py
Expand Up @@ -46,6 +46,7 @@ class BaseRecord(GenericModel, Generic[T]):
prediction: Optional[T] = None
annotation: Optional[T] = None
metrics: Dict[str, Any] = Field(default_factory=dict)
search_keywords: Optional[List[str]] = None

# this is a small hack to get a json-compatible serialization on cls.dict(), which we use for the httpx calls.
# they want to build this feature into pydantic, see https://github.com/samuelcolvin/pydantic/issues/1409
Expand Down
1 change: 1 addition & 0 deletions src/rubrix/client/sdk/text2text/models.py
Expand Up @@ -94,6 +94,7 @@ def to_client(self) -> ClientText2TextRecord:
id=self.id,
event_timestamp=self.event_timestamp,
metrics=self.metrics or None,
search_keywords=self.search_keywords or None,
)


Expand Down
1 change: 1 addition & 0 deletions src/rubrix/client/sdk/text_classification/models.py
Expand Up @@ -129,6 +129,7 @@ def to_client(self) -> ClientTextClassificationRecord:
if self.explanation
else None,
metrics=self.metrics or None,
search_keywords=self.search_keywords or None,
)


Expand Down
3 changes: 2 additions & 1 deletion src/rubrix/client/sdk/token_classification/models.py
Expand Up @@ -117,7 +117,8 @@ def to_client(self) -> ClientTokenClassificationRecord:
event_timestamp=self.event_timestamp,
status=self.status,
metadata=self.metadata or {},
metrics=self.metrics or {},
metrics=self.metrics or None,
search_keywords=self.search_keywords or None,
)


Expand Down
7 changes: 7 additions & 0 deletions src/rubrix/server/tasks/commons/api/model.py
Expand Up @@ -179,6 +179,13 @@ class BaseRecord(GenericModel, Generic[Annotation]):
prediction: Optional[Annotation] = None
annotation: Optional[Annotation] = None
metrics: Dict[str, Any] = Field(default_factory=dict)
search_keywords: Optional[List[str]] = None

@validator("search_keywords")
def remove_duplicated_keywords(cls, value) -> List[str]:
"""Remove duplicated keywords"""
if value:
return list(set(value))

@validator("id", always=True)
def default_id_if_none_provided(cls, id: Optional[str]) -> str:
Expand Down
50 changes: 44 additions & 6 deletions src/rubrix/server/tasks/commons/dao/dao.py
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import datetime
import re
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar

import deprecated
Expand Down Expand Up @@ -129,6 +130,12 @@ class DatasetRecordsDAO:
# This info must be provided by each task using dao.register_task_mappings method
_MAPPINGS_BY_TASKS = {}

__HIGHLIGHT_PRE_TAG__ = "<@@-rb-key>"
__HIGHLIGHT_POST_TAG__ = "</@@-rb-key>"
__HIGHLIGHT_VALUES_REGEX__ = re.compile(
rf"{__HIGHLIGHT_PRE_TAG__}(.+?){__HIGHLIGHT_POST_TAG__}"
)

@classmethod
def get_instance(
cls,
Expand Down Expand Up @@ -158,7 +165,7 @@ def init(self):
def add_records(
self,
dataset: BaseDatasetDB,
records: List[BaseRecord],
records: List[DBRecord],
record_class: Type[DBRecord],
) -> int:
"""
Expand Down Expand Up @@ -190,7 +197,9 @@ def add_records(
db_record = record_class.parse_obj(r)
if now:
db_record.last_updated = now
documents.append(db_record.dict(exclude_none=False))
documents.append(
db_record.dict(exclude_none=False, exclude={"search_keywords"})
)

index_name = self.create_dataset_index(dataset)
self._configure_metadata_fields(index_name, metadata_values)
Expand Down Expand Up @@ -246,6 +255,7 @@ def search_records(
"query": search.query or {"match_all": {}},
"sort": search.sort or [{"_id": {"order": "asc"}}],
"aggs": aggregation_requests,
"highlight": self.__configure_query_highlight__(),
}

try:
Expand Down Expand Up @@ -282,7 +292,7 @@ def search_records(

result = RecordSearchResults(
total=total,
records=list(map(self.esdoc2record, docs)),
records=list(map(self.__esdoc2record__, docs)),
)
if search_aggregations:
parsed_aggregations = parse_aggregations(search_aggregations)
Expand Down Expand Up @@ -319,15 +329,34 @@ def scan_dataset(
search = search or RecordSearch()
es_query = {
"query": search.query,
"highlight": self.__configure_query_highlight__(),
}
docs = self._es.list_documents(
dataset_records_index(dataset.id), query=es_query
)
for doc in docs:
yield self.esdoc2record(doc)
yield self.__esdoc2record__(doc)

def __esdoc2record__(self, doc: Dict[str, Any]):
return {
**doc["_source"],
"id": doc["_id"],
"search_keywords": self.__parse_highlight_results__(doc),
}

def esdoc2record(self, doc):
return {**doc["_source"], "id": doc["_id"]}
@classmethod
def __parse_highlight_results__(cls, doc: Dict[str, Any]) -> Optional[List[str]]:
highlight_info = doc.get("highlight")
if not highlight_info:
return None

search_keywords = []
for content in highlight_info.values():
if not isinstance(content, list):
content = [content]
for text in content:
search_keywords.extend(re.findall(cls.__HIGHLIGHT_VALUES_REGEX__, text))
return list(set(search_keywords))

def _configure_metadata_fields(self, index: str, metadata_values: Dict[str, Any]):
def check_metadata_length(metadata_length: int = 0):
Expand Down Expand Up @@ -406,6 +435,15 @@ def get_dataset_schema(self, dataset: BaseDatasetDB) -> Dict[str, Any]:
index_name = dataset_records_index(dataset.id)
return self._es.__client__.indices.get_mapping(index=index_name)

@classmethod
def __configure_query_highlight__(cls):
return {
"pre_tags": [cls.__HIGHLIGHT_PRE_TAG__],
"post_tags": [cls.__HIGHLIGHT_POST_TAG__],
"require_field_match": False,
"fields": {"text": {}},
}


_instance: Optional[DatasetRecordsDAO] = None

Expand Down
25 changes: 25 additions & 0 deletions tests/functional_tests/test_log_for_text_classification.py
Expand Up @@ -50,6 +50,31 @@ def test_delete_and_create_for_different_task(mocked_client):
rubrix.load(dataset)


def test_search_keywords(mocked_client):
dataset = "test_search_keywords"
from datasets import load_dataset

dataset_ds = load_dataset("Recognai/sentiment-banking", split="train")
dataset_rb = rubrix.read_datasets(dataset_ds, task="TextClassification")

rubrix.delete(dataset)
rubrix.log(name=dataset, records=dataset_rb)

df = rubrix.load(dataset, query="lim*")
assert not df.empty
assert "search_keywords" in df.columns
top_keywords = set(
[
keyword
for keywords in df.search_keywords.value_counts(sort=True, ascending=False)
.index[:3]
.tolist()
for keyword in keywords
]
)
assert {"limit", "limits", "limited"} == top_keywords, top_keywords


def test_log_records_with_empty_metadata_list(mocked_client):
dataset = "test_log_records_with_empty_metadata_list"

Expand Down
25 changes: 25 additions & 0 deletions tests/functional_tests/test_log_for_token_classification.py
Expand Up @@ -446,3 +446,28 @@ def test_log_record_that_makes_me_cry(mocked_client):
},
"annotated": {"mentions": []},
}


def test_search_keywords(mocked_client):
dataset = "test_search_keywords"
from datasets import load_dataset

dataset_ds = load_dataset("rubrix/gutenberg_spacy-ner", split="train")
dataset_rb = rubrix.read_datasets(dataset_ds, task="TokenClassification")

rubrix.delete(dataset)
rubrix.log(name=dataset, records=dataset_rb)

df = rubrix.load(dataset, query="lis*")
assert not df.empty
assert "search_keywords" in df.columns
top_keywords = set(
[
keyword
for keywords in df.search_keywords.value_counts(sort=True, ascending=False)
.index[:3]
.tolist()
for keyword in keywords
]
)
assert {"listened", "listen"} == top_keywords, top_keywords

0 comments on commit 2dd5853

Please sign in to comment.