Skip to content

Commit

Permalink
fix(DAO): datasets dao filter datasets by tasks (#1934)
Browse files Browse the repository at this point in the history
Co-authored-by: Francisco Aranda <francisco@recogn.ai>
  • Loading branch information
frascuchon and frascuchon committed Nov 21, 2022
1 parent 8e3851e commit 937b410
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/argilla/server/daos/backend/search/model.py
Expand Up @@ -53,7 +53,7 @@ class BaseQuery(BaseModel):


class BaseDatasetsQuery(BaseQuery):
Tasks: Optional[List[str]] = None
tasks: Optional[List[str]] = None
owners: Optional[List[str]] = None
include_no_owner: bool = None
name: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/daos/backend/search/query_builder.py
Expand Up @@ -65,7 +65,7 @@ def _datasets_to_es_query(
else:
query_filters.append(owners_filter)

if query.Tasks:
if query.tasks:
query_filters.append(
filters.terms_filter(field="task.keyword", values=query.tasks)
)
Expand Down
36 changes: 36 additions & 0 deletions tests/server/datasets/test_dao.py
Expand Up @@ -37,6 +37,42 @@ def test_retrieve_ownered_dataset_for_no_owner_user():
assert dao.find_by_name(created.name, owner="me") is None


def test_list_datasets_by_task():
dataset = "test_list_datasets_by_task"

all_datasets = dao.list_datasets()
for ds in all_datasets:
dao.delete_dataset(ds)

created_text = dao.create_dataset(
BaseDatasetDB(
name=dataset + "_text",
owner="other",
task=TaskType.text_classification,
),
)

created_token = dao.create_dataset(
BaseDatasetDB(
name=dataset + "_token",
owner="other",
task=TaskType.token_classification,
),
)

datasets = dao.list_datasets(
task2dataset_map={created_text.task: BaseDatasetDB},
)

assert len(datasets) == 1
assert datasets[0].name == created_text.name

datasets = dao.list_datasets(task2dataset_map={created_token.task: BaseDatasetDB})

assert len(datasets) == 1
assert datasets[0].name == created_token.name


def test_close_dataset():
dataset = "test_close_dataset"

Expand Down

0 comments on commit 937b410

Please sign in to comment.