Skip to content

Commit

Permalink
feat: Using new top_k_mentions metrics instead of `entity_consisten…
Browse files Browse the repository at this point in the history
…cy` (#1880)

Closes #1834
  • Loading branch information
frascuchon committed Nov 15, 2022
1 parent 28f3bcd commit 42f702d
Show file tree
Hide file tree
Showing 9 changed files with 1,128 additions and 326 deletions.
1,269 changes: 999 additions & 270 deletions docs/_source/guides/features/metrics.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions docs/_source/guides/steps/4_monitoring.ipynb
Expand Up @@ -4140,9 +4140,9 @@
}
],
"source": [
"from argilla.metrics.token_classification import entity_consistency\n",
"from argilla.metrics.token_classification import top_k_mentions\n",
"\n",
"entity_consistency(name=\"spacy_sm_wnut17\", mentions=5000, threshold=2).visualize()\n"
"top_k_mentions(name=\"spacy_sm_wnut17\", k=5000, threshold=2).visualize()\n"
]
},
{
Expand Down Expand Up @@ -9407,11 +9407,11 @@
}
],
"source": [
"from argilla.metrics.token_classification import entity_consistency\n",
"from argilla.metrics.token_classification import top_k_mentions\n",
"from argilla.metrics.token_classification.metrics import Annotations\n",
"\n",
"entity_consistency(\n",
" name=\"conll2002_es\", mentions=30, threshold=4, compute_for=Annotations\n",
"top_k_mentions(\n",
" name=\"conll2002_es\", k=30, threshold=4, compute_for=Annotations\n",
").visualize()\n"
]
},
Expand Down
1 change: 1 addition & 0 deletions src/argilla/metrics/models.py
Expand Up @@ -18,6 +18,7 @@
from pydantic import BaseModel, PrivateAttr


# TODO(@frascuchon): Define as dataclasses.dataclass
class MetricSummary(BaseModel):
"""THe metric summary result data model"""

Expand Down
1 change: 1 addition & 0 deletions src/argilla/metrics/token_classification/__init__.py
Expand Up @@ -26,4 +26,5 @@
token_frequency,
token_length,
tokens_length,
top_k_mentions,
)
74 changes: 53 additions & 21 deletions src/argilla/metrics/token_classification/metrics.py
Expand Up @@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from enum import Enum
from typing import Optional, Union
from typing import Optional, Set, Union

import deprecated

from argilla.client import api
from argilla.metrics import helpers
Expand Down Expand Up @@ -359,14 +361,28 @@ def entity_capitalness(
)


