Skip to content

Commit

Permalink
feat(#932): label models now modify the prediction_agent when calling…
Browse files Browse the repository at this point in the history
… LabelModel.predict (#1049)

* feat: add prediction agent

* test: add asserts

(cherry picked from commit 867f377)
  • Loading branch information
David Fidalgo authored and frascuchon committed Feb 2, 2022
1 parent eb958bf commit 4a024ee
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/rubrix/labeling/text_classification/label_models.py
Expand Up @@ -73,13 +73,15 @@ def predict(
self,
include_annotated_records: bool = False,
include_abstentions: bool = False,
prediction_agent: str = "LabelModel",
**kwargs,
) -> List[TextClassificationRecord]:
"""Applies the label model.
Args:
include_annotated_records: Whether or not to include annotated records.
include_abstentions: Whether or not to include records in the output, for which the label model abstained.
prediction_agent: String used for the ``prediction_agent`` in the returned records.
Returns:
A list of records that include the predictions of the label model.
Expand Down Expand Up @@ -282,13 +284,15 @@ def predict(
self,
include_annotated_records: bool = False,
include_abstentions: bool = False,
prediction_agent: str = "Snorkel",
tie_break_policy: Union[TieBreakPolicy, str] = "abstain",
) -> List[TextClassificationRecord]:
"""Returns a list of records that contain the predictions of the label model
Args:
include_annotated_records: Whether or not to include annotated records.
include_abstentions: Whether or not to include records in the output, for which the label model abstained.
prediction_agent: String used for the ``prediction_agent`` in the returned records.
tie_break_policy: Policy to break ties. You can choose among three policies:
- `abstain`: Do not provide any prediction
Expand Down Expand Up @@ -355,6 +359,7 @@ def predict(
]

records_with_prediction[-1].prediction = pred_for_rec
records_with_prediction[-1].prediction_agent = prediction_agent

return records_with_prediction

Expand Down Expand Up @@ -537,6 +542,7 @@ def predict(
self,
include_annotated_records: bool = False,
include_abstentions: bool = False,
prediction_agent: str = "FlyingSquid",
verbose: bool = True,
tie_break_policy: str = "abstain",
) -> List[TextClassificationRecord]:
Expand All @@ -545,6 +551,7 @@ def predict(
Args:
include_annotated_records: Whether or not to include annotated records.
include_abstentions: Whether or not to include records in the output, for which the label model abstained.
prediction_agent: String used for the ``prediction_agent`` in the returned records.
verbose: If True, print out messages of the progress to stderr.
tie_break_policy: Policy to break ties. You can choose among two policies:
Expand Down Expand Up @@ -618,6 +625,7 @@ def predict(

records_with_prediction.append(rec.copy(deep=True))
records_with_prediction[-1].prediction = pred_for_rec
records_with_prediction[-1].prediction_agent = prediction_agent

return records_with_prediction

Expand Down
4 changes: 4 additions & 0 deletions tests/labeling/text_classification/test_label_models.py
Expand Up @@ -259,6 +259,7 @@ def mock_predict(self, L, return_probs, tie_break_policy, *args, **kwargs):
tie_break_policy=policy,
include_annotated_records=include_annotated_records,
include_abstentions=include_abstentions,
prediction_agent="mock_agent",
)
assert len(records) == expected[0]
assert [
Expand All @@ -267,6 +268,7 @@ def mock_predict(self, L, return_probs, tie_break_policy, *args, **kwargs):
assert [
rec.prediction[0][1] if rec.prediction else None for rec in records
] == expected[2]
assert records[0].prediction_agent == "mock_agent"

@pytest.mark.parametrize("policy,expected", [("abstain", 0.5), ("random", 2.0 / 3)])
def test_score(self, monkeypatch, weak_labels, policy, expected):
Expand Down Expand Up @@ -455,12 +457,14 @@ def __call__(cls, L_matrix, verbose):
include_annotated_records=include_annotated_records,
include_abstentions=include_abstentions,
verbose=verbose,
prediction_agent="mock_agent",
)

assert MockPredict.calls_count == 3
assert len(records) == expected["nr_of_records"]
if records:
assert records[0].prediction == expected["prediction"]
assert records[0].prediction_agent == "mock_agent"

def test_predict_binary(self, monkeypatch, weak_labels):
class MockPredict:
Expand Down

0 comments on commit 4a024ee

Please sign in to comment.