Skip to content

Commit

Permalink
feat(API): provide a dict for record annotations/predictions (#1658)
Browse files Browse the repository at this point in the history
* feat(API): provide a dict for record annotations/predictions

* tests: fixing partial compatible data model helpers

(cherry picked from commit 6a612ac)

* fix: align agents for predictions annotations

* fix: empty agent for annotations dict but check it on single annotation
  • Loading branch information
frascuchon committed Oct 5, 2022
1 parent 85da336 commit 12b0f83
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 16 deletions.
21 changes: 19 additions & 2 deletions src/rubrix/server/daos/backend/mappings/helpers.py
Expand Up @@ -103,6 +103,10 @@ def nested_field():
def decimal_field():
return {"type": "float"}

@classmethod
def dynamic_field(cls):
return {"dynamic": True, "type": "object"}


def multilingual_stop_analyzer(supported_langs: List[str] = None) -> Dict[str, Any]:
"""Multilingual stop analyzer"""
Expand Down Expand Up @@ -154,6 +158,15 @@ def dynamic_metadata_text():
}


def dynamic_annotations_text(path: str):
path = f"{path}.*"
return {
path: mappings.path_match_keyword_template(
path=path, enable_text_search_in_keywords=True
)
}


def tasks_common_mappings():
"""Commons index mappings"""
return {
Expand All @@ -168,16 +181,20 @@ def tasks_common_mappings():
# so we can build extra metrics based on these fields
"prediction": {"type": "object", "enabled": False},
"annotation": {"type": "object", "enabled": False},
"predictions": mappings.dynamic_field(),
"annotations": mappings.dynamic_field(),
"status": mappings.keyword_field(),
"event_timestamp": {"type": "date"},
"last_updated": {"type": "date"},
"annotated_by": mappings.keyword_field(enable_text_search=True),
"predicted_by": mappings.keyword_field(enable_text_search=True),
"metrics": {"dynamic": True, "type": "object"},
"metadata": {"dynamic": True, "type": "object"},
"metrics": mappings.dynamic_field(),
"metadata": mappings.dynamic_field(),
},
"dynamic_templates": [
dynamic_metadata_text(),
dynamic_metrics_text(),
dynamic_annotations_text(path="predictions"),
dynamic_annotations_text(path="annotations"),
],
}
64 changes: 61 additions & 3 deletions src/rubrix/server/daos/models/records.py
Expand Up @@ -16,12 +16,13 @@
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from uuid import uuid4

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator, validator
from pydantic.generics import GenericModel

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.commons.models import PredictionStatus, TaskStatus, TaskType
from rubrix.server.daos.backend.search.model import BackendRecordsQuery, SortConfig
from rubrix.server.errors import ValidationError
from rubrix.server.helpers import flatten_dict
from rubrix.utils import limit_value_length

Expand All @@ -38,7 +39,10 @@ class DaoRecordsSearchResults(BaseModel):


class BaseAnnotationDB(BaseModel):
agent: str = Field(max_length=64)
agent: Optional[str] = Field(
None,
max_length=64,
)


AnnotationDB = TypeVar("AnnotationDB", bound=BaseAnnotationDB)
Expand All @@ -49,9 +53,63 @@ class BaseRecordInDB(GenericModel, Generic[AnnotationDB]):
metadata: Dict[str, Any] = Field(default=None)
event_timestamp: Optional[datetime] = None
status: Optional[TaskStatus] = None
prediction: Optional[AnnotationDB] = None
prediction: Optional[AnnotationDB] = Field(
None, description="Deprecated. Use `predictions` instead"
)
annotation: Optional[AnnotationDB] = None

predictions: Optional[Dict[str, AnnotationDB]] = Field(
None,
description="Provide the prediction info as a key-value dictionary."
"The key will represent the agent ant the value the prediction."
"Using this way you can skip passing the agent inside of the prediction",
)
annotations: Optional[Dict[str, AnnotationDB]] = Field(
None,
description="Provide the annotation info as a key-value dictionary."
"The key will represent the agent ant the value the annotation."
"Using this way you can skip passing the agent inside the annotation",
)

@staticmethod
def update_annotation(values, annotation_field: str):
field_to_update = f"{annotation_field}s"
annotation = values.get(annotation_field)
annotations = values.get(field_to_update) or {}

if annotations:
for key, value in annotations.items():
value.agent = None # Maybe we want key and agents with different values

if annotation:
if not annotation.agent:
raise AssertionError("Agent must be defined!")

annotations.update(
{
annotation.agent: annotation.__class__.parse_obj(
annotation.dict(exclude={"agent"})
)
}
)
values[field_to_update] = annotations

if annotations and not annotation:
# set first annotation
key, value = list(annotations.items())[0]
values[annotation_field] = value.__class__(
agent=key, **value.dict(exclude={"agent"})
)

return values

@root_validator()
def prepare_record_for_db(cls, values):

values = cls.update_annotation(values, "prediction")
values = cls.update_annotation(values, "annotation")
return values

