Skip to content

Commit

Permalink
fix(#905): copy dataset with rules (#948)
Browse files Browse the repository at this point in the history
* fix: copy dataset with defined rules

* add missing tests

(cherry picked from commit 70503ff)
  • Loading branch information
frascuchon committed Jan 18, 2022
1 parent 2921e18 commit 8597b83
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
11 changes: 9 additions & 2 deletions src/rubrix/server/datasets/dao.py
Expand Up @@ -21,12 +21,13 @@

from rubrix.server.commons.es_wrapper import ElasticsearchWrapper, create_es_wrapper
from rubrix.server.tasks.commons import TaskType
from .model import DatasetDB

from ..commons.es_settings import (
DATASETS_INDEX_NAME,
DATASETS_INDEX_TEMPLATE,
DATASETS_RECORDS_INDEX_NAME,
)
from .model import DatasetDB


def dataset_records_index(dataset_id: str) -> str:
Expand Down Expand Up @@ -287,10 +288,16 @@ def __dict_to_key_value_list__(data: Dict[str, Any]) -> List[Dict[str, Any]]:
}

def copy(self, source: DatasetDB, target: DatasetDB):
source_doc = self._es.get_document_by_id(
index=DATASETS_INDEX_NAME, doc_id=source.id
)
self._es.add_document(
index=DATASETS_INDEX_NAME,
doc_id=target.id,
document=self._dataset_to_es_doc(target),
document={
**source_doc["_source"], # we copy extended fields from source document
**self._dataset_to_es_doc(target),
},
)
index_from = dataset_records_index(source.id)
index_to = dataset_records_index(target.id)
Expand Down
24 changes: 20 additions & 4 deletions tests/labeling/text_classification/test_rule.py
Expand Up @@ -21,10 +21,7 @@
CreationTextClassificationRecord,
TextClassificationBulkData,
)
from rubrix.labeling.text_classification import (
Rule,
load_rules,
)
from rubrix.labeling.text_classification import Rule, load_rules
from rubrix.labeling.text_classification.rule import RuleNotAppliedError
from tests.server.test_helpers import client, mocking_client

Expand Down Expand Up @@ -127,6 +124,25 @@ def test_load_rules(monkeypatch, log_dataset):
assert rules[0].label == "LALA"


def test_copy_dataset_with_rules(monkeypatch, log_dataset):
import rubrix as rb

mocking_client(monkeypatch, client)

client.post(
f"/api/datasets/TextClassification/{log_dataset}/labeling/rules",
json={"query": "a query", "label": "LALA"},
)

copied_dataset = f"{log_dataset}_copy"
rb.delete(copied_dataset)
rb.copy(log_dataset, name_of_copy=copied_dataset)

assert [{"q": r.query, "l": r.label} for r in load_rules(copied_dataset)] == [
{"q": r.query, "l": r.label} for r in load_rules(log_dataset)
]


@pytest.mark.parametrize(
["rule", "expected_metrics"],
[
Expand Down

0 comments on commit 8597b83

Please sign in to comment.