Skip to content

Commit

Permalink
feat(Client): validate token classification annotations in client (#1709
Browse files Browse the repository at this point in the history
)

Closes #1579

(cherry picked from commit 6d82717)

fix: Improve misaligned error message
  • Loading branch information
frascuchon committed Oct 5, 2022
1 parent d7cc006 commit 936d1ca
Show file tree
Hide file tree
Showing 16 changed files with 627 additions and 296 deletions.
2 changes: 1 addition & 1 deletion src/rubrix/__init__.py
Expand Up @@ -24,7 +24,7 @@
from rubrix.logging import configure_logging as _configure_logging

from . import _version
from .utils import _LazyRubrixModule
from .utils import LazyRubrixModule as _LazyRubrixModule

__version__ = _version.version

Expand Down
12 changes: 6 additions & 6 deletions src/rubrix/client/datasets.py
Expand Up @@ -28,6 +28,7 @@
TokenClassificationRecord,
)
from rubrix.client.sdk.datasets.models import TaskType
from rubrix.utils.span_utils import SpanUtils

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -877,12 +878,11 @@ def _prepare_for_training_with_transformers(self):
class_tags = datasets.ClassLabel(names=class_tags)

def spans2iob(example):
r = TokenClassificationRecord(
text=example["text"],
tokens=example["tokens"],
annotation=self.__entities_to_tuple__(example["annotation"]),
)
return class_tags.str2int(r.spans2iob(r.annotation))
span_utils = SpanUtils(example["text"], example["tokens"])
entity_spans = self.__entities_to_tuple__(example["annotation"])
tags = span_utils.to_tags(entity_spans)

return class_tags.str2int(tags)

