Skip to content

Commit

Permalink
feat: Enable metadata length field config by environment variable (#1923
Browse files Browse the repository at this point in the history
)

Closes #1761
  • Loading branch information
frascuchon committed Nov 22, 2022
1 parent 53a57f7 commit 0ff2de7
Show file tree
Hide file tree
Showing 16 changed files with 115 additions and 51 deletions.
Expand Up @@ -23,7 +23,7 @@ You can set following environment variables to further configure your server and

### Server

- `ELASTICSEARCH`: URL of the connection endpoint of the Elasticsearch instance (Default: `http://localhost:9200`).
- `ARGILLA_ELASTICSEARCH`: URL of the connection endpoint of the Elasticsearch instance (Default: `http://localhost:9200`).

- `ARGILLA_ELASTICSEARCH_SSL_VERIFY`: If "False", disables SSL certificate verification when connection to the Elasticsearch backend.

Expand All @@ -35,7 +35,9 @@ You can set following environment variables to further configure your server and

- `ARGILLA_EXACT_ES_SEARCH_ANALYZER`: Default analyzer for `*.exact` fields in textual information (Default: "whitespace").

- `METADATA_FIELDS_LIMIT`: Max number of fields in the metadata (Default: 50, max: 100).
- `ARGILLA_METADATA_FIELDS_LIMIT`: Max number of fields in the metadata (Default: 50, max: 100).

- `ARGILLA_METADATA_FIELD_LENGTH`: Max length supported for the string metadata fields. Higher values will be truncated. Abusing this may lead to Elastic performance issues (Default: 128).

- `CORS_ORIGINS`: List of host patterns for CORS origin access.

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/_constants.py
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

MAX_KEYWORD_LENGTH = 128
DEFAULT_MAX_KEYWORD_LENGTH = 128


API_KEY_HEADER_NAME = "X-Argilla-Api-Key"
Expand Down
18 changes: 18 additions & 0 deletions src/argilla/_messages.py
@@ -0,0 +1,18 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ARGILLA_METADATA_FIELD_WARNING_MESSAGE = (
"You can configure this length in the server with the ARGILLA_METADATA_FIELD_LENGTH "
"environment variable. Note that, setting this too high may lead to Elastic performance issues."
)
30 changes: 20 additions & 10 deletions src/argilla/client/models.py
Expand Up @@ -26,8 +26,8 @@
from deprecated import deprecated
from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator

from argilla._constants import MAX_KEYWORD_LENGTH
from argilla.utils import limit_value_length
from argilla import _messages
from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.utils.span_utils import SpanUtils

_LOGGER = logging.getLogger(__name__)
Expand All @@ -37,16 +37,26 @@ class _Validators(BaseModel):
"""Base class for our record models that takes care of general validations"""