@validator("id", always=True, pre=True)
def default_id_if_none_provided(cls, id: Optional[str]) -> str:
"""Validates id info and sets a random uuid if not provided"""
Expand Down
15 changes: 9 additions & 6 deletions src/rubrix/server/daos/records.py
Expand Up @@ -86,20 +86,23 @@ def add_records(
"""

now = None
now = datetime.datetime.utcnow()
documents = []
metadata_values = {}
mapping = self._es.get_mappings(dataset.id)

if "last_updated" in record_class.schema()["properties"]:
now = datetime.datetime.utcnow()
exclude_fields = [
name
for name in record_class.schema()["properties"]
if name not in mapping["mappings"]["properties"]
]

for r in records:
metadata_values.update(r.metadata or {})
db_record = record_class.parse_obj(r)
if now:
db_record.last_updated = now
db_record.last_updated = now
documents.append(
db_record.dict(exclude_none=False, exclude={"search_keywords"})
db_record.dict(exclude_none=False, exclude=set(exclude_fields))
)

self._es.create_dataset_index(
Expand Down
15 changes: 15 additions & 0 deletions tests/client/sdk/conftest.py
Expand Up @@ -99,6 +99,21 @@ def _expands_schema(
expanded_props = self._expands_schema(field_props, definitions)
definition["items"] = expanded_props.get("properties", expanded_props)
new_schema[name] = definition
elif "allOf" in definition:
allOf_expanded = [
self._expands_schema(
definitions[def_["$ref"].replace("#/definitions/", "")].get(
"properties", {}
),
definitions,
)
for def_ in definition["allOf"]
if "$ref" in def_
]
if len(allOf_expanded) == 1:
new_schema[name] = allOf_expanded[0]
else:
new_schema[name] = allOf_expanded
else:
new_schema[name] = definition
return new_schema
Expand Down
68 changes: 65 additions & 3 deletions tests/server/text2text/test_api.py
@@ -1,14 +1,14 @@
from rubrix.client.sdk.text2text.models import Text2TextBulkData
from rubrix.server.apis.v0.models.commons.model import BulkResponse
from rubrix.server.apis.v0.models.text2text import (
Text2TextBulkRequest,
Text2TextRecordInputs,
Text2TextSearchResults,
)


def test_search_records(mocked_client):
dataset = "test_search_records"
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
delete_dataset(dataset, mocked_client)

records = [
Text2TextRecordInputs.parse_obj(data)
Expand All @@ -32,7 +32,7 @@ def test_search_records(mocked_client):
]
response = mocked_client.post(
f"/api/datasets/{dataset}/Text2Text:bulk",
json=Text2TextBulkData(
json=Text2TextBulkRequest(
tags={"env": "test", "class": "text classification"},
metadata={"config": {"the": "config"}},
records=records,
Expand Down Expand Up @@ -63,3 +63,65 @@ def test_search_records(mocked_client):
"status": {"Default": 2},
"words": {"data": 2, "ånother": 1},
}


def test_api_with_new_predictions_data_model(mocked_client):
dataset = "test_api_with_new_predictions_data_model"
delete_dataset(dataset, mocked_client)

records = [
Text2TextRecordInputs.parse_obj(
{
"text": "This is a text data",
"predictions": {
"test": {
"sentences": [{"text": "This is a test data", "score": 0.6}]
},
},
}
),
Text2TextRecordInputs.parse_obj(
{
"text": "Another data",
"annotations": {
"annotator-1": {"sentences": [{"text": "THis is a test data"}]},
"annotator-2": {"sentences": [{"text": "This IS the test datay"}]},
},
}
),
]

response = mocked_client.post(
f"/api/datasets/{dataset}/Text2Text:bulk",
json=Text2TextBulkRequest(
records=records,
).dict(by_alias=True),
)

assert response.status_code == 200, response.json()
bulk_response = BulkResponse.parse_obj(response.json())
assert bulk_response.dataset == dataset
assert bulk_response.failed == 0
assert bulk_response.processed == 2

response = mocked_client.post(
f"/api/datasets/{dataset}/Text2Text:search",
json={"query": {"query_text": "predictions.test.sentences.text.exact:data"}},
)

assert response.status_code == 200, response.json()
results = Text2TextSearchResults.parse_obj(response.json())
assert results.total == 1, results

response = mocked_client.post(
f"/api/datasets/{dataset}/Text2Text:search",
json={"query": {"query_text": "_exists_:annotations.annotator-1"}},
)

assert response.status_code == 200, response.json()
results = Text2TextSearchResults.parse_obj(response.json())
assert results.total == 1, results


def delete_dataset(dataset, mocked_client):
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
51 changes: 51 additions & 0 deletions tests/server/text2text/test_model.py
Expand Up @@ -50,6 +50,57 @@ def test_model_dict():
{"score": 0.5, "text": "sentence 2"},
],
},
"predictions": {
"test_sentences_sorted_by_score": {
"sentences": [
{"score": 1.0, "text": "sentence " "3"},
{"score": 0.6, "text": "sentence " "1"},
{"score": 0.5, "text": "sentence " "2"},
]
}
},
"status": "Default",
"text": "The input text",
}


def test_model_with_predictions():
record = Text2TextRecord.parse_obj(
{
"id": 0,
"text": "The input text",
"predictions": {
"test_sentences_sorted_by_score": {
"sentences": [
{"score": 1.0, "text": "sentence " "3"},
{"score": 0.6, "text": "sentence " "1"},
{"score": 0.5, "text": "sentence " "2"},
]
}
},
"status": "Default",
}
)
assert record.dict(exclude_none=True) == {
"id": 0,
"metrics": {},
"prediction": {
"agent": "test_sentences_sorted_by_score",
"sentences": [
{"score": 1.0, "text": "sentence 3"},
{"score": 0.6, "text": "sentence 1"},
{"score": 0.5, "text": "sentence 2"},
],
},
"predictions": {
"test_sentences_sorted_by_score": {
"sentences": [
{"score": 1.0, "text": "sentence " "3"},
{"score": 0.6, "text": "sentence " "1"},
{"score": 0.5, "text": "sentence " "2"},
]
}
},
"status": "Default",
"text": "The input text",
}
Expand Down

0 comments on commit 12b0f83

Please sign in to comment.