ds = (
self.to_datasets()
Expand Down
209 changes: 83 additions & 126 deletions src/rubrix/client/models.py
Expand Up @@ -20,14 +20,14 @@
import datetime
import logging
import warnings
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.utils import limit_value_length
from rubrix.utils.span_utils import SpanUtils

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -295,8 +295,7 @@ class TokenClassificationRecord(_Validators):
metrics: Optional[Dict[str, Any]] = None
search_keywords: Optional[List[str]] = None

__chars2tokens__: Dict[int, int] = PrivateAttr(default=None)
__tokens2chars__: Dict[int, Tuple[int, int]] = PrivateAttr(default=None)
_span_utils: SpanUtils = PrivateAttr()

def __init__(
self,
Expand All @@ -320,52 +319,49 @@ def __init__(
text = " ".join(tokens)

super().__init__(text=text, tokens=tokens, **data)

self._span_utils = SpanUtils(self.text, self.tokens)

if self.annotation:
self.annotation = self._validate_spans(self.annotation)
if self.prediction:
self.prediction = self._validate_spans(self.prediction)

if self.annotation and tags:
_LOGGER.warning("Annotation already provided, `tags` won't be used")
return
if tags:
self.annotation = self.__tags2entities__(tags)

def __tags2entities__(self, tags: List[str]) -> List[Tuple[str, int, int]]:
idx = 0
entities = []
entity_starts = False
while idx < len(tags):
tag = tags[idx]
if tag == "O":
entity_starts = False
if tag != "O":
prefix, entity = tag.split("-")
if prefix in ["B", "U"]:
if prefix == "B":
entity_starts = True
char_start, char_end = self.token_span(token_idx=idx)
entities.append(
{"entity": entity, "start": char_start, "end": char_end + 1}
)
elif prefix in ["I", "L"]:
if not entity_starts:
_LOGGER.warning(
"Detected non-starting tag and first entity token was not found."
f"Assuming {tag} as first entity token"
)
entity_starts = True
char_start, char_end = self.token_span(token_idx=idx)
entities.append(
{"entity": entity, "start": char_start, "end": char_end + 1}
)

_, char_end = self.token_span(token_idx=idx)
entities[-1]["end"] = char_end + 1
idx += 1
return [(value["entity"], value["start"], value["end"]) for value in entities]
elif tags:
self.annotation = self._span_utils.from_tags(tags)

def __setattr__(self, name: str, value: Any):
"""Make text and tokens immutable"""
if name in ["text", "tokens"]:
raise AttributeError(f"You cannot assign a new value to `{name}`")
super().__setattr__(name, value)

def _validate_spans(
self, spans: List[Tuple[str, int, int]]
) -> List[Tuple[str, int, int]]:
"""Validates the entity spans with respect to the tokens.
If necessary, also performs an automatic correction of the spans.
Args:
spans: The entity spans to validate.
Returns:
The optionally corrected spans.
Raises:
ValidationError: If spans are not valid or misaligned.
"""
try:
self._span_utils.validate(spans)
except ValueError:
spans = self._span_utils.correct(spans)
self._span_utils.validate(spans)

return spans

@validator("tokens", pre=True)
def _normalize_tokens(cls, value):
if isinstance(value, list):
Expand All @@ -375,7 +371,7 @@ def _normalize_tokens(cls, value):
return value

@validator("prediction")
def add_default_score(
def _add_default_score(
cls,
prediction: Optional[
List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]]
Expand All @@ -391,103 +387,64 @@ def add_default_score(
for pred in prediction
]

@staticmethod
def __build_indices_map__(
text: str, tokens: Tuple[str, ...]
) -> Tuple[Dict[int, int], Dict[int, Tuple[int, int]]]:
"""
Build the indices mapping between text characters and tokens where belongs to,
and vice versa.
chars2tokens index contains is the token idx where i char is contained (if any).
Out-of-token characters won't be included in this map,
so access should be using ``chars2tokens_map.get(i)``
instead of ``chars2tokens_map[i]``.
"""

def chars2tokens_index(text_, tokens_):
chars_map = {}
current_token = 0
current_token_char_start = 0
for idx, char in enumerate(text_):
relative_idx = idx - current_token_char_start
if (
relative_idx < len(tokens_[current_token])
and char == tokens_[current_token][relative_idx]
):
chars_map[idx] = current_token
elif (
current_token + 1 < len(tokens_)
and relative_idx >= len(tokens_[current_token])
and char == tokens_[current_token + 1][0]
):
current_token += 1
current_token_char_start += relative_idx
chars_map[idx] = current_token
return chars_map

def tokens2chars_index(
chars2tokens: Dict[int, int]
) -> Dict[int, Tuple[int, int]]:
tokens2chars_map = defaultdict(list)
for c, t in chars2tokens.items():
tokens2chars_map[t].append(c)

return {
token_idx: (min(chars), max(chars))
for token_idx, chars in tokens2chars_map.items()
}

chars2tokens_idx = chars2tokens_index(text_=text, tokens_=tokens)
return chars2tokens_idx, tokens2chars_index(chars2tokens_idx)
@validator("text")
def _check_if_empty_after_strip(cls, text: str):
assert text.strip(), "The provided `text` contains only whitespaces."
return text

@property
def __chars2tokens__(self) -> Dict[int, int]:
"""DEPRECATED, please use the ``rubrix.utils.span_utils.SpanUtils.chars_to_token_idx`` attribute."""
warnings.warn(
"The `__chars2tokens__` attribute is deprecated and will be removed in a future version. "
"Please use the `rubrix.utils.span_utils.SpanUtils.char_to_token_idx` attribute instead.",
FutureWarning,
)
return self._span_utils.char_to_token_idx

@property
def __tokens2chars__(self) -> Dict[int, Tuple[int, int]]:
"""DEPRECATED, please use the ``rubrix.utils.span_utils.SpanUtils.chars_to_token_idx`` attribute."""
warnings.warn(
"The `__tokens2chars__` attribute is deprecated and will be removed in a future version. "
"Please use the `rubrix.utils.span_utils.SpanUtils.token_to_char_idx` attribute instead.",
FutureWarning,
)
return self._span_utils.token_to_char_idx

def char_id2token_id(self, char_idx: int) -> Optional[int]:
"""
Given a character id, returns the token id it belongs to.
``None`` otherwise
"""

if self.__chars2tokens__ is None:
self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__(
self.text, tuple(self.tokens)
)
return self.__chars2tokens__.get(char_idx)
"""DEPRECATED, please use the ``rubrix.utisl.span_utils.SpanUtils.char_to_token_idx`` dict instead."""
warnings.warn(
"The `char_id2token_id` method is deprecated and will be removed in a future version. "
"Please use the `rubrix.utils.span_utils.SpanUtils.char_to_token_idx` dict instead.",
FutureWarning,
)
return self._span_utils.char_to_token_idx.get(char_idx)

def token_span(self, token_idx: int) -> Tuple[int, int]:
"""
Given a token id, returns the start and end characters.
Raises an ``IndexError`` if token id is out of tokens list indices
"""
if self.__tokens2chars__ is None:
self.__chars2tokens__, self.__tokens2chars__ = self.__build_indices_map__(
self.text, tuple(self.tokens)
)
if token_idx not in self.__tokens2chars__:
"""DEPRECATED, please use the ``rubrix.utisl.span_utils.SpanUtils.token_to_char_idx`` dict instead."""
warnings.warn(
"The `token_span` method is deprecated and will be removed in a future version. "
"Please use the `rubrix.utils.span_utils.SpanUtils.token_to_char_idx` dict instead.",
FutureWarning,
)
if token_idx not in self._span_utils.token_to_char_idx:
raise IndexError(f"Token id {token_idx} out of bounds")
return self.__tokens2chars__[token_idx]
return self._span_utils.token_to_char_idx[token_idx]

def spans2iob(
self, spans: Optional[List[Tuple[str, int, int]]] = None
) -> Optional[List[str]]:
"""Build the iob tags sequence for a list of spans annoations"""
"""DEPRECATED, please use the ``rubrix.utils.SpanUtils.to_tags()`` method."""
warnings.warn(
"'spans2iob' is deprecated and will be removed in a future version. "
"Please use the `rubrix.utils.SpanUtils.to_tags()` method instead, and adapt your code accordingly.",
FutureWarning,
)

if spans is None:
return None

tags = ["O"] * len(self.tokens)
for label, start, end in spans:
token_start = self.char_id2token_id(start)
token_end = self.char_id2token_id(end - 1)
assert (
token_start is not None and token_end is not None
), "Provided spans are missaligned at token level"
tags[token_start] = f"B-{label}"
for idx in range(token_start + 1, token_end + 1):
tags[idx] = f"I-{label}"

return tags
return self._span_utils.to_tags(spans)


class Text2TextRecord(_Validators):
Expand Down
2 changes: 2 additions & 0 deletions src/rubrix/server/apis/v0/models/token_classification.py
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, root_validator, validator
Expand All @@ -35,6 +36,7 @@
from rubrix.server.services.tasks.token_classification.model import (
ServiceTokenClassificationDataset,
)
from rubrix.utils import SpanUtils


class TokenClassificationAnnotation(_TokenClassificationAnnotation):
Expand Down

0 comments on commit 936d1ca

Please sign in to comment.