Skip to content

Commit

Permalink
fix(#1027): Improve client models by reordering fields + forbidding e…
Browse files Browse the repository at this point in the history
…xtra args (#1032)

* feat: introduce RootValidators + correct order

* feat: forbid extras + test

(cherry picked from commit c1e32d1)
  • Loading branch information
David Fidalgo authored and frascuchon committed Jan 31, 2022
1 parent c24fdad commit 6c1ae7f
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 62 deletions.
127 changes: 68 additions & 59 deletions src/rubrix/client/models.py
Expand Up @@ -21,12 +21,63 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator, validator

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.commons.helpers import limit_value_length


class _RootValidators(BaseModel):
"""Base class for our record models that takes care of root validations"""

@root_validator
def _check_value_length(cls, values):
"""Checks metadata values length and apply value truncation for large values"""
new_metadata = limit_value_length(
values["metadata"], max_length=MAX_KEYWORD_LENGTH
)
if new_metadata != values["metadata"]:
warnings.warn(
"Some metadata values exceed the max length. "
f"Those values will be truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters."
)
values["metadata"] = new_metadata

return values

@root_validator
def _check_agents(cls, values):
"""Triggers a warning when ONLY agents are provided"""
if (
values.get("annotation_agent") is not None
and values.get("annotation") is None
):
warnings.warn(
"You provided an `annotation_agent`, but no `annotation`. The `annotation_agent` will not be logged to the server."
)
if (
values.get("prediction_agent") is not None
and values.get("prediction") is None
):
warnings.warn(
"You provided an `prediction_agent`, but no `prediction`. The `prediction_agent` will not be logged to the server."
)

return values

@root_validator
def _check_and_update_status(cls, values):
"""Updates the status if an annotation is provided and no status is specified."""
values["status"] = values.get("status") or (
"Default" if values.get("annotation") is None else "Validated"
)

return values

class Config:
extra = "forbid"


class BulkResponse(BaseModel):
"""Summary response when logging records to the Rubrix server.
Expand Down Expand Up @@ -55,7 +106,7 @@ class TokenAttributions(BaseModel):
attributions: Dict[str, float] = Field(default_factory=dict)


class TextClassificationRecord(BaseModel):
class TextClassificationRecord(_RootValidators):
"""Record for text classification
Args:
Expand All @@ -64,10 +115,10 @@ class TextClassificationRecord(BaseModel):
prediction:
A list of tuples containing the predictions for the record.
The first entry of the tuple is the predicted label, the second entry is its corresponding score.
annotation:
A string or a list of strings (multilabel) corresponding to the annotation (gold label) for the record.
prediction_agent:
Name of the prediction agent. By default, this is set to the hostname of your machine.
annotation:
A string or a list of strings (multilabel) corresponding to the annotation (gold label) for the record.
annotation_agent:
Name of the prediction agent. By default, this is set to the hostname of your machine.
multi_label:
Expand Down Expand Up @@ -99,17 +150,18 @@ class TextClassificationRecord(BaseModel):
inputs: Union[str, List[str], Dict[str, Union[str, List[str]]]]

prediction: Optional[List[Tuple[str, float]]] = None
annotation: Optional[Union[str, List[str]]] = None
prediction_agent: Optional[str] = None
annotation: Optional[Union[str, List[str]]] = None
annotation_agent: Optional[str] = None
multi_label: bool = False

multi_label: bool = False
explanation: Optional[Dict[str, List[TokenAttributions]]] = None

id: Optional[Union[int, str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
status: Optional[str] = None
event_timestamp: Optional[datetime.datetime] = None

metrics: Optional[Dict[str, Any]] = None

@validator("inputs", pre=True)
Expand All @@ -119,20 +171,8 @@ def input_as_dict(cls, inputs):
return inputs
return dict(text=inputs)

@validator("metadata", pre=True)
def check_value_length(cls, metadata):
return _limit_metadata_values(metadata)

def __init__(self, *args, **kwargs):
"""Custom init to handle dynamic defaults"""
# noinspection PyArgumentList
super().__init__(*args, **kwargs)
self.status = self.status or (
"Default" if self.annotation is None else "Validated"
)


class TokenClassificationRecord(BaseModel):
class TokenClassificationRecord(_RootValidators):
"""Record for a token classification task
Args:
Expand All @@ -145,11 +185,11 @@ class TokenClassificationRecord(BaseModel):
A list of tuples containing the predictions for the record. The first entry of the tuple is the name of
predicted entity, the second and third entry correspond to the start and stop character index of the entity.
EXPERIMENTAL: The fourth entry is optional and corresponds to the score of the entity.
prediction_agent:
Name of the prediction agent. By default, this is set to the hostname of your machine.
annotation:
A list of tuples containing annotations (gold labels) for the record. The first entry of the tuple is the
name of the entity, the second and third entry correspond to the start and stop char index of the entity.
prediction_agent:
Name of the prediction agent. By default, this is set to the hostname of your machine.
annotation_agent:
Name of the prediction agent. By default, this is set to the hostname of your machine.
id:
Expand Down Expand Up @@ -180,29 +220,19 @@ class TokenClassificationRecord(BaseModel):
prediction: Optional[
List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]]
] = None
annotation: Optional[List[Tuple[str, int, int]]] = None
prediction_agent: Optional[str] = None
annotation: Optional[List[Tuple[str, int, int]]] = None
annotation_agent: Optional[str] = None

id: Optional[Union[int, str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
status: Optional[str] = None
event_timestamp: Optional[datetime.datetime] = None
metrics: Optional[Dict[str, Any]] = None

@validator("metadata", pre=True)
def check_value_length(cls, metadata):
return _limit_metadata_values(metadata)

def __init__(self, *args, **kwargs):
"""Custom init to handle dynamic defaults"""
super().__init__(*args, **kwargs)
self.status = self.status or (
"Default" if self.annotation is None else "Validated"
)
metrics: Optional[Dict[str, Any]] = None


class Text2TextRecord(BaseModel):
class Text2TextRecord(_RootValidators):
"""Record for a text to text task
Args:
Expand All @@ -211,10 +241,10 @@ class Text2TextRecord(BaseModel):
prediction:
A list of strings or tuples containing predictions for the input text.
If tuples, the first entry is the predicted text, the second entry is its corresponding score.
annotation:
A string representing the expected output text for the given input text.
prediction_agent:
Name of the prediction agent. By default, this is set to the hostname of your machine.
annotation:
A string representing the expected output text for the given input text.
annotation_agent:
Name of the prediction agent. By default, this is set to the hostname of your machine.
id:
Expand All @@ -241,14 +271,15 @@ class Text2TextRecord(BaseModel):
text: str

prediction: Optional[List[Union[str, Tuple[str, float]]]] = None
annotation: Optional[str] = None
prediction_agent: Optional[str] = None
annotation: Optional[str] = None
annotation_agent: Optional[str] = None

id: Optional[Union[int, str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
status: Optional[str] = None
event_timestamp: Optional[datetime.datetime] = None

metrics: Optional[Dict[str, Any]] = None

@validator("prediction")
Expand All @@ -262,27 +293,5 @@ def prediction_as_tuples(
return prediction
return [(text, 1.0) for text in prediction]

@validator("metadata", pre=True)
def check_value_length(cls, metadata):
return _limit_metadata_values(metadata)

def __init__(self, *args, **kwargs):
"""Custom init to handle dynamic defaults"""
super().__init__(*args, **kwargs)
self.status = self.status or (
"Default" if self.annotation is None else "Validated"
)


def _limit_metadata_values(metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Checks metadata values length and apply value truncation for large values"""
new_value = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH)
if new_value != metadata:
warnings.warn(
"Some metadata values exceed the max length. "
f"Those values will be truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters."
)
return new_value


Record = Union[TextClassificationRecord, TokenClassificationRecord, Text2TextRecord]
43 changes: 40 additions & 3 deletions tests/client/test_models.py
Expand Up @@ -13,14 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Optional

import numpy
import pytest
from pydantic import ValidationError
from rubrix._constants import MAX_KEYWORD_LENGTH

from rubrix.client.models import Text2TextRecord, TextClassificationRecord
from rubrix.client.models import TokenClassificationRecord
from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.client.models import (
Text2TextRecord,
TextClassificationRecord,
TokenClassificationRecord,
_RootValidators,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -92,3 +97,35 @@ def test_model_serialization_with_numpy_nan():
)

json_record = json.loads(record.json())


def test_warning_when_only_agent():
class MockRecord(_RootValidators):
prediction: Optional[Any] = None
prediction_agent: Optional[str] = None
annotation: Optional[Any] = None
annotation_agent: Optional[str] = None
metadata: Optional[Any] = None
status: Optional[str] = None

with pytest.warns(
UserWarning, match="`prediction_agent` will not be logged to the server."
):
MockRecord(prediction_agent="mock")
with pytest.warns(
UserWarning, match="`annotation_agent` will not be logged to the server."
):
MockRecord(annotation_agent="mock")


def test_forbid_extra():
class MockRecord(_RootValidators):
prediction: Optional[Any] = None
prediction_agent: Optional[str] = None
annotation: Optional[Any] = None
annotation_agent: Optional[str] = None
metadata: Optional[Any] = None
status: Optional[str] = None

with pytest.raises(ValidationError):
MockRecord(extra="mock")

0 comments on commit 6c1ae7f

Please sign in to comment.