Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/api_reference/api_project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ Projects
.. automethod:: superannotate.SAClient.get_project_steps
.. automethod:: superannotate.SAClient.set_project_steps
.. automethod:: superannotate.SAClient.get_component_config
.. automethod:: superannotate.SAClient.create_categories
.. automethod:: superannotate.SAClient.list_categories
.. automethod:: superannotate.SAClient.remove_categories
150 changes: 150 additions & 0 deletions src/superannotate/lib/app/interface/sdk_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
from lib.app.serializers import WMProjectSerializer
from lib.core.entities.work_managament import WMUserTypeEnum
from lib.core.jsx_conditions import EmptyQuery
from lib.core.entities.items import ProjectCategoryEntity


logger = logging.getLogger("sa")

Expand Down Expand Up @@ -1194,6 +1196,154 @@ def clone_project(
)
return data

def create_categories(
self, project: Union[NotEmptyStr, int], categories: List[str]
):
"""
Create one or more categories in a project.

:param project: The name or ID of the project.
:type project: Union[NotEmptyStr, int]

:param categories: A list of categories to create
:type categories: list of str

Request Example:
::

client.create_categories(
project="product-review-mm",
categories=["Shoes", "T-Shirt"]
)
"""
project = (
self.controller.get_project_by_id(project).data
if isinstance(project, int)
else self.controller.get_project(project)
)
self.controller.check_multimodal_project_categorization(project)

response = (
self.controller.service_provider.work_management.create_project_categories(
project_id=project.id, categories=categories
)
)
logger.info(
f"{len(response.data)} categories successfully added to the project."
)

def list_categories(self, project: Union[NotEmptyStr, int]):
"""
List all categories in the project.

:param project: The name or ID of the project.
:type project: Union[NotEmptyStr, int]

:return: List of categories
:rtype: list of dict

Request Example:
::

client.list_categories(
project="product-review-mm"
)

Response Example:
::

[
{
"createdAt": "2025-01-29T13:51:39.000Z",
"updatedAt": "2025-01-29T13:51:39.000Z",
"id": 328577,
"name": "category1",
"project_id": 1234
},
{
"createdAt": "2025-01-29T13:51:39.000Z",
"updatedAt": "2025-01-29T13:51:39.000Z",
"id": 328577,
"name": "category2",
"project_id": 1234
},
]

"""
project = (
self.controller.get_project_by_id(project).data
if isinstance(project, int)
else self.controller.get_project(project)
)
self.controller.check_multimodal_project_categorization(project)

response = (
self.controller.service_provider.work_management.list_project_categories(
project_id=project.id, entity=ProjectCategoryEntity
)
)
return BaseSerializer.serialize_iterable(response.data)

def remove_categories(
self,
project: Union[NotEmptyStr, int],
categories: Union[List[str], Literal["*"]],
):
"""
Remove one or more categories in a project. "*" in the category list will match all categories defined in the project.


:param project: The name or ID of the project.
:type project: Union[NotEmptyStr, int]

:param categories: A list of categories to remove, Accepts "*" to indicate all available categories in the project.
:type categories: Union[List[str], Literal["*"]]

Request Example:
::

client.remove_categories(
project="product-review-mm",
categories=["Shoes", "T-Shirt"]
)

# To remove all categories
client.remove_categories(
project="product-review-mm",
categories="*"
)
"""
project = (
self.controller.get_project_by_id(project).data
if isinstance(project, int)
else self.controller.get_project(project)
)
self.controller.check_multimodal_project_categorization(project)

query = EmptyQuery()
if categories == "*":
query &= Filter("id", [0], OperatorEnum.GT)
elif categories and isinstance(categories, list):
categories = [c.lower() for c in categories]
all_categories = self.controller.service_provider.work_management.list_project_categories(
project_id=project.id, entity=ProjectCategoryEntity
)
categories_to_remove = [
c for c in all_categories.data if c.name.lower() in categories
]
query &= Filter("id", [c.id for c in categories_to_remove], OperatorEnum.IN)
else:
raise AppException("Categories should be a list of strings or '*'.")

response = (
self.controller.service_provider.work_management.remove_project_categories(
project_id=project.id, query=query
)
)
logger.info(
f"{len(response.data)} categories successfully removed from the project."
)

