From 3f8cabb6d7e28f1dd4c745a7a903459c9713cdad Mon Sep 17 00:00:00 2001 From: Narek Mkhitaryan Date: Wed, 2 Jul 2025 12:43:11 +0400 Subject: [PATCH] added project category related functions --- docs/source/api_reference/api_project.rst | 3 + .../lib/app/interface/sdk_interface.py | 150 ++++++++++++++++++ src/superannotate/lib/core/entities/items.py | 9 ++ src/superannotate/lib/core/service_types.py | 4 + .../lib/core/serviceproviders.py | 16 +- .../lib/core/usecases/annotations.py | 4 +- .../lib/infrastructure/controller.py | 12 ++ .../services/work_management.py | 37 ++++- .../test_project_categories.py | 146 +++++++++++++++++ 9 files changed, 372 insertions(+), 9 deletions(-) create mode 100644 tests/integration/work_management/test_project_categories.py diff --git a/docs/source/api_reference/api_project.rst b/docs/source/api_reference/api_project.rst index 858030333..23f0cdac5 100644 --- a/docs/source/api_reference/api_project.rst +++ b/docs/source/api_reference/api_project.rst @@ -27,3 +27,6 @@ Projects .. automethod:: superannotate.SAClient.get_project_steps .. automethod:: superannotate.SAClient.set_project_steps .. automethod:: superannotate.SAClient.get_component_config +.. automethod:: superannotate.SAClient.create_categories +.. automethod:: superannotate.SAClient.list_categories +.. automethod:: superannotate.SAClient.remove_categories diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index 8bcc3d924..856a3f2a6 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -73,6 +73,8 @@ from lib.app.serializers import WMProjectSerializer from lib.core.entities.work_managament import WMUserTypeEnum from lib.core.jsx_conditions import EmptyQuery +from lib.core.entities.items import ProjectCategoryEntity + logger = logging.getLogger("sa") @@ -1194,6 +1196,154 @@ def clone_project( ) return data + def create_categories( + self, project: Union[NotEmptyStr, int], categories: List[str] + ): + """ + Create one or more categories in a project. + + :param project: The name or ID of the project. + :type project: Union[NotEmptyStr, int] + + :param categories: A list of categories to create + :type categories: list of str + + Request Example: + :: + + client.create_categories( + project="product-review-mm", + categories=["Shoes", "T-Shirt"] + ) + """ + project = ( + self.controller.get_project_by_id(project).data + if isinstance(project, int) + else self.controller.get_project(project) + ) + self.controller.check_multimodal_project_categorization(project) + + response = ( + self.controller.service_provider.work_management.create_project_categories( + project_id=project.id, categories=categories + ) + ) + logger.info( + f"{len(response.data)} categories successfully added to the project." + ) + + def list_categories(self, project: Union[NotEmptyStr, int]): + """ + List all categories in the project. + + :param project: The name or ID of the project. + :type project: Union[NotEmptyStr, int] + + :return: List of categories + :rtype: list of dict + + Request Example: + :: + + client.list_categories( + project="product-review-mm" + ) + + Response Example: + :: + + [ + { + "createdAt": "2025-01-29T13:51:39.000Z", + "updatedAt": "2025-01-29T13:51:39.000Z", + "id": 328577, + "name": "category1", + "project_id": 1234 + }, + { + "createdAt": "2025-01-29T13:51:39.000Z", + "updatedAt": "2025-01-29T13:51:39.000Z", + "id": 328577, + "name": "category2", + "project_id": 1234 + }, + ] + + """ + project = ( + self.controller.get_project_by_id(project).data + if isinstance(project, int) + else self.controller.get_project(project) + ) + self.controller.check_multimodal_project_categorization(project) + + response = ( + self.controller.service_provider.work_management.list_project_categories( + project_id=project.id, entity=ProjectCategoryEntity + ) + ) + return BaseSerializer.serialize_iterable(response.data) + + def remove_categories( + self, + project: Union[NotEmptyStr, int], + categories: Union[List[str], Literal["*"]], + ): + """ + Remove one or more categories in a project. "*" in the category list will match all categories defined in the project. + + + :param project: The name or ID of the project. + :type project: Union[NotEmptyStr, int] + + :param categories: A list of categories to remove, Accepts "*" to indicate all available categories in the project. + :type categories: Union[List[str], Literal["*"]] + + Request Example: + :: + + client.remove_categories( + project="product-review-mm", + categories=["Shoes", "T-Shirt"] + ) + + # To remove all categories + client.remove_categories( + project="product-review-mm", + categories="*" + ) + """ + project = ( + self.controller.get_project_by_id(project).data + if isinstance(project, int) + else self.controller.get_project(project) + ) + self.controller.check_multimodal_project_categorization(project) + + query = EmptyQuery() + if categories == "*": + query &= Filter("id", [0], OperatorEnum.GT) + elif categories and isinstance(categories, list): + categories = [c.lower() for c in categories] + all_categories = self.controller.service_provider.work_management.list_project_categories( + project_id=project.id, entity=ProjectCategoryEntity + ) + categories_to_remove = [ + c for c in all_categories.data if c.name.lower() in categories + ] + query &= Filter("id", [c.id for c in categories_to_remove], OperatorEnum.IN) + else: + raise AppException("Categories should be a list of strings or '*'.") + + response = ( + self.controller.service_provider.work_management.remove_project_categories( + project_id=project.id, query=query + ) + ) + logger.info( + f"{len(response.data)} categories successfully removed from the project." + ) + def create_folder(self, project: NotEmptyStr, folder_name: NotEmptyStr): """ Create a new folder in the project. diff --git a/src/superannotate/lib/core/entities/items.py b/src/superannotate/lib/core/entities/items.py index 477cab8dd..1beff24cf 100644 --- a/src/superannotate/lib/core/entities/items.py +++ b/src/superannotate/lib/core/entities/items.py @@ -26,6 +26,15 @@ class Config: extra = Extra.ignore +class ProjectCategoryEntity(TimedBaseModel): + id: int + name: str + project_id: int + + class Config: + extra = Extra.ignore + + class MultiModalItemCategoryEntity(TimedBaseModel): id: int = Field(None, alias="category_id") value: str = Field(None, alias="category_name") diff --git a/src/superannotate/lib/core/service_types.py b/src/superannotate/lib/core/service_types.py index e98c4af62..066def7b7 100644 --- a/src/superannotate/lib/core/service_types.py +++ b/src/superannotate/lib/core/service_types.py @@ -234,6 +234,10 @@ class ListCategoryResponse(ServiceResponse): res_data: List[entities.CategoryEntity] = None +class ListProjectCategoryResponse(ServiceResponse): + res_data: List[entities.items.ProjectCategoryEntity] = None + + class WorkflowResponse(ServiceResponse): res_data: entities.WorkflowEntity = None diff --git a/src/superannotate/lib/core/serviceproviders.py b/src/superannotate/lib/core/serviceproviders.py index 9ffba499f..95b313f13 100644 --- a/src/superannotate/lib/core/serviceproviders.py +++ b/src/superannotate/lib/core/serviceproviders.py @@ -7,9 +7,12 @@ from typing import List from typing import Literal from typing import Optional +from typing import Union from lib.core import entities from lib.core.conditions import Condition +from lib.core.entities import CategoryEntity +from lib.core.entities.project_entities import BaseEntity from lib.core.enums import CustomFieldEntityEnum from lib.core.jsx_conditions import Query from lib.core.reporter import Reporter @@ -18,6 +21,7 @@ from lib.core.service_types import FolderResponse from lib.core.service_types import IntegrationListResponse from lib.core.service_types import ListCategoryResponse +from lib.core.service_types import ListProjectCategoryResponse from lib.core.service_types import ProjectListResponse from lib.core.service_types import ProjectResponse from lib.core.service_types import ServiceResponse @@ -137,13 +141,21 @@ def search_projects( raise NotImplementedError @abstractmethod - def list_project_categories(self, project_id: int) -> ListCategoryResponse: + def list_project_categories( + self, project_id: int, entity: BaseEntity = CategoryEntity + ) -> Union[ListCategoryResponse, ListProjectCategoryResponse]: + raise NotImplementedError + + @abstractmethod + def remove_project_categories( + self, project_id: int, query: Query + ) -> ListProjectCategoryResponse: raise NotImplementedError @abstractmethod def create_project_categories( self, project_id: int, categories: List[str] - ) -> ServiceResponse: + ) -> ListProjectCategoryResponse: raise NotImplementedError @abstractmethod diff --git a/src/superannotate/lib/core/usecases/annotations.py b/src/superannotate/lib/core/usecases/annotations.py index 854f1ebd0..0e5365e52 100644 --- a/src/superannotate/lib/core/usecases/annotations.py +++ b/src/superannotate/lib/core/usecases/annotations.py @@ -2165,10 +2165,10 @@ def _attach_categories(self, folder_id: int, item_id_category_map: Dict[int, str self._service_provider.work_management.create_project_categories( project_id=self._project.id, categories=categories_to_create, - ).data["data"] + ).data ) for c in _categories: - self._category_name_to_id_map[c["name"]] = c["id"] + self._category_name_to_id_map[c.name] = c.id for item_id, category_name in item_id_category_map.items(): with suppress(KeyError): item_id_category_id_map[item_id] = self._category_name_to_id_map[ diff --git a/src/superannotate/lib/infrastructure/controller.py b/src/superannotate/lib/infrastructure/controller.py index 48e6cea63..ae361ee39 100644 --- a/src/superannotate/lib/infrastructure/controller.py +++ b/src/superannotate/lib/infrastructure/controller.py @@ -1918,3 +1918,15 @@ def get_item( return self.get_item_by_id(item_id=item, project=project) else: return self.items.get_by_name(project, folder, item) + + def check_multimodal_project_categorization(self, project: ProjectEntity): + if project.type != ProjectType.MULTIMODAL: + raise AppException( + "This function is only supported for Multimodal projects." + ) + project_settings = self.service_provider.projects.list_settings(project).data + if not next( + (i.value for i in project_settings if i.attribute == "CategorizeItems"), + None, + ): + raise AppException("Item Category not enabled for project.") diff --git a/src/superannotate/lib/infrastructure/services/work_management.py b/src/superannotate/lib/infrastructure/services/work_management.py index 375a6932b..2edf6fcec 100644 --- a/src/superannotate/lib/infrastructure/services/work_management.py +++ b/src/superannotate/lib/infrastructure/services/work_management.py @@ -3,9 +3,11 @@ from typing import List from typing import Literal from typing import Optional +from typing import Union from lib.core.entities import CategoryEntity from lib.core.entities import WorkflowEntity +from lib.core.entities.project_entities import BaseEntity from lib.core.entities.work_managament import WMProjectEntity from lib.core.entities.work_managament import WMProjectUserEntity from lib.core.entities.work_managament import WMScoreEntity @@ -16,6 +18,7 @@ from lib.core.jsx_conditions import OperatorEnum from lib.core.jsx_conditions import Query from lib.core.service_types import ListCategoryResponse +from lib.core.service_types import ListProjectCategoryResponse from lib.core.service_types import ServiceResponse from lib.core.service_types import WMCustomFieldResponse from lib.core.service_types import WMProjectListResponse @@ -74,10 +77,12 @@ def _generate_context(**kwargs): encoded_context = base64.b64encode(json.dumps(kwargs).encode("utf-8")) return encoded_context.decode("utf-8") - def list_project_categories(self, project_id: int) -> ListCategoryResponse: - return self.client.paginate( - self.URL_LIST_CATEGORIES, - item_type=CategoryEntity, + def list_project_categories( + self, project_id: int, entity: BaseEntity = CategoryEntity + ) -> Union[ListCategoryResponse, ListProjectCategoryResponse]: + response = self.client.paginate( + url=self.URL_LIST_CATEGORIES, + item_type=entity, query_params={"project_id": project_id}, headers={ "x-sa-entity-context": self._generate_context( @@ -85,10 +90,12 @@ def list_project_categories(self, project_id: int) -> ListCategoryResponse: ), }, ) + response.raise_for_status() + return response def create_project_categories( self, project_id: int, categories: List[str] - ) -> ServiceResponse: + ) -> ListProjectCategoryResponse: response = self.client.request( method="post", url=self.URL_CREATE_CATEGORIES, @@ -99,6 +106,26 @@ def create_project_categories( team_id=self.client.team_id, project_id=project_id ), }, + content_type=ListProjectCategoryResponse, + dispatcher="data", + ) + response.raise_for_status() + return response + + def remove_project_categories( + self, project_id: int, query: Query + ) -> ListProjectCategoryResponse: + + response = self.client.request( + method="delete", + url=f"{self.URL_CREATE_CATEGORIES}?{query.build_query()}", + headers={ + "x-sa-entity-context": self._generate_context( + team_id=self.client.team_id, project_id=project_id + ), + }, + content_type=ListProjectCategoryResponse, + dispatcher="data", ) response.raise_for_status() return response diff --git a/tests/integration/work_management/test_project_categories.py b/tests/integration/work_management/test_project_categories.py new file mode 100644 index 000000000..fbb55ec34 --- /dev/null +++ b/tests/integration/work_management/test_project_categories.py @@ -0,0 +1,146 @@ +import json +import os +import time +from pathlib import Path +from unittest import TestCase + +from src.superannotate import SAClient + +sa = SAClient() + + +class TestProjectCategories(TestCase): + PROJECT_NAME = "TestProjectCategories" + PROJECT_TYPE = "Multimodal" + PROJECT_DESCRIPTION = "DESCRIPTION" + EDITOR_TEMPLATE_PATH = os.path.join( + Path(__file__).parent.parent.parent, + "data_set/editor_templates/form1.json", + ) + CLASSES_TEMPLATE_PATH = os.path.join( + Path(__file__).parent.parent.parent, + "data_set/editor_templates/form1_classes.json", + ) + + @classmethod + def setUpClass(cls, *args, **kwargs) -> None: + cls.tearDownClass() + cls._project = sa.create_project( + cls.PROJECT_NAME, + cls.PROJECT_DESCRIPTION, + cls.PROJECT_TYPE, + settings=[ + {"attribute": "TemplateState", "value": 1}, + {"attribute": "CategorizeItems", "value": 1}, + ], + ) + team = sa.controller.team + project = sa.controller.get_project(cls.PROJECT_NAME) + time.sleep(5) + + with open(cls.EDITOR_TEMPLATE_PATH) as f: + template_data = json.load(f) + res = sa.controller.service_provider.projects.attach_editor_template( + team, project, template=template_data + ) + assert res.ok + sa.create_annotation_classes_from_classes_json( + cls.PROJECT_NAME, cls.CLASSES_TEMPLATE_PATH + ) + + @classmethod + def tearDownClass(cls) -> None: + # cleanup test scores and project + projects = sa.search_projects(cls.PROJECT_NAME, return_metadata=True) + for project in projects: + try: + sa.delete_project(project) + except Exception: + pass + + def test_project_categories_flow(self): + with self.assertLogs("sa", level="INFO") as cm: + sa.create_categories( + project=self.PROJECT_NAME, + categories=["SDK_test_category_1", "SDK_test_category_2"], + ) + assert ( + "INFO:sa:2 categories successfully added to the project." + == cm.output[0] + ) + categories = sa.list_categories(project=self.PROJECT_NAME) + assert len(categories) == 2 + assert categories[0]["name"] == "SDK_test_category_1" + assert categories[1]["name"] == "SDK_test_category_2" + + # Check that each category has the expected keys + for category in categories: + assert "id" in category + assert "project_id" in category + assert "createdAt" in category + assert "updatedAt" in category + + # delete categories + with self.assertLogs("sa", level="INFO") as cm: + sa.remove_categories(project=self.PROJECT_NAME, categories="*") + assert ( + "INFO:sa:2 categories successfully removed from the project." + == cm.output[0] + ) + categories = sa.list_categories(project=self.PROJECT_NAME) + assert not categories + + def test_duplicate_categories_handling(self): + sa.create_categories( + project=self.PROJECT_NAME, + categories=[ + "Category_A", + "Category_B", + "category_a", + "Category_B", + "Category_A", + ], + ) + # Verify only unique categories were created + categories = sa.list_categories(project=self.PROJECT_NAME) + category_names = [category["name"] for category in categories] + + # Should only have two categories (first occurrences of each unique name) + assert len(categories) == 2, f"Expected 2 categories, got {len(categories)}" + assert ( + "Category_A" in category_names + ), "Category_A not found in created categories" + assert ( + "Category_B" in category_names + ), "Category_B not found in created categories" + assert ( + "category_a" not in category_names + ), "Duplicate category_a should not be created" + # Clean up + sa.remove_categories(project=self.PROJECT_NAME, categories="*") + + def test_category_name_length_limitation(self): + long_name = "A" * 250 # 250 characters + expected_truncated_length = 200 # Expected length after truncation + + # Create the category with the long name + sa.create_categories(project=self.PROJECT_NAME, categories=[long_name]) + + categories = sa.list_categories(project=self.PROJECT_NAME) + assert len(categories) == 1, "Expected 1 category to be created" + + created_category = categories[0] + assert len(created_category["name"]) == expected_truncated_length + assert created_category["name"] == long_name[:expected_truncated_length] + # Clean up + sa.remove_categories(project=self.PROJECT_NAME, categories="*") + + def test_delete_all_categories_with_asterisk(self): + sa.create_categories( + project=self.PROJECT_NAME, categories=["Cat1", "Cat2", "Cat3"] + ) + categories = sa.list_categories(project=self.PROJECT_NAME) + assert len(categories) == 3 + sa.remove_categories(project=self.PROJECT_NAME, categories="*") + categories = sa.list_categories(project=self.PROJECT_NAME) + assert len(categories) == 0