Skip to content

Commit

Permalink
fix(NER): create record annotation from tags (also in from_datasets) (#…
Browse files Browse the repository at this point in the history
…1283)

* fix(ner): build record annotation from tags

* fix(ner): parse tags in from_datasets method

(cherry picked from commit 65da06f)
  • Loading branch information
frascuchon committed Mar 30, 2022
1 parent 8052aa8 commit adcf1b1
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions src/rubrix/client/datasets.py
Expand Up @@ -66,6 +66,14 @@ class DatasetBase:

_RECORD_TYPE = None

@classmethod
def _record_init_args(cls) -> List[str]:
"""
Helper the returns the field list available for creation of inner records.
The ``_RECORD_TYPE.__fields__`` will be returned as default
"""
return [field for field in cls._RECORD_TYPE.__fields__]

def __init__(self, records: Optional[List[Record]] = None):
if self._RECORD_TYPE is None:
raise NotImplementedError(
Expand Down Expand Up @@ -185,9 +193,7 @@ def from_datasets(
)

not_supported_columns = [
col
for col in dataset.column_names
if col not in cls._RECORD_TYPE.__fields__
col for col in dataset.column_names if col not in cls._record_init_args()
]
if not_supported_columns:
_LOGGER.warning(
Expand Down Expand Up @@ -251,11 +257,12 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> "Dataset":
The imported records in a Rubrix Dataset.
"""
not_supported_columns = [
col for col in dataframe.columns if col not in cls._RECORD_TYPE.__fields__
col for col in dataframe.columns if col not in cls._record_init_args()
]
if not_supported_columns:
_LOGGER.warning(
f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model and are ignored: {not_supported_columns}"
f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model "
f"and are ignored: {not_supported_columns}"
)
dataframe = dataframe.drop(columns=not_supported_columns)

Expand Down Expand Up @@ -638,6 +645,12 @@ class DatasetForTokenClassification(DatasetBase):

_RECORD_TYPE = TokenClassificationRecord

@classmethod
def _record_init_args(cls) -> List[str]:
"""Adds the `tags` argument to default record init arguments"""
parent_fields = super(DatasetForTokenClassification, cls)._record_init_args()
return parent_fields + ["tags"] # compute annotation from tags

def __init__(self, records: Optional[List[TokenClassificationRecord]] = None):
# we implement this to have more specific type hints
super().__init__(records=records)
Expand Down Expand Up @@ -871,8 +884,8 @@ def _parse_tags_field(
import datasets

labels = dataset.features[field]
if isinstance(labels, list):
labels = labels[0]
if isinstance(labels, datasets.Sequence):
labels = labels.feature
int2str = (
labels.int2str if isinstance(labels, datasets.ClassLabel) else lambda x: x
)
Expand Down

0 comments on commit adcf1b1

Please sign in to comment.