Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions src/superannotate/lib/app/interface/sdk_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def set_contributors_categories(
self,
project: Union[NotEmptyStr, int],
contributors: List[Union[int, str]],
categories: Union[List[str], Literal["*"]],
categories: Union[List[NotEmptyStr], Literal["*"]],
):
"""
Assign one or more categories to a contributor with an assignable role (Annotator, QA or custom role)
Expand Down Expand Up @@ -899,6 +899,9 @@ def set_contributors_categories(
categories="*"
)
"""
if not categories:
AppException("Categories should be a list of strings or '*'.")

project = (
self.controller.get_project_by_id(project).data
if isinstance(project, int)
Expand All @@ -917,7 +920,7 @@ def remove_contributors_categories(
self,
project: Union[NotEmptyStr, int],
contributors: List[Union[int, str]],
categories: Union[List[str], Literal["*"]],
categories: Union[List[NotEmptyStr], Literal["*"]],
):
"""
Remove one or more categories for a contributor. "*" in the category list will match all categories defined in the project.
Expand Down Expand Up @@ -946,6 +949,9 @@ def remove_contributors_categories(
categories="*"
)
"""
if not categories:
AppException("Categories should be a list of strings or '*'.")

project = (
self.controller.get_project_by_id(project).data
if isinstance(project, int)
Expand Down Expand Up @@ -1297,7 +1303,7 @@ def clone_project(
return data

def create_categories(
self, project: Union[NotEmptyStr, int], categories: List[str]
self, project: Union[NotEmptyStr, int], categories: List[NotEmptyStr]
):
"""
Create one or more categories in a project.
Expand All @@ -1316,6 +1322,9 @@ def create_categories(
categories=["Shoes", "T-Shirt"]
)
"""
if not categories:
raise AppException("Categories should be a list of strings.")

project = (
self.controller.get_project_by_id(project).data
if isinstance(project, int)
Expand Down Expand Up @@ -1387,7 +1396,7 @@ def list_categories(self, project: Union[NotEmptyStr, int]):
def remove_categories(
self,
project: Union[NotEmptyStr, int],
categories: Union[List[str], Literal["*"]],
categories: Union[List[NotEmptyStr], Literal["*"]],
):
"""
Remove one or more categories in a project. "*" in the category list will match all categories defined in the project.
Expand All @@ -1413,14 +1422,16 @@ def remove_categories(
categories="*"
)
"""
if not categories:
AppException("Categories should be a list of strings or '*'.")

project = (
self.controller.get_project_by_id(project).data
if isinstance(project, int)
else self.controller.get_project(project)
)
self.controller.check_multimodal_project_categorization(project)

categories_to_remove = None
query = EmptyQuery()
if categories == "*":
query &= Filter("id", [0], OperatorEnum.GT)
Expand All @@ -1432,17 +1443,21 @@ def remove_categories(
categories_to_remove = [
c for c in all_categories.data if c.name.lower() in categories
]
query &= Filter("id", [c.id for c in categories_to_remove], OperatorEnum.IN)
if categories_to_remove:
query &= Filter(
"id", [c.id for c in categories_to_remove], OperatorEnum.IN
)
else:
raise AppException("Categories should be a list of strings or '*'.")

if categories_to_remove:
if query.condition_set:
response = self.controller.service_provider.work_management.remove_project_categories(
project_id=project.id, query=query
)
logger.info(
f"{len(response.data)} categories successfully removed from the project."
)
if response.data:
logger.info(
f"{len(response.data)} categories successfully removed from the project."
)

def create_folder(self, project: NotEmptyStr, folder_name: NotEmptyStr):
"""
Expand Down Expand Up @@ -4497,7 +4512,7 @@ def set_items_category(
self,
project: Union[NotEmptyStr, Tuple[int, int], Tuple[str, str]],
items: List[Union[int, str]],
category: str,
category: NotEmptyStr,
):
"""
Add categories to one or more items.
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/work_management/test_project_categories.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from unittest import TestCase

from lib.core.exceptions import AppException
from src.superannotate import SAClient

sa = SAClient()
Expand Down Expand Up @@ -144,3 +145,22 @@ def test_delete_all_categories_with_asterisk(self):
sa.remove_categories(project=self.PROJECT_NAME, categories="*")
categories = sa.list_categories(project=self.PROJECT_NAME)
assert len(categories) == 0

def test_delete_categories_with_empty_list(self):
with self.assertRaisesRegexp(
AppException, "Categories should be a list of strings or '*'"
):
sa.remove_categories(project=self.PROJECT_NAME, categories=[])

def test_delete_invalid_categories(self):
# silent skip
sa.remove_categories(
project=self.PROJECT_NAME,
categories=["invalid_category_1", "invalid_category_2"],
)

def test_create_categories_with_empty_categories(self):
with self.assertRaisesRegexp(
AppException, "Categories should be a list of strings."
):
sa.create_categories(project=self.PROJECT_NAME, categories=[])