Skip to content

Commit

Permalink
feat(#955): add default for rules in WeakLabels (#976)
Browse files Browse the repository at this point in the history
* feat: load rules of dataset by default

* test: add tests

* test: fix tests

fix(#1010): fix WeakLabels when not providing rules (#1011)

* fix: bugfix when no rules but a ds is provided

* test: add test

* test: fix test
  • Loading branch information
David Fidalgo authored and frascuchon committed Feb 2, 2022
1 parent a5ed329 commit 34389d3
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
43 changes: 33 additions & 10 deletions src/rubrix/labeling/text_classification/weak_labels.py
Expand Up @@ -21,15 +21,16 @@

from rubrix import load
from rubrix.client.models import TextClassificationRecord
from rubrix.labeling.text_classification.rule import Rule
from rubrix.labeling.text_classification.rule import Rule, load_rules


class WeakLabels:
"""Computes the weak labels of a dataset by applying a given list of rules.
Args:
rules: A list of rules (labeling functions). They must return a string, or ``None`` in case of abstention.
dataset: Name of the dataset to which the rules will be applied.
rules: A list of rules (labeling functions). They must return a string, or ``None`` in case of abstention.
If None, we will use the rules of the dataset (Default).
ids: An optional list of record ids to filter the dataset before applying the rules.
query: An optional ElasticSearch query with the
`query string syntax <https://rubrix.readthedocs.io/en/stable/reference/webapp/search_records.html>`_
Expand All @@ -38,14 +39,21 @@ class WeakLabels:
abstention (e.g. ``{None: -1}``). By default, we will build a mapping on the fly when applying the rules.
Raises:
NoRulesFoundError: When you do not provide rules, and the dataset has no rules either.
DuplicatedRuleNameError: When you provided multiple rules with the same name.
NoRecordsFoundError: When the filtered dataset is empty.
MultiLabelError: When trying to get weak labels for a multi-label text classification task.
MissingLabelError: When provided with a ``label2int`` dict, and a
weak label or annotation label is not present in its keys.
Examples:
Get the weak label matrix and a summary of the applied rules:
Get the weak label matrix from a dataset with rules:
>>> weak_labels = WeakLabels(dataset="my_dataset")
>>> weak_labels.matrix()
>>> weak_labels.summary()
Get the weak label matrix from rules defined in Python:
>>> def awesome_rule(record: TextClassificationRecord) -> str:
... return "Positive" if "awesome" in record.inputs["text"] else None
Expand All @@ -54,24 +62,37 @@ class WeakLabels:
>>> weak_labels.matrix()
>>> weak_labels.summary()
Use snorkel's LabelModel:
Use the WeakLabels object with snorkel's LabelModel:
>>> from snorkel.labeling.model import LabelModel
>>> label_model = LabelModel()
>>> label_model.fit(L_train=weak_labels.matrix(has_annotation=False))
>>> label_model.score(L=weak_labels.matrix(has_annotation=True), Y=weak_labels.annotation())
>>> label_model.predict(L=weak_labels.matrix(has_annotation=False))
For a builtin integration with Snorkel, see `rubrix.labeling.text_classification.Snorkel`.
"""

def __init__(
self,
rules: List[Callable],
dataset: str,
rules: Optional[List[Callable]] = None,
ids: Optional[List[Union[int, str]]] = None,
query: Optional[str] = None,
label2int: Optional[Dict[Optional[str], int]] = None,
):
self._rules = rules
if not isinstance(dataset, str):
raise TypeError(
f"The name of the dataset must be a string, but you provided: {dataset}"
)
self._dataset = dataset

self._rules = rules or load_rules(dataset)
if self._rules == []:
raise NoRulesFoundError(
f"No rules were found in the given dataset '{dataset}'"
)

self._rules_index2name = {
# covers our Rule class, snorkel's LabelingFunction class and arbitrary methods
index: (
Expand All @@ -84,11 +105,11 @@ def __init__(
)
or f"rule_{index}"
)
for index, rule in enumerate(rules)
for index, rule in enumerate(self._rules)
}
# raise error if there are duplicates
counts = Counter(self._rules_index2name.values())
if len(counts.keys()) < len(rules):
if len(counts.keys()) < len(self._rules):
raise DuplicatedRuleNameError(
f"Following rule names are duplicated x times: { {key: val for key, val in counts.items() if val > 1} }"
" Please make sure to provide unique rule names."
Expand All @@ -97,8 +118,6 @@ def __init__(
val: key for key, val in self._rules_index2name.items()
}

self._dataset = dataset

# load records and check compatibility
self._records: List[TextClassificationRecord] = load(
dataset, query=query, ids=ids, as_pandas=False
Expand Down Expand Up @@ -499,6 +518,10 @@ class WeakLabelsError(Exception):
pass


class NoRulesFoundError(WeakLabelsError):
pass


class DuplicatedRuleNameError(WeakLabelsError):
pass

Expand Down
39 changes: 34 additions & 5 deletions tests/labeling/text_classification/test_weak_labels.py
Expand Up @@ -30,6 +30,7 @@
MissingLabelError,
MultiLabelError,
NoRecordsFoundError,
NoRulesFoundError,
WeakLabels,
)
from tests.server.test_helpers import client
Expand Down Expand Up @@ -105,7 +106,7 @@ def mock_load(*args, **kwargs):
)

with pytest.raises(MultiLabelError):
WeakLabels(rules=[], dataset="mock")
WeakLabels(rules=[lambda x: None], dataset="mock")


def test_no_records_found_error(monkeypatch):
Expand All @@ -119,21 +120,21 @@ def mock_load(*args, **kwargs):
with pytest.raises(
NoRecordsFoundError, match="No records found in dataset 'mock'."
):
WeakLabels(rules=[], dataset="mock")
WeakLabels(rules=[lambda x: None], dataset="mock")
with pytest.raises(
NoRecordsFoundError,
match="No records found in dataset 'mock' with query 'mock'.",
):
WeakLabels(rules=[], dataset="mock", query="mock")
WeakLabels(rules=[lambda x: None], dataset="mock", query="mock")
with pytest.raises(
NoRecordsFoundError, match="No records found in dataset 'mock' with ids \[-1\]."
):
WeakLabels(rules=[], dataset="mock", ids=[-1])
WeakLabels(rules=[lambda x: None], dataset="mock", ids=[-1])
with pytest.raises(
NoRecordsFoundError,
match="No records found in dataset 'mock' with query 'mock' and with ids \[-1\].",
):
WeakLabels(rules=[], dataset="mock", query="mock", ids=[-1])
WeakLabels(rules=[lambda x: None], dataset="mock", query="mock", ids=[-1])


@pytest.mark.parametrize(
Expand Down Expand Up @@ -406,3 +407,31 @@ def mock_apply(self, *args, **kwargs):
weak_labels.change_mapping(old_mapping)

assert (weak_labels.matrix() == old_wlm).all()


def test_dataset_type_error():
with pytest.raises(TypeError, match="must be a string, but you provided"):
WeakLabels([1, 2, 3])


def test_rules_from_dataset(monkeypatch, log_dataset):
monkeypatch.setattr(httpx, "get", client.get)
monkeypatch.setattr(httpx, "stream", client.stream)

mock_rules = [Rule(query="mock", label="mock")]
monkeypatch.setattr(
"rubrix.labeling.text_classification.weak_labels.load_rules",
lambda x: mock_rules,
)

wl = WeakLabels(log_dataset)
assert wl.rules is mock_rules


def test_norulesfounderror(monkeypatch):
monkeypatch.setattr(
"rubrix.labeling.text_classification.weak_labels.load_rules", lambda x: []
)

with pytest.raises(NoRulesFoundError, match="No rules were found"):
WeakLabels("mock")

0 comments on commit 34389d3

Please sign in to comment.