Skip to content

Commit

Permalink
feat: Add, delete and edit labeling rules from Python client (#1884)
Browse files Browse the repository at this point in the history
Closes #1855
  • Loading branch information
frascuchon committed Nov 23, 2022
1 parent 31b84cf commit d534a29
Show file tree
Hide file tree
Showing 14 changed files with 4,882 additions and 1,282 deletions.
@@ -0,0 +1,6 @@
,query,label
0,your,SPAM
1,rich,SPAM
2,film,HAM
3,meeting,HAM
4,help,HAM
945 changes: 733 additions & 212 deletions docs/_source/guides/techniques/weak_supervision.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions src/argilla/client/api.py
Expand Up @@ -608,6 +608,29 @@ def compute_metric(

return MetricResults(**metric_.dict(), results=response.parsed)

def add_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]):
"""Adds the dataset labeling rules"""
for rule in rules:
text_classification_api.add_dataset_labeling_rule(
self._client,
name=dataset,
rule=rule,
)

def update_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]):
"""Updates the dataset labeling rules"""
for rule in rules:
text_classification_api.update_dataset_labeling_rule(
self._client, name=dataset, rule=rule
)

def delete_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]):
"""Deletes the dataset labeling rules"""
for rule in rules:
text_classification_api.delete_dataset_labeling_rule(
self._client, name=dataset, rule=rule
)