def create_folder(self, project: NotEmptyStr, folder_name: NotEmptyStr):
"""
Create a new folder in the project.
Expand Down
9 changes: 9 additions & 0 deletions src/superannotate/lib/core/entities/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ class Config:
extra = Extra.ignore


class ProjectCategoryEntity(TimedBaseModel):
id: int
name: str
project_id: int

class Config:
extra = Extra.ignore


class MultiModalItemCategoryEntity(TimedBaseModel):
id: int = Field(None, alias="category_id")
value: str = Field(None, alias="category_name")
Expand Down
4 changes: 4 additions & 0 deletions src/superannotate/lib/core/service_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ class ListCategoryResponse(ServiceResponse):
res_data: List[entities.CategoryEntity] = None


class ListProjectCategoryResponse(ServiceResponse):
res_data: List[entities.items.ProjectCategoryEntity] = None


class WorkflowResponse(ServiceResponse):
res_data: entities.WorkflowEntity = None

Expand Down
16 changes: 14 additions & 2 deletions src/superannotate/lib/core/serviceproviders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
from typing import List
from typing import Literal
from typing import Optional
from typing import Union

from lib.core import entities
from lib.core.conditions import Condition
from lib.core.entities import CategoryEntity
from lib.core.entities.project_entities import BaseEntity
from lib.core.enums import CustomFieldEntityEnum
from lib.core.jsx_conditions import Query
from lib.core.reporter import Reporter
Expand All @@ -18,6 +21,7 @@
from lib.core.service_types import FolderResponse
from lib.core.service_types import IntegrationListResponse
from lib.core.service_types import ListCategoryResponse
from lib.core.service_types import ListProjectCategoryResponse
from lib.core.service_types import ProjectListResponse
from lib.core.service_types import ProjectResponse
from lib.core.service_types import ServiceResponse
Expand Down Expand Up @@ -137,13 +141,21 @@ def search_projects(
raise NotImplementedError

@abstractmethod
def list_project_categories(self, project_id: int) -> ListCategoryResponse:
def list_project_categories(
self, project_id: int, entity: BaseEntity = CategoryEntity
) -> Union[ListCategoryResponse, ListProjectCategoryResponse]:
raise NotImplementedError

@abstractmethod
def remove_project_categories(
self, project_id: int, query: Query
) -> ListProjectCategoryResponse:
raise NotImplementedError

@abstractmethod
def create_project_categories(
self, project_id: int, categories: List[str]
) -> ServiceResponse:
) -> ListProjectCategoryResponse:
raise NotImplementedError

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions src/superannotate/lib/core/usecases/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,10 +2165,10 @@ def _attach_categories(self, folder_id: int, item_id_category_map: Dict[int, str
self._service_provider.work_management.create_project_categories(
project_id=self._project.id,
categories=categories_to_create,
).data["data"]
).data
)
for c in _categories:
self._category_name_to_id_map[c["name"]] = c["id"]
self._category_name_to_id_map[c.name] = c.id
for item_id, category_name in item_id_category_map.items():
with suppress(KeyError):
item_id_category_id_map[item_id] = self._category_name_to_id_map[
Expand Down
12 changes: 12 additions & 0 deletions src/superannotate/lib/infrastructure/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,3 +1918,15 @@ def get_item(
return self.get_item_by_id(item_id=item, project=project)
else:
return self.items.get_by_name(project, folder, item)

def check_multimodal_project_categorization(self, project: ProjectEntity):
if project.type != ProjectType.MULTIMODAL:
raise AppException(
"This function is only supported for Multimodal projects."
)
project_settings = self.service_provider.projects.list_settings(project).data
if not next(
(i.value for i in project_settings if i.attribute == "CategorizeItems"),
None,
):
raise AppException("Item Category not enabled for project.")
37 changes: 32 additions & 5 deletions src/superannotate/lib/infrastructure/services/work_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import List
from typing import Literal
from typing import Optional
from typing import Union

from lib.core.entities import CategoryEntity
from lib.core.entities import WorkflowEntity
from lib.core.entities.project_entities import BaseEntity
from lib.core.entities.work_managament import WMProjectEntity
from lib.core.entities.work_managament import WMProjectUserEntity
from lib.core.entities.work_managament import WMScoreEntity
Expand All @@ -16,6 +18,7 @@
from lib.core.jsx_conditions import OperatorEnum
from lib.core.jsx_conditions import Query
from lib.core.service_types import ListCategoryResponse
from lib.core.service_types import ListProjectCategoryResponse
from lib.core.service_types import ServiceResponse
from lib.core.service_types import WMCustomFieldResponse
from lib.core.service_types import WMProjectListResponse
Expand Down Expand Up @@ -74,21 +77,25 @@ def _generate_context(**kwargs):
encoded_context = base64.b64encode(json.dumps(kwargs).encode("utf-8"))
return encoded_context.decode("utf-8")

def list_project_categories(self, project_id: int) -> ListCategoryResponse:
return self.client.paginate(
self.URL_LIST_CATEGORIES,
item_type=CategoryEntity,
def list_project_categories(
self, project_id: int, entity: BaseEntity = CategoryEntity
) -> Union[ListCategoryResponse, ListProjectCategoryResponse]:
response = self.client.paginate(
url=self.URL_LIST_CATEGORIES,
item_type=entity,
query_params={"project_id": project_id},
headers={
"x-sa-entity-context": self._generate_context(
team_id=self.client.team_id
),
},
)
response.raise_for_status()
return response

def create_project_categories(
self, project_id: int, categories: List[str]
) -> ServiceResponse:
) -> ListProjectCategoryResponse:
response = self.client.request(
method="post",
url=self.URL_CREATE_CATEGORIES,
Expand All @@ -99,6 +106,26 @@ def create_project_categories(
team_id=self.client.team_id, project_id=project_id
),
},
content_type=ListProjectCategoryResponse,
dispatcher="data",
)
response.raise_for_status()
return response

def remove_project_categories(
self, project_id: int, query: Query
) -> ListProjectCategoryResponse:

response = self.client.request(
method="delete",
url=f"{self.URL_CREATE_CATEGORIES}?{query.build_query()}",
headers={
"x-sa-entity-context": self._generate_context(
team_id=self.client.team_id, project_id=project_id
),
},
content_type=ListProjectCategoryResponse,
dispatcher="data",
)
response.raise_for_status()
return response
Expand Down
Loading