Skip to content

Commit

Permalink
feat(#1225): prepare tokenclass dataset for hf training (#1231)
Browse files Browse the repository at this point in the history
* feat(#1225): prepare tokenclass dataset for hf training

* fix: optional search_keywords

* chore: add docstring

* Update src/rubrix/client/datasets.py

Co-authored-by: David Fidalgo <david@recogn.ai>

* Apply suggestions from code review

Co-authored-by: David Fidalgo <david@recogn.ai>

Co-authored-by: David Fidalgo <david@recogn.ai>
(cherry picked from commit 0935db7)
  • Loading branch information
frascuchon committed Mar 25, 2022
1 parent 8b2a07c commit ae5e7cd
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 13 deletions.
107 changes: 98 additions & 9 deletions src/rubrix/client/datasets.py
Expand Up @@ -124,6 +124,9 @@ def to_datasets(self) -> "datasets.Dataset":
import datasets

ds_dict = self._to_datasets_dict()
# TODO: THIS FIELD IS ONLY AT CLIENT API LEVEL. NOT SENSE HERE FOR NOW
if "search_keywords" in ds_dict:
del ds_dict["search_keywords"]

try:
dataset = datasets.Dataset.from_dict(ds_dict)
Expand Down Expand Up @@ -535,6 +538,88 @@ def from_pandas(
) -> "DatasetForTokenClassification":
return super().from_pandas(dataframe)

@_requires_datasets
def prepare_for_training(self) -> "datasets.Dataset":
"""Prepares the dataset for training.
This will return a ``datasets.Dataset`` with all columns returned by ``to_datasets`` method
and an additional *ner_tags* column:
- Records without an annotation are removed.
- The *ner_tags* column corresponds to the iob tags sequences for annotations of the records
- The iob tags are transformed to integers.
Returns:
A datasets Dataset with a *ner_tags* column and all columns returned by ``to_datasets``.
Examples:
>>> import rubrix as rb
>>> rb_dataset = rb.DatasetForTokenClassification([
... rb.TokenClassificationRecord(
... text="The text",
... tokens=["The", "text"],
... annotation=[("TAG", 0, 2)],
... )
... ])
>>> rb_dataset.prepare_for_training().features
{'text': Value(dtype='string'),
'tokens': Sequence(feature=Value(dtype='string'), length=-1),
'prediction': Value(dtype='null'),
'prediction_agent': Value(dtype='null'),
'annotation': [{'end': Value(dtype='int64'),
'label': Value(dtype='string'),
'start': Value(dtype='int64')}],
'annotation_agent': Value(dtype='null'),
'id': Value(dtype='null'),
'metadata': Value(dtype='null'),
'status': Value(dtype='string'),
'event_timestamp': Value(dtype='null'),
'metrics': Value(dtype='null'),
'ner_tags': [ClassLabel(num_classes=3, names=['O', 'B-TAG', 'I-TAG'])]}
"""
import datasets

class_tags = ["O"]
class_tags.extend(
[
f"{pre}-{label}"
for label in sorted(self.__all_labels__())
for pre in ["B", "I"]
]
)
class_tags = datasets.ClassLabel(names=class_tags)

def spans2iob(example):
r = TokenClassificationRecord(
text=example["text"],
tokens=example["tokens"],
annotation=self.__entities_to_tuple__(example["annotation"]),
)
return class_tags.str2int(r.spans2iob(r.annotation))

ds = (
self.to_datasets()
.filter(self.__only_annotations__)
.map(
lambda example: {"ner_tags": spans2iob(example)},
)
)
new_features = ds.features.copy()
new_features["ner_tags"] = [class_tags]
return ds.cast(new_features)

def __all_labels__(self):
all_labels = set()
for record in self._records:
if record.annotation:
all_labels.update([label for label, _, _ in record.annotation])

return list(all_labels)

def __only_annotations__(self, data) -> bool:
return data["annotation"] is not None

def _to_datasets_dict(self) -> Dict:
"""Helper method to put token classification records in a `datasets.Dataset`"""
# create a dict first, where we make the necessary transformations
Expand Down Expand Up @@ -573,24 +658,28 @@ def entities_to_dict(

return ds_dict

@staticmethod
def __entities_to_tuple__(
entities,
) -> List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]]:
return [
(ent["label"], ent["start"], ent["end"])
if len(ent) == 3
else (ent["label"], ent["start"], ent["end"], ent["score"] or 1.0)
for ent in entities
]

@classmethod
def _from_datasets(
cls, dataset: "datasets.Dataset"
) -> "DatasetForTokenClassification":
def entities_to_tuple(entities):
return [
(ent["label"], ent["start"], ent["end"])
if len(ent) == 3
else (ent["label"], ent["start"], ent["end"], ent["score"] or 1.0)
for ent in entities
]

records = []
for row in dataset:
if row.get("prediction"):
row["prediction"] = entities_to_tuple(row["prediction"])
row["prediction"] = cls.__entities_to_tuple__(row["prediction"])
if row.get("annotation"):
row["annotation"] = entities_to_tuple(row["annotation"])
row["annotation"] = cls.__entities_to_tuple__(row["annotation"])

records.append(TokenClassificationRecord(**row))

Expand Down
112 changes: 108 additions & 4 deletions tests/client/test_dataset.py
Expand Up @@ -21,7 +21,11 @@
import pytest

import rubrix as rb
from rubrix.client.datasets import DatasetBase, WrongRecordTypeError
from rubrix.client.datasets import (
DatasetBase,
DatasetForTokenClassification,
WrongRecordTypeError,
)
from rubrix.client.models import TextClassificationRecord

_HF_HUB_ACCESS_TOKEN = os.getenv("HF_HUB_ACCESS_TOKEN")
Expand Down Expand Up @@ -209,7 +213,20 @@ def test_to_from_datasets(self, records, request):
dataset_ds = expected_dataset.to_datasets()

assert isinstance(dataset_ds, datasets.Dataset)
assert dataset_ds.column_names == list(expected_dataset[0].__fields__.keys())
assert dataset_ds.column_names == [
"inputs",
"prediction",
"prediction_agent",
"annotation",
"annotation_agent",
"multi_label",
"explanation",
"id",
"metadata",
"status",
"event_timestamp",
"metrics",
]
assert dataset_ds.features["prediction"] == [
{"label": datasets.Value("string"), "score": datasets.Value("float64")}
]
Expand Down Expand Up @@ -325,7 +342,19 @@ def test_to_from_datasets(self, tokenclassification_records):
dataset_ds = expected_dataset.to_datasets()

assert isinstance(dataset_ds, datasets.Dataset)
assert dataset_ds.column_names == list(expected_dataset[0].__fields__.keys())
assert dataset_ds.column_names == [
"text",
"tokens",
"prediction",
"prediction_agent",
"annotation",
"annotation_agent",
"id",
"metadata",
"status",
"event_timestamp",
"metrics",
]
assert dataset_ds.features["prediction"] == [
{
"label": datasets.Value("string"),
Expand Down Expand Up @@ -402,6 +431,70 @@ def test_push_to_hub(self, tokenclassification_records):

assert isinstance(dataset_ds, datasets.Dataset)

@pytest.mark.skipif(
_HF_HUB_ACCESS_TOKEN is None,
reason="You need a HF Hub access token to test the push_to_hub feature",
)
def test_prepare_for_training(self):
ner_dataset = datasets.load_dataset(
"rubrix/gutenberg_spacy-ner",
use_auth_token=_HF_HUB_ACCESS_TOKEN,
split="train",
)
rb_dataset: DatasetForTokenClassification = rb.read_datasets(
ner_dataset, task="TokenClassification"
)
for r in rb_dataset:
r.annotation = [
(label, start, end) for label, start, end, _ in r.prediction
]

train = rb_dataset.prepare_for_training()
assert isinstance(train, datasets.Dataset)
assert "ner_tags" in train.column_names
assert len(train) == 100
assert train.features["ner_tags"] == [
datasets.ClassLabel(
names=[
"O",
"B-CARDINAL",
"I-CARDINAL",
"B-DATE",
"I-DATE",
"B-FAC",
"I-FAC",
"B-GPE",
"I-GPE",
"B-LANGUAGE",
"I-LANGUAGE",
"B-LOC",
"I-LOC",
"B-NORP",
"I-NORP",
"B-ORDINAL",
"I-ORDINAL",
"B-ORG",
"I-ORG",
"B-PERSON",
"I-PERSON",
"B-PRODUCT",
"I-PRODUCT",
"B-QUANTITY",
"I-QUANTITY",
"B-TIME",
"I-TIME",
"B-WORK_OF_ART",
"I-WORK_OF_ART",
]
)
]

train.push_to_hub(
"rubrix/_test_token_classification_training",
token=_HF_HUB_ACCESS_TOKEN,
private=True,
)


class TestDatasetForText2Text:
def test_init(self, text2text_records):
Expand All @@ -415,7 +508,18 @@ def test_to_from_datasets(self, text2text_records):
dataset_ds = expected_dataset.to_datasets()

assert isinstance(dataset_ds, datasets.Dataset)
assert dataset_ds.column_names == list(expected_dataset[0].__fields__.keys())
assert dataset_ds.column_names == [
"text",
"prediction",
"prediction_agent",
"annotation",
"annotation_agent",
"id",
"metadata",
"status",
"event_timestamp",
"metrics",
]
assert dataset_ds.features["prediction"] == [
{
"text": datasets.Value("string"),
Expand Down

0 comments on commit ae5e7cd

Please sign in to comment.