def fetch_dataset_labeling_rules(self, dataset: str) -> List[LabelingRule]:
response = text_classification_api.fetch_dataset_labeling_rules(
self._client, name=dataset
Expand Down
57 changes: 57 additions & 0 deletions src/argilla/client/sdk/text_classification/api.py
Expand Up @@ -58,6 +58,63 @@ def data(
)


def add_dataset_labeling_rule(
client: AuthenticatedClient,
name: str,
rule: LabelingRule,
) -> Response[Union[LabelingRule, HTTPValidationError, ErrorMessage]]:
url = "{}/api/datasets/{name}/TextClassification/labeling/rules".format(
client.base_url, name=name
)

response = httpx.post(
url=url,
json={"query": rule.query, "labels": rule.labels},
headers=client.get_headers(),
cookies=client.get_cookies(),
timeout=client.get_timeout(),
)

return build_typed_response(response, LabelingRule)


def update_dataset_labeling_rule(
client: AuthenticatedClient,
name: str,
rule: LabelingRule,
) -> Response[Union[HTTPValidationError, ErrorMessage]]:
url = "{}/api/datasets/TextClassification/{name}/labeling/rules/{query}".format(
client.base_url, name=name, query=rule.query
)

response = httpx.patch(
url,
json={"labels": rule.labels},
headers=client.get_headers(),
cookies=client.get_cookies(),
timeout=client.get_timeout(),
)

return build_typed_response(response, LabelingRule)


def delete_dataset_labeling_rule(
client: AuthenticatedClient,
name: str,
rule: LabelingRule,
) -> Response[Union[LabelingRule, HTTPValidationError, ErrorMessage]]:
url = "{}/api/datasets/TextClassification/{name}/labeling/rules/{query}".format(
client.base_url, name=name, query=rule.query
)

httpx.delete(
url,
headers=client.get_headers(),
cookies=client.get_cookies(),
timeout=client.get_timeout(),
)


def fetch_dataset_labeling_rules(
client: AuthenticatedClient,
name: str,
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/sdk/text_classification/models.py
Expand Up @@ -186,7 +186,7 @@ class LabelingRule(BaseModel):
labels: List[str] = Field(default_factory=list)
query: str
description: Optional[str] = None
author: str
author: Optional[str] = None
created_at: datetime = None


Expand Down
2 changes: 1 addition & 1 deletion src/argilla/labeling/text_classification/__init__.py
Expand Up @@ -15,5 +15,5 @@

from .label_errors import find_label_errors
from .label_models import FlyingSquid, MajorityVoter, Snorkel
from .rule import Rule, load_rules
from .rule import Rule, add_rules, delete_rules, load_rules, update_rules
from .weak_labels import WeakLabels, WeakMultiLabels
75 changes: 72 additions & 3 deletions src/argilla/labeling/text_classification/rule.py
Expand Up @@ -62,6 +62,10 @@ def label(self) -> Union[str, List[str]]:
"""The rule label"""
return self._label

@label.setter
def label(self, value):
self._label = value

@property
def name(self):
"""The name of the rule."""
Expand All @@ -74,6 +78,34 @@ def author(self):
"""Who authored the rule."""
return self._author

def _convert_to_labeling_rule(self):
"""Converts the rule to a LabelingRule"""
if isinstance(self._label, str):
labels = [self._label]
else:
labels = self._label

return LabelingRule(query=self.query, labels=labels)

def add_to_dataset(self, dataset: str):
"""Add to rule to the given dataset"""
api.active_api().add_dataset_labeling_rules(
dataset, rules=[self._convert_to_labeling_rule()]
)

def remove_from_dataset(self, dataset: str):
"""Removes the rule from the given dataset"""

api.active_api().delete_dataset_labeling_rules(
dataset, rules=[self._convert_to_labeling_rule()]
)

def update_at_dataset(self, dataset: str):
"""Updates the rule at the given dataset"""
api.active_api().update_dataset_labeling_rules(
dataset, rules=[self._convert_to_labeling_rule()]
)

def apply(self, dataset: str):
"""Apply the rule to a dataset and save matching ids of the records.
Expand Down Expand Up @@ -101,9 +133,7 @@ def metrics(self, dataset: str) -> Dict[str, Union[int, float]]:
"""
metrics = api.active_api().rule_metrics_for_dataset(
dataset=dataset,
rule=LabelingRule(
query=self.query, label=self.label, author=self.author or "None"
),
rule=LabelingRule(query=self.query, label=self.label),
)

return {
Expand Down Expand Up @@ -143,6 +173,45 @@ def __call__(
return self._label


def add_rules(dataset: str, rules: List[Rule]):
"""Adds the rules to a given dataset
Args:
dataset: Name of the dataset.
rules: Rules to add to the dataset
Returns:
"""
rules = [rule._convert_to_labeling_rule() for rule in rules]
return api.active_api().add_dataset_labeling_rules(dataset, rules)


def delete_rules(dataset: str, rules: List[Rule]):
"""Deletes the rules from the given dataset
Args:
dataset: Name of the dataset
rules: Rules to delete from the dataset
Returns:
"""
rules = [rule._convert_to_labeling_rule() for rule in rules]
api.active_api().delete_dataset_labeling_rules(dataset, rules)


def update_rules(dataset: str, rules: List[Rule]):
"""Updates the rules of the given dataset
Args:
dataset: Name of the dataset
rules: Rules to update at the dataset
Returns:
"""
rules = [rule._convert_to_labeling_rule() for rule in rules]
api.active_api().update_dataset_labeling_rules(dataset, rules)


def load_rules(dataset: str) -> List[Rule]:
"""load the rules defined in a given dataset.
Expand Down
7 changes: 5 additions & 2 deletions tests/client/sdk/conftest.py
Expand Up @@ -71,14 +71,17 @@ def check_schema_props(client_props, server_props):
return len(different_props) < len(client_props) / 2

client_props = self._expands_schema(
client_schema["properties"], client_schema["definitions"]
client_schema["properties"],
client_schema.get("definitions", {}),
)
server_props = self._expands_schema(
server_schema["properties"], server_schema["definitions"]
server_schema["properties"],
server_schema.get("definitions", {}),
)

if client_props == server_props:
return True

return check_schema_props(client_props, server_props)

def _expands_schema(
Expand Down
4 changes: 1 addition & 3 deletions tests/client/sdk/text_classification/test_models.py
Expand Up @@ -62,9 +62,7 @@ def test_labeling_rule_schema(helpers):
client_schema = LabelingRule.schema()
server_schema = ServerLabelingRule.schema()

assert helpers.remove_description(client_schema) == helpers.remove_description(
server_schema
)
assert helpers.are_compatible_api_schemas(client_schema, server_schema)


def test_labeling_rule_metrics_schema(helpers):
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Expand Up @@ -59,6 +59,7 @@ def whoami_mocked(client):
monkeypatch.setattr(users_api, "whoami", whoami_mocked)

monkeypatch.setattr(httpx, "post", client_.post)
monkeypatch.setattr(httpx, "patch", client_.patch)
monkeypatch.setattr(httpx.AsyncClient, "post", client_.post_async)
monkeypatch.setattr(httpx, "get", client_.get)
monkeypatch.setattr(httpx, "delete", client_.delete)
Expand Down
112 changes: 111 additions & 1 deletion tests/labeling/text_classification/test_rule.py
Expand Up @@ -21,7 +21,13 @@
CreationTextClassificationRecord,
TextClassificationBulkData,
)
from argilla.labeling.text_classification import Rule, load_rules
from argilla.labeling.text_classification import (
Rule,
add_rules,
delete_rules,
load_rules,
update_rules,
)
from argilla.labeling.text_classification.rule import RuleNotAppliedError
from argilla.server.errors import EntityNotFoundError

Expand Down Expand Up @@ -86,6 +92,40 @@ def test_name(name, expected):
assert rule.name == expected


def test_atomic_crud_operations(monkeypatch, mocked_client, log_dataset):
rule = Rule(query="inputs.text:(NOT positive)", label="negative")
with pytest.raises(RuleNotAppliedError):
rule(TextClassificationRecord(text="test"))

monkeypatch.setattr(httpx, "get", mocked_client.get)
monkeypatch.setattr(httpx, "patch", mocked_client.patch)
monkeypatch.setattr(httpx, "delete", mocked_client.delete)
monkeypatch.setattr(httpx, "post", mocked_client.post)
monkeypatch.setattr(httpx, "stream", mocked_client.stream)

rule.add_to_dataset(log_dataset)

rules = load_rules(log_dataset)
assert len(rules) == 1
assert rules[0].query == "inputs.text:(NOT positive)"
assert rules[0].label == "negative"

rule.remove_from_dataset(log_dataset)

rules = load_rules(log_dataset)
assert len(rules) == 0

rule = Rule(query="inputs.text:(NOT positive)", label="negative")
rule.add_to_dataset(log_dataset)
rule.label = "positive"
rule.update_at_dataset(log_dataset)

rules = load_rules(log_dataset)
assert len(rules) == 1
assert rules[0].query == "inputs.text:(NOT positive)"
assert rules[0].label == "positive"


def test_apply(monkeypatch, mocked_client, log_dataset):
rule = Rule(query="inputs.text:(NOT positive)", label="negative")
with pytest.raises(RuleNotAppliedError):
Expand Down Expand Up @@ -123,6 +163,76 @@ def test_load_rules(mocked_client, log_dataset):
assert rules[0].label == "LALA"


def test_add_rules(mocked_client, log_dataset):

expected_rules = [
Rule(query="a query", label="La La"),
Rule(query="another query", label="La La"),
Rule(query="the other query", label="La La La"),
]

add_rules(log_dataset, expected_rules)

actual_rules = load_rules(log_dataset)

assert len(actual_rules) == 3
for actual_rule, expected_rule in zip(actual_rules, expected_rules):
assert actual_rule.query == expected_rule.query
assert actual_rule.label == expected_rule.label


def test_delete_rules(mocked_client, log_dataset):

rules = [
Rule(query="a query", label="La La"),
Rule(query="another query", label="La La"),
Rule(query="the other query", label="La La La"),
]

add_rules(log_dataset, rules)

delete_rules(
log_dataset,
[
Rule(query="a query", label="La La"),
],
)

actual_rules = load_rules(log_dataset)

assert len(actual_rules) == 2

for actual_rule, expected_rule in zip(actual_rules, rules[1:]):
assert actual_rule.label == expected_rule.label
assert actual_rule.query == expected_rule.query


def test_update_rules(mocked_client, log_dataset):

rules = [
Rule(query="a query", label="La La"),
Rule(query="another query", label="La La"),
Rule(query="the other query", label="La La La"),
]

add_rules(log_dataset, rules)
rules_to_update = [
Rule(query="a query", label="La La La"),
]
update_rules(log_dataset, rules=rules_to_update)

actual_rules = load_rules(log_dataset)

assert len(rules) == 3

assert actual_rules[0].query == "a query"
assert actual_rules[0].label == "La La La"

for actual_rule, expected_rule in zip(actual_rules[1:], rules[1:]):
assert actual_rule.label == expected_rule.label
assert actual_rule.query == expected_rule.query


def test_copy_dataset_with_rules(mocked_client, log_dataset):
import argilla as ar

Expand Down

0 comments on commit d534a29

Please sign in to comment.