diff --git a/docs/source/api_reference/api_team.rst b/docs/source/api_reference/api_team.rst index 15f8130c..2217365e 100644 --- a/docs/source/api_reference/api_team.rst +++ b/docs/source/api_reference/api_team.rst @@ -15,3 +15,5 @@ Team .. automethod:: superannotate.SAClient.resume_user_activity .. automethod:: superannotate.SAClient.get_user_scores .. automethod:: superannotate.SAClient.set_user_scores +.. automethod:: superannotate.SAClient.set_contributors_categories +.. automethod:: superannotate.SAClient.remove_contributors_categories diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index dec26352..f6beab25 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -474,7 +474,7 @@ def list_users( self, *, project: Union[int, str] = None, - include: List[Literal["custom_fields"]] = None, + include: List[Literal["custom_fields", "categories"]] = None, **filters, ): """ @@ -488,7 +488,10 @@ def list_users( Possible values are - - "custom_fields": Includes custom fields and scores assigned to each user. + - "custom_fields": Includes custom fields and scores assigned to each user. + - "categories": Includes a list of categories assigned to each project contributor. + Note: 'project' parameter must be specified when including 'categories'. + :type include: list of str, optional :param filters: Specifies filtering criteria, with all conditions combined using logical AND. @@ -860,6 +863,103 @@ def set_user_scores( ) logger.info("Scores successfully set.") + def set_contributors_categories( + self, + project: Union[NotEmptyStr, int], + contributors: List[Union[int, str]], + categories: Union[List[str], Literal["*"]], + ): + """ + Assign one or more categories to a contributor with an assignable role (Annotator, QA or custom role) + in a Multimodal project. Project Admins are not eligible for category assignments. "*" in the category + list will match all categories defined in the project. + + + :param project: The name or ID of the project. + :type project: Union[NotEmptyStr, int] + + :param contributors: A list of emails or IDs of the contributor. + :type contributors: List[Union[int, str]] + + :param categories: A list of category names to assign. Accepts "*" to indicate all available categories in the project. + :type categories: Union[List[str], Literal["*"]] + + Request Example: + :: + + client.set_contributor_categories( + project="product-review-mm", + contributors=["test@superannotate.com","contributor@superannotate.com"], + categories=["Shoes", "T-Shirt"] + ) + + client.set_contributor_categories( + project="product-review-mm", + contributors=["test@superannotate.com","contributor@superannotate.com"] + categories="*" + ) + """ + project = ( + self.controller.get_project_by_id(project).data + if isinstance(project, int) + else self.controller.get_project(project) + ) + self.controller.check_multimodal_project_categorization(project) + + self.controller.work_management.set_remove_contributor_categories( + project=project, + contributors=contributors, + categories=categories, + operation="set", + ) + + def remove_contributors_categories( + self, + project: Union[NotEmptyStr, int], + contributors: List[Union[int, str]], + categories: Union[List[str], Literal["*"]], + ): + """ + Remove one or more categories for a contributor. "*" in the category list will match all categories defined in the project. + + :param project: The name or ID of the project. + :type project: Union[NotEmptyStr, int] + + :param contributors: A list of emails or IDs of the contributor. + :type contributors: List[Union[int, str]] + + :param categories: A list of category names to remove. Accepts "*" to indicate all available categories in the project. + :type categories: Union[List[str], Literal["*"]] + + Request Example: + :: + + client.remove_contributor_categories( + project="product-review-mm", + contributors=["test@superannotate.com","contributor@superannotate.com"], + categories=["Shoes", "T-Shirt", "Jeans"] + ) + + client.remove_contributor_categories( + project="product-review-mm", + contributors=["test@superannotate.com","contributor@superannotate.com"] + categories="*" + ) + """ + project = ( + self.controller.get_project_by_id(project).data + if isinstance(project, int) + else self.controller.get_project(project) + ) + self.controller.check_multimodal_project_categorization(project) + + self.controller.work_management.set_remove_contributor_categories( + project=project, + contributors=contributors, + categories=categories, + operation="remove", + ) + def get_component_config(self, project: Union[NotEmptyStr, int], component_id: str): """ Retrieves the configuration for a given project and component ID. @@ -1320,6 +1420,7 @@ def remove_categories( ) self.controller.check_multimodal_project_categorization(project) + categories_to_remove = None query = EmptyQuery() if categories == "*": query &= Filter("id", [0], OperatorEnum.GT) @@ -1335,14 +1436,13 @@ def remove_categories( else: raise AppException("Categories should be a list of strings or '*'.") - response = ( - self.controller.service_provider.work_management.remove_project_categories( + if categories_to_remove: + response = self.controller.service_provider.work_management.remove_project_categories( project_id=project.id, query=query ) - ) - logger.info( - f"{len(response.data)} categories successfully removed from the project." - ) + logger.info( + f"{len(response.data)} categories successfully removed from the project." + ) def create_folder(self, project: NotEmptyStr, folder_name: NotEmptyStr): """ diff --git a/src/superannotate/lib/core/entities/work_managament.py b/src/superannotate/lib/core/entities/work_managament.py index 8e23f38b..57c880b6 100644 --- a/src/superannotate/lib/core/entities/work_managament.py +++ b/src/superannotate/lib/core/entities/work_managament.py @@ -129,11 +129,12 @@ def json(self, **kwargs): class WMProjectUserEntity(TimedBaseModel): id: Optional[int] team_id: Optional[int] - role: int + role: Optional[int] email: Optional[str] state: Optional[WMUserStateEnum] custom_fields: Optional[dict] = Field(dict(), alias="customField") permissions: Optional[dict] + categories: Optional[list[dict]] class Config: extra = Extra.ignore diff --git a/src/superannotate/lib/core/serviceproviders.py b/src/superannotate/lib/core/serviceproviders.py index 6672383b..49d7ffc9 100644 --- a/src/superannotate/lib/core/serviceproviders.py +++ b/src/superannotate/lib/core/serviceproviders.py @@ -228,6 +228,17 @@ def create_score( def delete_score(self, score_id: int) -> ServiceResponse: raise NotImplementedError + @abstractmethod + def set_remove_contributor_categories( + self, + project_id: int, + contributor_ids: List[int], + category_ids: List[int], + operation: Literal["set", "remove"], + chunk_size=100, + ) -> list[dict]: + raise NotImplementedError + class BaseProjectService(SuperannotateServiceProvider): @abstractmethod diff --git a/src/superannotate/lib/infrastructure/controller.py b/src/superannotate/lib/infrastructure/controller.py index dd531a32..967d4f2c 100644 --- a/src/superannotate/lib/infrastructure/controller.py +++ b/src/superannotate/lib/infrastructure/controller.py @@ -68,6 +68,9 @@ from typing_extensions import Unpack +logger = logging.getLogger("sa") + + def build_condition(**kwargs) -> Condition: condition = Condition.get_empty_condition() if any(kwargs.values()): @@ -177,7 +180,10 @@ def set_custom_field_value( ) def list_users( - self, include: List[Literal["custom_fields"]] = None, project=None, **filters + self, + include: List[Literal["custom_fields", "categories"]] = None, + project=None, + **filters, ): context = {"team_id": self.service_provider.client.team_id} if project: @@ -205,6 +211,10 @@ def list_users( ] ) query = chain.handle(filters, EmptyQuery()) + + if project and include and "categories" in include: + query &= Join("categories") + if include and "custom_fields" in include: response = self.service_provider.work_management.list_users( query, @@ -401,6 +411,76 @@ def set_user_scores( res.res_error = "Please provide valid score values." res.raise_for_status() + def set_remove_contributor_categories( + self, + project: ProjectEntity, + contributors: List[Union[int, str]], + categories: Union[List[str], Literal["*"]], + operation: Literal["set", "remove"], + ): + if categories and contributors: + all_categories = ( + self.service_provider.work_management.list_project_categories( + project_id=project.id, entity=ProjectCategoryEntity # noqa + ).data + ) + if categories == "*": + category_ids = [c.id for c in all_categories] + else: + categories = [c.lower() for c in categories] + category_ids = [ + c.id for c in all_categories if c.name.lower() in categories + ] + + if isinstance(contributors[0], str): + project_contributors = self.list_users( + project=project, email__in=contributors + ) + elif isinstance(contributors[0], int): + project_contributors = self.list_users( + project=project, id__in=contributors + ) + else: + raise AppException("Contributors not found.") + + if len(project_contributors) < len(contributors): + raise AppException("Contributors not found.") + + contributor_ids = [ + c.id + for c in project_contributors + if c.role != 3 # exclude Project Admins + ] + + if category_ids and contributor_ids: + response = self.service_provider.work_management.set_remove_contributor_categories( + project_id=project.id, + contributor_ids=contributor_ids, + category_ids=category_ids, + operation=operation, + ) + + success_processed = 0 + for contributor in response: + contributor_category_ids = [ + category["id"] for category in contributor["categories"] + ] + if operation == "set": + if set(category_ids).issubset(contributor_category_ids): + success_processed += len(category_ids) + else: + if not set(category_ids).intersection(contributor_category_ids): + success_processed += len(category_ids) + + if success_processed / len(contributor_ids) == len(category_ids): + action_for_log = ( + "added to" if operation == "set" else "removed from" + ) + logger.info( + f"{len(category_ids)} categories successfully {action_for_log} " + f"{len(contributor_ids)} contributors." + ) + class ProjectManager(BaseManager): def __init__(self, service_provider: ServiceProvider, team: TeamEntity): diff --git a/src/superannotate/lib/infrastructure/services/work_management.py b/src/superannotate/lib/infrastructure/services/work_management.py index 2edf6fce..861af83f 100644 --- a/src/superannotate/lib/infrastructure/services/work_management.py +++ b/src/superannotate/lib/infrastructure/services/work_management.py @@ -14,6 +14,7 @@ from lib.core.entities.work_managament import WMUserEntity from lib.core.enums import CustomFieldEntityEnum from lib.core.exceptions import AppException +from lib.core.jsx_conditions import EmptyQuery from lib.core.jsx_conditions import Filter from lib.core.jsx_conditions import OperatorEnum from lib.core.jsx_conditions import Query @@ -71,6 +72,7 @@ class WorkManagementService(BaseWorkManagementService): URL_SEARCH_PROJECT_USERS = "projectusers/search" URL_SEARCH_PROJECTS = "projects/search" URL_RESUME_PAUSE_USER = "teams/editprojectsusers" + URL_CONTRIBUTORS_CATEGORIES = "customentities/edit" @staticmethod def _generate_context(**kwargs): @@ -475,3 +477,46 @@ def delete_score(self, score_id: int) -> ServiceResponse: ), }, ) + + def set_remove_contributor_categories( + self, + project_id: int, + contributor_ids: List[int], + category_ids: List[int], + operation: Literal["set", "remove"], + chunk_size=100, + ) -> List[dict]: + params = { + "entity": "Contributor", + "parentEntity": "Project", + } + if operation == "set": + params["action"] = "addcontributorcategory" + else: + params["action"] = "removecontributorcategory" + + from lib.infrastructure.utils import divide_to_chunks + + success_contributors = [] + + for chunk in divide_to_chunks(contributor_ids, chunk_size): + body_query = EmptyQuery() + body_query &= Filter("id", chunk, OperatorEnum.IN) + response = self.client.request( + url=self.URL_CONTRIBUTORS_CATEGORIES, + method="post", + params=params, + data={ + "query": body_query.body_builder(), + "body": {"categories": [{"id": i} for i in category_ids]}, + }, + headers={ + "x-sa-entity-context": self._generate_context( + team_id=self.client.team_id, project_id=project_id + ), + }, + ) + response.raise_for_status() + success_contributors.extend(response.data["data"]) + + return success_contributors diff --git a/tests/integration/work_management/test_contributors_categories.py b/tests/integration/work_management/test_contributors_categories.py new file mode 100644 index 00000000..46c2f55f --- /dev/null +++ b/tests/integration/work_management/test_contributors_categories.py @@ -0,0 +1,266 @@ +import json +import os +import time +from pathlib import Path +from unittest import TestCase + +from lib.core.exceptions import AppException +from src.superannotate import SAClient + +sa = SAClient() + + +class TestContributorsCategories(TestCase): + PROJECT_NAME = "TestContributorsCategories" + PROJECT_TYPE = "Multimodal" + PROJECT_DESCRIPTION = "DESCRIPTION" + EDITOR_TEMPLATE_PATH = os.path.join( + Path(__file__).parent.parent.parent, + "data_set/editor_templates/form1.json", + ) + CLASSES_TEMPLATE_PATH = os.path.join( + Path(__file__).parent.parent.parent, + "data_set/editor_templates/form1_classes.json", + ) + + @classmethod + def setUpClass(cls, *args, **kwargs) -> None: + cls.tearDownClass() + cls._project = sa.create_project( + cls.PROJECT_NAME, + cls.PROJECT_DESCRIPTION, + cls.PROJECT_TYPE, + settings=[ + {"attribute": "TemplateState", "value": 1}, + {"attribute": "CategorizeItems", "value": 1}, + ], + ) + team = sa.controller.team + project = sa.controller.get_project(cls.PROJECT_NAME) + time.sleep(2) + + with open(cls.EDITOR_TEMPLATE_PATH) as f: + template_data = json.load(f) + res = sa.controller.service_provider.projects.attach_editor_template( + team, project, template=template_data + ) + assert res.ok + sa.create_annotation_classes_from_classes_json( + cls.PROJECT_NAME, cls.CLASSES_TEMPLATE_PATH + ) + users = sa.list_users() + scapegoat = [ + u for u in users if u["role"] == "Contributor" and u["state"] == "Confirmed" + ][0] + cls.scapegoat = scapegoat + sa.add_contributors_to_project( + cls.PROJECT_NAME, [scapegoat["email"]], "Annotator" + ) + + @classmethod + def tearDownClass(cls) -> None: + projects = sa.search_projects(cls.PROJECT_NAME, return_metadata=True) + for project in projects: + try: + sa.delete_project(project) + except Exception: + pass + + def tearDown(self): + # cleanup categories + sa.remove_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"]], + categories="*", + ) + sa.remove_categories(project=self.PROJECT_NAME, categories="*") + + def test_set_contributors_categories(self): + test_categories = ["Category_A", "Category_B", "Category_C"] + sa.create_categories(project=self.PROJECT_NAME, categories=test_categories) + + categories = sa.list_categories(project=self.PROJECT_NAME) + assert len(categories) == len(test_categories) + + with self.assertLogs("sa", level="INFO") as cm: + sa.set_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"]], + categories=["Category_A", "Category_B"], + ) + assert ( + "INFO:sa:2 categories successfully added to 1 contributors." + == cm.output[0] + ) + + project_users = sa.list_users( + project=self.PROJECT_NAME, + email=self.scapegoat["email"], + include=["categories"], + ) + assert len(project_users) == 1 + assert len(project_users[0]["categories"]) == 2 + assert project_users[0]["categories"][0]["name"] == "Category_A" + assert project_users[0]["categories"][1]["name"] == "Category_B" + + def test_set_contributors_categories_all(self): + test_categories = ["Category_A", "Category_B", "Category_C", "Category_D"] + sa.create_categories(project=self.PROJECT_NAME, categories=test_categories) + + categories = sa.list_categories(project=self.PROJECT_NAME) + assert len(categories) == len(test_categories) + + with self.assertLogs("sa", level="INFO") as cm: + sa.set_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"]], + categories="*", + ) + assert ( + "INFO:sa:4 categories successfully added to 1 contributors." + == cm.output[0] + ) + + project_users = sa.list_users( + project=self.PROJECT_NAME, + email=self.scapegoat["email"], + include=["categories"], + ) + assert len(project_users) == 1 + assert len(project_users[0]["categories"]) == 4 + assert project_users[0]["categories"][0]["name"] == "Category_A" + assert project_users[0]["categories"][1]["name"] == "Category_B" + assert project_users[0]["categories"][2]["name"] == "Category_C" + assert project_users[0]["categories"][3]["name"] == "Category_D" + + def test_set_contributors_categories_by_id(self): + # Test assigning categories using contributor ID + test_categories = ["ID_Cat_A", "ID_Cat_B"] + sa.create_categories(project=self.PROJECT_NAME, categories=test_categories) + + sa.set_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["id"]], + categories=test_categories, + ) + + # Verify categories were assigned + project_users = sa.list_users( + project=self.PROJECT_NAME, id=self.scapegoat["id"], include=["categories"] + ) + assigned_categories = [cat["name"] for cat in project_users[0]["categories"]] + for category in test_categories: + assert category in assigned_categories + + def test_set_contributors_categories_nonexistent(self): + sa.set_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"]], + categories=["NonExistentCategory"], + ) + + def test_remove_contributors_categories(self): + test_categories = ["RemoveCat_A", "RemoveCat_B", "RemoveCat_C"] + sa.create_categories(project=self.PROJECT_NAME, categories=test_categories) + + sa.set_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"]], + categories=test_categories, + ) + + project_users = sa.list_users( + project=self.PROJECT_NAME, + email=self.scapegoat["email"], + include=["categories"], + ) + assert len(project_users) == 1 + assert len(project_users[0]["categories"]) == 3 + + with self.assertLogs("sa", level="INFO") as cm: + sa.remove_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"]], + categories=["RemoveCat_A", "RemoveCat_B"], + ) + assert "INFO:sa:2 categories successfully removed" in cm.output[0] + + project_users = sa.list_users( + project=self.PROJECT_NAME, + email=self.scapegoat["email"], + include=["categories"], + ) + assert len(project_users) == 1 + assert len(project_users[0]["categories"]) == 1 + assert project_users[0]["categories"][0]["name"] == "RemoveCat_C" + + def test_remove_all_contributors_categories(self): + test_categories = ["AllRemove_X", "AllRemove_Y", "AllRemove_Z"] + sa.create_categories(project=self.PROJECT_NAME, categories=test_categories) + + # First assign all categories + sa.set_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"]], + categories=test_categories, + ) + + project_users = sa.list_users( + project=self.PROJECT_NAME, + email=self.scapegoat["email"], + include=["categories"], + ) + assert len(project_users) == 1 + assert len(project_users[0]["categories"]) == 3 + + with self.assertLogs("sa", level="INFO") as cm: + sa.remove_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"]], + categories="*", + ) + assert "INFO:sa:3 categories successfully removed" in cm.output[0] + + project_users = sa.list_users( + project=self.PROJECT_NAME, + email=self.scapegoat["email"], + include=["categories"], + ) + assert len(project_users) == 1 + assert len(project_users[0]["categories"]) == 0 + + def test_set_categories_with_invalid_contributor(self): + test_categories = ["Category_A", "Category_B", "Category_C"] + sa.create_categories(project=self.PROJECT_NAME, categories=test_categories) + + with self.assertRaisesRegexp(AppException, "Contributors not found.") as cm: + sa.set_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"], "invalid_email@mail.com"], + categories=["Category_A", "Category_B"], + ) + + def test_set_contributors_with_invalid_categories(self): + test_categories = ["Category_A", "Category_B", "Category_C"] + sa.create_categories(project=self.PROJECT_NAME, categories=test_categories) + + sa.set_contributors_categories( + project=self.PROJECT_NAME, + contributors=[self.scapegoat["email"]], + categories=[ + "Category_A", + "Category_C", + "InvalidCategory_1", + "InvalidCategory_2", + ], + ) + + project_users = sa.list_users( + project=self.PROJECT_NAME, + email=self.scapegoat["email"], + include=["categories"], + ) + assert len(project_users) == 1 + assert len(project_users[0]["categories"]) == 2 + assert project_users[0]["categories"][0]["name"] == "Category_A" + assert project_users[0]["categories"][1]["name"] == "Category_C"