@validator("metadata", check_fields=False)
def _check_value_length(cls, v):
"""Checks metadata values length and apply value truncation for large values"""
new_metadata = limit_value_length(v, max_length=MAX_KEYWORD_LENGTH)
if new_metadata != v:
warnings.warn(
"Some metadata values exceed the max length. Those values will be"
f" truncated by keeping only the last {MAX_KEYWORD_LENGTH} characters."
def _check_value_length(cls, metadata):
"""Checks metadata values length and warn message for large values"""
if not metadata:
return metadata

default_length_exceeded = False
for v in metadata.values():
if isinstance(v, str) and len(v) > DEFAULT_MAX_KEYWORD_LENGTH:
default_length_exceeded = True
break

if default_length_exceeded:
message = (
"Some metadata values could exceed the max length. For those cases, values will be"
f" truncated by keeping only the last {DEFAULT_MAX_KEYWORD_LENGTH} characters. "
+ _messages.ARGILLA_METADATA_FIELD_WARNING_MESSAGE
)
warnings.warn(message, UserWarning)

return new_metadata
return metadata

@validator("metadata", check_fields=False)
def _none_to_empty_dict(cls, v):
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/sdk/token_classification/models.py
Expand Up @@ -17,7 +17,7 @@

from pydantic import BaseModel, Field, validator

from argilla._constants import MAX_KEYWORD_LENGTH
from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.client.models import (
TokenClassificationRecord as ClientTokenClassificationRecord,
)
Expand All @@ -35,7 +35,7 @@
class EntitySpan(BaseModel):
start: int
end: int
label: str = Field(min_length=1, max_length=MAX_KEYWORD_LENGTH)
label: str = Field(min_length=1, max_length=DEFAULT_MAX_KEYWORD_LENGTH)
score: float = Field(default=1.0, ge=0.0, le=1.0)


Expand Down
3 changes: 2 additions & 1 deletion src/argilla/server/daos/backend/elasticsearch.py
Expand Up @@ -930,7 +930,8 @@ def _configure_metadata_fields(self, id: str, metadata_values: Dict[str, Any]):
def check_metadata_length(metadata_length: int = 0):
if metadata_length > settings.metadata_fields_limit:
raise MetadataLimitExceededError(
length=metadata_length, limit=settings.metadata_fields_limit
length=metadata_length,
limit=settings.metadata_fields_limit,
)

def detect_nested_type(v: Any) -> bool:
Expand Down
15 changes: 8 additions & 7 deletions src/argilla/server/daos/backend/mappings/helpers.py
Expand Up @@ -14,7 +14,6 @@

from typing import Any, Dict, List

from argilla._constants import MAX_KEYWORD_LENGTH
from argilla.server.settings import settings

EXTENDED_ANALYZER_REF = "extended_analyzer"
Expand All @@ -26,12 +25,12 @@

class mappings:
@staticmethod
def keyword_field(enable_text_search: bool = False):
def keyword_field(
enable_text_search: bool = False,
):
"""Mappings config for keyword field"""
mapping = {
"type": "keyword",
# TODO: Use environment var and align with fields validators
"ignore_above": MAX_KEYWORD_LENGTH,
}
if enable_text_search:
text_field = mappings.text_field()
Expand All @@ -41,14 +40,15 @@ def keyword_field(enable_text_search: bool = False):

@staticmethod
def path_match_keyword_template(
path: str, enable_text_search_in_keywords: bool = False
path: str,
enable_text_search_in_keywords: bool = False,
):
"""Dynamic template mappings config for keyword field based on path match"""
return {
"path_match": path,
"match_mapping_type": "string",
"mapping": mappings.keyword_field(
enable_text_search=enable_text_search_in_keywords
enable_text_search=enable_text_search_in_keywords,
),
}

Expand Down Expand Up @@ -167,7 +167,8 @@ def dynamic_metrics_text():
def dynamic_metadata_text():
return {
"metadata.*": mappings.path_match_keyword_template(
path="metadata.*", enable_text_search_in_keywords=True
path="metadata.*",
enable_text_search_in_keywords=True,
)
}

Expand Down
16 changes: 14 additions & 2 deletions src/argilla/server/daos/models/records.py
Expand Up @@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from datetime import datetime
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from uuid import uuid4

from pydantic import BaseModel, Field, root_validator, validator
from pydantic.generics import GenericModel

from argilla._constants import MAX_KEYWORD_LENGTH
from argilla import _messages
from argilla.server.commons.models import PredictionStatus, TaskStatus, TaskType
from argilla.server.daos.backend.search.model import BackendRecordsQuery, SortConfig
from argilla.server.helpers import flatten_dict
from argilla.server.settings import settings
from argilla.utils import limit_value_length


Expand Down Expand Up @@ -138,7 +140,17 @@ def flatten_metadata(cls, metadata: Dict[str, Any]):
"""
if metadata:
metadata = flatten_dict(metadata, drop_empty=True)
metadata = limit_value_length(metadata, max_length=MAX_KEYWORD_LENGTH)
new_metadata = limit_value_length(
data=metadata,
max_length=settings.metadata_field_length,
)
message = (
"Some metadata values exceed the max length. Those values will be"
f" truncated by keeping only the last {settings.metadata_field_length} characters. "
+ _messages.ARGILLA_METADATA_FIELD_WARNING_MESSAGE
)
warnings.warn(message, UserWarning)
metadata = new_metadata
return metadata

@classmethod
Expand Down
Expand Up @@ -18,7 +18,7 @@

from pydantic import BaseModel, Field, root_validator, validator

from argilla._constants import MAX_KEYWORD_LENGTH
from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.server.commons.models import PredictionStatus, TaskStatus, TaskType
from argilla.server.helpers import flatten_dict
from argilla.server.services.datasets import ServiceBaseDataset
Expand Down Expand Up @@ -97,9 +97,9 @@ class ClassPrediction(BaseModel):
@validator("class_label")
def check_label_length(cls, class_label):
if isinstance(class_label, str):
assert 1 <= len(class_label) <= MAX_KEYWORD_LENGTH, (
f"Class name '{class_label}' exceeds max length of {MAX_KEYWORD_LENGTH}"
if len(class_label) > MAX_KEYWORD_LENGTH
assert 1 <= len(class_label) <= DEFAULT_MAX_KEYWORD_LENGTH, (
f"Class name '{class_label}' exceeds max length of {DEFAULT_MAX_KEYWORD_LENGTH}"
if len(class_label) > DEFAULT_MAX_KEYWORD_LENGTH
else f"Class name must not be empty"
)
return class_label
Expand Down
Expand Up @@ -12,14 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple

from pydantic import BaseModel, Field, validator

from argilla._constants import MAX_KEYWORD_LENGTH
from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.server.commons.models import PredictionStatus, TaskType
from argilla.server.services.datasets import ServiceBaseDataset
from argilla.server.services.search.model import (
Expand Down Expand Up @@ -57,7 +54,7 @@ class EntitySpan(BaseModel):

start: int
end: int
label: str = Field(min_length=1, max_length=MAX_KEYWORD_LENGTH)
label: str = Field(min_length=1, max_length=DEFAULT_MAX_KEYWORD_LENGTH)
score: float = Field(default=1.0, ge=0.0, le=1.0)

@validator("end")
Expand Down
14 changes: 12 additions & 2 deletions src/argilla/server/settings.py
Expand Up @@ -20,7 +20,9 @@
from typing import List, Optional
from urllib.parse import urlparse

from pydantic import BaseSettings, Field, validator
from pydantic import BaseSettings, Field

from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH


class ApiSettings(BaseSettings):
Expand Down Expand Up @@ -82,7 +84,15 @@ class ApiSettings(BaseSettings):
es_records_index_replicas: int = 0

metadata_fields_limit: int = Field(
default=50, gt=0, le=100, description="Max number of fields in metadata"
default=50,
gt=0,
le=100,
description="Max number of fields in metadata",
)
metadata_field_length: int = Field(
default=DEFAULT_MAX_KEYWORD_LENGTH,
description="Max length supported for the string metadata fields."
" Values containing higher than this will be truncated",
)

enable_telemetry: bool = True
Expand Down
5 changes: 4 additions & 1 deletion tests/client/test_dataset.py
Expand Up @@ -260,7 +260,10 @@ def test_to_from_datasets(self, records, request):
"metrics",
]
assert dataset_ds.features["prediction"] == [
{"label": datasets.Value("string"), "score": datasets.Value("float64")}
{
"label": datasets.Value("string"),
"score": datasets.Value("float64"),
}
]

dataset = ar.DatasetForTextClassification.from_datasets(dataset_ds)
Expand Down
26 changes: 17 additions & 9 deletions tests/client/test_models.py
Expand Up @@ -21,7 +21,7 @@
import pytest
from pydantic import ValidationError

from argilla._constants import MAX_KEYWORD_LENGTH
from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.client.models import (
Text2TextRecord,
TextClassificationRecord,
Expand Down Expand Up @@ -202,20 +202,28 @@ def test_token_classification_prediction_validator(prediction, expected):
def test_text_classification_record_none_inputs():
"""Test validation error for None in inputs"""
with pytest.raises(ValidationError):
TextClassificationRecord(inputs={"text": None})
TextClassificationRecord.parse_obj(dict(inputs={"text": None}))


def test_metadata_values_length():
text = "oh yeah!"
metadata = {"too_long": "a" * 200}
expected_length = 200
metadata = {"too_long": "a" * expected_length}

record = TextClassificationRecord(inputs={"text": text}, metadata=metadata)
assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH
with pytest.warns(expected_warning=UserWarning):
record = TextClassificationRecord(
inputs={"text": text},
metadata=metadata,
)
assert len(record.metadata["too_long"]) == expected_length

record = TokenClassificationRecord(
text=text, tokens=text.split(), metadata=metadata
)
assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH
with pytest.warns(expected_warning=UserWarning):
record = TokenClassificationRecord(
text=text,
tokens=text.split(),
metadata=metadata,
)
assert len(record.metadata["too_long"]) == expected_length


def test_model_serialization_with_numpy_nan():
Expand Down
Expand Up @@ -100,7 +100,9 @@ def test_delete_records_with_unmatched_records(mocked_client):
name=dataset,
records=[
ar.TextClassificationRecord(
id=i, text="This is the text", metadata=dict(idx=i)
id=i,
text="This is the text",
metadata=dict(idx=i),
)
for i in range(0, 50)
],
Expand Down
4 changes: 2 additions & 2 deletions tests/server/text_classification/test_model.py
Expand Up @@ -15,7 +15,7 @@
import pytest
from pydantic import ValidationError

from argilla._constants import MAX_KEYWORD_LENGTH
from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.server.apis.v0.models.text_classification import (
TextClassificationAnnotation,
TextClassificationQuery,
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_too_long_metadata():
}
)

assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH
assert len(record.metadata["too_long"]) == DEFAULT_MAX_KEYWORD_LENGTH


def test_too_long_label():
Expand Down
4 changes: 2 additions & 2 deletions tests/server/token_classification/test_model.py
Expand Up @@ -16,7 +16,7 @@
import pytest
from pydantic import ValidationError

from argilla._constants import MAX_KEYWORD_LENGTH
from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.server.apis.v0.models.token_classification import (
TokenClassificationAnnotation,
TokenClassificationQuery,
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_too_long_metadata():
}
)

assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH
assert len(record.metadata["too_long"]) == DEFAULT_MAX_KEYWORD_LENGTH


def test_entity_label_too_long():
Expand Down

0 comments on commit 0ff2de7

Please sign in to comment.