def entity_consistency(
@deprecated.deprecated(reason="Use `top_k_mentions` instead")
def entity_consistency(*args, **kwargs):
message = "This function is not used anymore.\nYou should use the top_k_mentions function instead"
warnings.warn(
message=message,
category=DeprecationWarning,
)
return MetricSummary.new_summary(
data={},
visualization=lambda: helpers.empty_visualization(),
)


def top_k_mentions(
name: str,
query: Optional[str] = None,
compute_for: Union[str, ComputeFor] = Predictions,
mentions: int = 100,
k: int = 100,
threshold: int = 2,
post_label_filter: Optional[Set[str]] = None,
):
"""Computes the consistency for top entity mentions in the dataset.
"""Computes the consistency for top k mentions in the dataset.
Entity consistency defines the label variability for a given mention. For example, a mention `first` identified
in the whole dataset as `Cardinal`, `Person` and `Time` is less consistent than a mention `Peter` identified as
Expand All @@ -378,41 +394,57 @@ def entity_consistency(
`query string syntax <https://argilla.readthedocs.io/en/stable/guides/queries.html>`_
compute_for: Metric can be computed for annotations or predictions. Accepted values are
``Annotations`` and ``Predictions``. Default to ``Predictions``
mentions: The number of top mentions to retrieve.
threshold: The entity variability threshold (must be greater or equal to 2).
k: The number of mentions to retrieve.
threshold: The entity variability threshold (must be greater or equal to 1).
post_label_filter: A set of labels used for filtering the results. This filter may affect to the expected
number of mentions
Returns:
The summary entity capitalness distribution
The summary top k mentions distribution
Examples:
>>> from argilla.metrics.token_classification import entity_consistency
>>> summary = entity_consistency(name="example-dataset")
>>> summary = top_k_mentions(name="example-dataset")
>>> summary.visualize()
"""
if threshold < 2:
# TODO: Warning???
threshold = 2

threshold = max(1, threshold)
metric = api.active_api().compute_metric(
name,
metric=f"{_check_compute_for(compute_for)}_entity_consistency",
metric=f"{_check_compute_for(compute_for)}_top_k_mentions_consistency",
query=query,
size=mentions,
size=k,
interval=threshold,
)
mentions = [mention["mention"] for mention in metric.results["mentions"]]
entities = {}

filtered_mentions, mention_values = [], []
for mention in metric.results["mentions"]:
entities = mention["entities"]
if post_label_filter:
entities = [
entity for entity in entities if entity["label"] in post_label_filter
]
if entities:
mention["entities"] = entities
filtered_mentions.append(mention)
mention_values.append(mention["mention"])

entities = {}
for mention in filtered_mentions:
for entity in mention["entities"]:
mentions_for_label = entities.get(entity["label"], [0] * len(mentions))
mentions_for_label[mentions.index(mention["mention"])] = entity["count"]
entities[entity["label"]] = mentions_for_label
label = entity["label"]
mentions_for_label = entities.get(label, [0] * len(filtered_mentions))
mentions_for_label[mention_values.index(mention["mention"])] = entity[
"count"
]
entities[label] = mentions_for_label

return MetricSummary.new_summary(
data=metric.results,
data={"mentions": filtered_mentions},
visualization=lambda: helpers.stacked_bar(
x=mentions, y_s=entities, title=metric.description
x=mention_values,
y_s=entities,
title=metric.description,
),
)

Expand Down
34 changes: 23 additions & 11 deletions src/argilla/server/daos/backend/metrics/token_classification.py
Expand Up @@ -30,28 +30,32 @@


@dataclasses.dataclass
class EntityConsistency(NestedPathElasticsearchMetric):
class TopKMentionsConsistency(NestedPathElasticsearchMetric):
"""Computes the entity consistency distribution"""

mention_field: str
labels_field: str
chars_length_field: str
tokens_length_field: str

def _inner_aggregation(
self,
size: int,
interval: int = 2,
interval: int = 1,
entity_size: int = _DEFAULT_MAX_ENTITY_BUCKET,
) -> Dict[str, Any]:
size = size or 50
interval = int(max(interval or 2, 2))
interval = interval or 1
return {
"consistency": {
**aggregations.terms_aggregation(
self.compound_nested_field(self.mention_field), size=size
self.compound_nested_field(self.mention_field),
size=size,
),
"aggs": {
"entities": aggregations.terms_aggregation(
self.compound_nested_field(self.labels_field), size=entity_size
self.compound_nested_field(self.labels_field),
size=entity_size,
),
"count": {
"cardinality": {
Expand Down Expand Up @@ -201,28 +205,36 @@ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, An
id="bi-dimensional", field_x="label", field_y="value"
),
),
"predicted_entity_consistency": EntityConsistency(
id="predicted_entity_consistency",
"predicted_top_k_mentions_consistency": TopKMentionsConsistency(
id="predicted_top_k_mentions_consistency",
nested_path="metrics.predicted.mentions",
mention_field="value",
labels_field="label",
chars_length_field="chars_length",
tokens_length_field="tokens_length",
),
"annotated_entity_consistency": EntityConsistency(
id="annotated_entity_consistency",
"annotated_top_k_mentions_consistency": TopKMentionsConsistency(
id="annotated_top_k_mentions_consistency",
nested_path="metrics.annotated.mentions",
mention_field="value",
labels_field="label",
chars_length_field="chars_length",
tokens_length_field="tokens_length",
),
"predicted_tag_consistency": EntityConsistency(
"predicted_tag_consistency": TopKMentionsConsistency(
id="predicted_tag_consistency",
nested_path="metrics.predicted.tags",
mention_field="value",
labels_field="tag",
chars_length_field="chars_length",
tokens_length_field="tokens_length",
),
"annotated_tag_consistency": EntityConsistency(
"annotated_tag_consistency": TopKMentionsConsistency(
id="annotated_tag_consistency",
nested_path="metrics.annotated.tags",
mention_field="value",
labels_field="tag",
chars_length_field="chars_length",
tokens_length_field="tokens_length",
),
}
Expand Up @@ -217,7 +217,8 @@ def capitalness(value: str) -> Optional[str]:

@staticmethod
def mentions_metrics(
record: ServiceTokenClassificationRecord, mentions: List[Tuple[str, EntitySpan]]
record: ServiceTokenClassificationRecord,
mentions: List[Tuple[str, EntitySpan]],
):
def mention_tokens_length(entity: EntitySpan) -> int:
"""Compute mention tokens length"""
Expand Down Expand Up @@ -252,7 +253,9 @@ def mention_tokens_length(entity: EntitySpan) -> int:

@classmethod
def build_tokens_metrics(
cls, record: ServiceTokenClassificationRecord, tags: Optional[List[str]] = None
cls,
record: ServiceTokenClassificationRecord,
tags: Optional[List[str]] = None,
) -> List[TokenMetrics]:

return [
Expand Down Expand Up @@ -393,7 +396,7 @@ def _compute_iob_tags(
description="Computes predicted mentions distribution against its labels",
),
ServiceBaseMetric(
id="predicted_entity_consistency",
id="predicted_top_k_mentions_consistency",
name="Entity label consistency for predictions",
description="Computes entity label variability for top-k predicted entity mentions",
),
Expand Down Expand Up @@ -433,7 +436,7 @@ def _compute_iob_tags(
description="Computes annotated mentions distribution against its labels",
),
ServiceBaseMetric(
id="annotated_entity_consistency",
id="annotated_top_k_mentions_consistency",
name="Entity label consistency for annotations",
description="Computes entity label variability for top-k annotated entity mentions",
),
Expand Down
4 changes: 3 additions & 1 deletion tests/functional_tests/test_log_for_token_classification.py
Expand Up @@ -19,6 +19,7 @@
from argilla.client import api
from argilla.client.sdk.commons.errors import NotFoundApiError
from argilla.metrics import __all__ as ALL_METRICS
from argilla.metrics import entity_consistency


def test_log_with_empty_text(mocked_client):
Expand Down Expand Up @@ -53,9 +54,10 @@ def test_log_with_empty_tokens_list(mocked_client):
def test_call_metrics_with_no_api_client_initialized(mocked_client):

for metric in ALL_METRICS:
if metric == entity_consistency:
continue

api.__ACTIVE_API__ = None

with pytest.raises(NotFoundApiError):
metric("not_found")

Expand Down
50 changes: 36 additions & 14 deletions tests/metrics/test_token_classification.py
Expand Up @@ -16,10 +16,10 @@

import argilla
import argilla as ar
from argilla.metrics import entity_consistency
from argilla.metrics.token_classification import (
Annotations,
entity_capitalness,
entity_consistency,
entity_density,
entity_labels,
f1,
Expand All @@ -28,6 +28,7 @@
token_frequency,
token_length,
tokens_length,
top_k_mentions,
)


Expand Down Expand Up @@ -215,14 +216,12 @@ def test_entity_capitalness(mocked_client):
results.visualize()


def test_entity_consistency(mocked_client):
dataset = "test_entity_consistency"
def test_top_k_mentions_consistency(mocked_client):
dataset = "test_top_k_mentions_consistency"
argilla.delete(dataset)
log_some_data(dataset)

results = entity_consistency(dataset, threshold=2)
assert results
assert results.data == {
mentions = {
"mentions": [
{
"mention": "first",
Expand All @@ -234,29 +233,52 @@ def test_entity_consistency(mocked_client):
}
]
}
results.visualize()

results = entity_consistency(dataset, compute_for=Annotations, threshold=2)
assert results
assert results.data == {
filtered_mentions = {
"mentions": [
{
"mention": "first",
"entities": [
{"count": 2, "label": "CARDINAL"},
{"count": 1, "label": "NUMBER"},
{"count": 1, "label": "PERSON"},
],
}
]
}
validate_mentions(
dataset=dataset,
expected_mentions=mentions,
)

validate_mentions(
dataset=dataset,
compute_for=Annotations,
threshold=2,
expected_mentions=mentions,
)

validate_mentions(
dataset=dataset,
post_label_filter={"NUMBER"},
expected_mentions=filtered_mentions,
)


def validate_mentions(
*,
dataset: str,
expected_mentions: dict,
**metric_args,
):
results = top_k_mentions(dataset, **metric_args)
assert results
assert results.data == expected_mentions
results.visualize()


@pytest.mark.parametrize(
("metric", "expected_results"),
[
(entity_consistency, {"mentions": []}),
(top_k_mentions, {"mentions": []}),
(entity_consistency, {}),
(mention_length, {}),
(entity_density, {}),
(entity_capitalness, {}),
Expand Down

0 comments on commit 42f702d

Please sign in to comment.