Skip to content

Commit

Permalink
fix(datasets): prevent error when no annotated records found in datas…
Browse files Browse the repository at this point in the history
…et (#1284)

* fix(datasets): prevent error when no annotated records found in dataset

* test: improve coverage

(cherry picked from commit d55a9fa)
  • Loading branch information
frascuchon committed Mar 30, 2022
1 parent adcf1b1 commit c20028f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/rubrix/client/datasets.py
Expand Up @@ -744,6 +744,15 @@ def prepare_for_training(self) -> "datasets.Dataset":
"""
import datasets

has_annotations = False
for rec in self._records:
if rec.annotation is not None:
has_annotations = True
break

if not has_annotations:
return datasets.Dataset.from_dict({})

class_tags = ["O"]
class_tags.extend(
[
Expand All @@ -765,12 +774,11 @@ def spans2iob(example):
ds = (
self.to_datasets()
.filter(self.__only_annotations__)
.map(
lambda example: {"ner_tags": spans2iob(example)},
)
.map(lambda example: {"ner_tags": spans2iob(example)})
)
new_features = ds.features.copy()
new_features["ner_tags"] = [class_tags]

return ds.cast(new_features)

def __all_labels__(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/client/test_dataset.py
Expand Up @@ -462,6 +462,12 @@ def test_from_to_datasets_id(self):

assert rb.read_datasets(dataset_ds, task="TokenClassification")[0].id is None

def test_prepare_for_training_empty(self):
dataset = rb.DatasetForTokenClassification(
[rb.TokenClassificationRecord(text="mock", tokens=["mock"])]
)
assert len(dataset.prepare_for_training()) == 0

def test_datasets_empty_metadata(self):
dataset = rb.DatasetForTokenClassification(
[rb.TokenClassificationRecord(text="mock", tokens=["mock"])]
Expand Down

0 comments on commit c20028f

Please sign in to comment.