From 04904fb7310d04095069c33be4d4ebe22eccdb52 Mon Sep 17 00:00:00 2001 From: Vaghinak Basentsyan Date: Tue, 4 Mar 2025 13:35:52 +0400 Subject: [PATCH] Added ability to filter by item category --- .../lib/app/interface/sdk_interface.py | 2 +- .../lib/core/entities/__init__.py | 2 +- .../lib/core/entities/filters.py | 2 + src/superannotate/lib/core/entities/items.py | 12 +++- .../lib/core/entities/project.py | 5 -- .../lib/core/serviceproviders.py | 6 ++ .../lib/core/usecases/annotations.py | 6 +- .../lib/infrastructure/annotation_adapter.py | 4 +- .../lib/infrastructure/query_builder.py | 10 +++ .../lib/infrastructure/serviceprovider.py | 7 +++ src/superannotate/lib/infrastructure/utils.py | 24 ++++++++ tests/integration/items/test_list_items.py | 61 +++++++++++++++++++ 12 files changed, 129 insertions(+), 12 deletions(-) diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index 2dfe694b6..4eb343861 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -3418,7 +3418,7 @@ def list_items( exclude = {"meta", "annotator_email", "qa_email"} if not include_custom_metadata: exclude.add("custom_metadata") - return BaseSerializer.serialize_iterable(res, exclude=exclude) + return BaseSerializer.serialize_iterable(res, exclude=exclude, by_alias=False) def list_projects( self, diff --git a/src/superannotate/lib/core/entities/__init__.py b/src/superannotate/lib/core/entities/__init__.py index 008822c41..c4666973d 100644 --- a/src/superannotate/lib/core/entities/__init__.py +++ b/src/superannotate/lib/core/entities/__init__.py @@ -4,6 +4,7 @@ from lib.core.entities.classes import AnnotationClassEntity from lib.core.entities.folder import FolderEntity from lib.core.entities.integrations import IntegrationEntity +from lib.core.entities.items import CategoryEntity from lib.core.entities.items import ClassificationEntity from lib.core.entities.items import DocumentEntity from lib.core.entities.items import ImageEntity @@ -12,7 +13,6 @@ from lib.core.entities.items import TiledEntity from lib.core.entities.items import VideoEntity from lib.core.entities.project import AttachmentEntity -from lib.core.entities.project import CategoryEntity from lib.core.entities.project import ContributorEntity from lib.core.entities.project import CustomFieldEntity from lib.core.entities.project import ProjectEntity diff --git a/src/superannotate/lib/core/entities/filters.py b/src/superannotate/lib/core/entities/filters.py index 8f6d2369d..0fe71a1d3 100644 --- a/src/superannotate/lib/core/entities/filters.py +++ b/src/superannotate/lib/core/entities/filters.py @@ -29,6 +29,8 @@ class ItemFilters(BaseFilters): assignments__user_role__in: Optional[List[str]] assignments__user_role__ne: Optional[str] assignments__user_role__notin: Optional[List[str]] + categories__value: Optional[str] + categories__value__in: Optional[List[str]] class ProjectFilters(BaseFilters): diff --git a/src/superannotate/lib/core/entities/items.py b/src/superannotate/lib/core/entities/items.py index 3d9f25d19..477cab8dd 100644 --- a/src/superannotate/lib/core/entities/items.py +++ b/src/superannotate/lib/core/entities/items.py @@ -2,7 +2,7 @@ from typing import Optional from lib.core.entities.base import BaseItemEntity -from lib.core.entities.base import TimedBaseModel +from lib.core.entities.project import TimedBaseModel from lib.core.enums import ApprovalStatus from lib.core.enums import ProjectType from lib.core.pydantic_v1 import Extra @@ -18,9 +18,17 @@ class Config: extra = Extra.ignore +class CategoryEntity(TimedBaseModel): + id: int + value: str = Field(None, alias="name") + + class Config: + extra = Extra.ignore + + class MultiModalItemCategoryEntity(TimedBaseModel): id: int = Field(None, alias="category_id") - name: str = Field(None, alias="category_name") + value: str = Field(None, alias="category_name") class Config: extra = Extra.ignore diff --git a/src/superannotate/lib/core/entities/project.py b/src/superannotate/lib/core/entities/project.py index f8306d698..840abf56b 100644 --- a/src/superannotate/lib/core/entities/project.py +++ b/src/superannotate/lib/core/entities/project.py @@ -187,8 +187,3 @@ def is_system(self): class Config: extra = Extra.ignore - - -class CategoryEntity(BaseModel): - id: Optional[int] - name: Optional[str] diff --git a/src/superannotate/lib/core/serviceproviders.py b/src/superannotate/lib/core/serviceproviders.py index d8fbab28e..5de7fc235 100644 --- a/src/superannotate/lib/core/serviceproviders.py +++ b/src/superannotate/lib/core/serviceproviders.py @@ -696,6 +696,12 @@ class BaseServiceProvider: def get_role_id(self, project: entities.ProjectEntity, role_name: str) -> int: raise NotImplementedError + @abstractmethod + def get_category_id( + self, project: entities.ProjectEntity, category_name: str + ) -> int: + raise NotImplementedError + @abstractmethod def get_role_name(self, project: entities.ProjectEntity, role_id: int) -> str: raise NotImplementedError diff --git a/src/superannotate/lib/core/usecases/annotations.py b/src/superannotate/lib/core/usecases/annotations.py index faafbb63c..90bd5ba67 100644 --- a/src/superannotate/lib/core/usecases/annotations.py +++ b/src/superannotate/lib/core/usecases/annotations.py @@ -2098,8 +2098,10 @@ def execute(self): if categorization_enabled: item_id_category_map = {} for item_name in uploaded_annotations: - category = name_annotation_map[item_name]["metadata"].get( - "item_category" + category = ( + name_annotation_map[item_name]["metadata"] + .get("item_category", {}) + .get("value") ) if category: item_id_category_map[name_item_map[item_name].id] = category diff --git a/src/superannotate/lib/infrastructure/annotation_adapter.py b/src/superannotate/lib/infrastructure/annotation_adapter.py index c333231cf..13436510a 100644 --- a/src/superannotate/lib/infrastructure/annotation_adapter.py +++ b/src/superannotate/lib/infrastructure/annotation_adapter.py @@ -44,7 +44,9 @@ def get_component_value(self, component_id: str): return None def set_component_value(self, component_id: str, value: Any): - self.annotation.setdefault("data", {}).setdefault(component_id, {})["value"] = value + self.annotation.setdefault("data", {}).setdefault(component_id, {})[ + "value" + ] = value return self diff --git a/src/superannotate/lib/infrastructure/query_builder.py b/src/superannotate/lib/infrastructure/query_builder.py index 4473e502b..bb83e3009 100644 --- a/src/superannotate/lib/infrastructure/query_builder.py +++ b/src/superannotate/lib/infrastructure/query_builder.py @@ -113,6 +113,8 @@ def handle(self, filters: Dict[str, Any], query: Query = None) -> Query: for key, val in filters.items(): _keys = key.split("__") val = self._handle_special_fields(_keys, val) + if _keys[0] == "categories" and _keys[1] == "value": + _keys[1] = "category_id" condition, _key = determine_condition_and_key(_keys) query &= Filter(_key, val, condition) return super().handle(filters, query) @@ -147,6 +149,14 @@ def _handle_special_fields(self, keys: List[str], val): ] else: val = self._service_provider.get_role_id(self._project, val) + elif keys[0] == "categories" and keys[1] == "value": + if isinstance(val, list): + val = [ + self._service_provider.get_category_id(self._project, i) + for i in val + ] + else: + val = self._service_provider.get_category_id(self._project, val) return val diff --git a/src/superannotate/lib/infrastructure/serviceprovider.py b/src/superannotate/lib/infrastructure/serviceprovider.py index 641111a49..01ab57a78 100644 --- a/src/superannotate/lib/infrastructure/serviceprovider.py +++ b/src/superannotate/lib/infrastructure/serviceprovider.py @@ -79,6 +79,13 @@ def list_custom_field_names(self, entity: CustomFieldEntityEnum) -> List[str]: self.client.team_id, entity=entity ) + def get_category_id( + self, project: entities.ProjectEntity, category_name: str + ) -> int: + return self._cached_work_management_repository.get_category_id( + project, category_name + ) + def get_custom_field_id( self, field_name: str, entity: CustomFieldEntityEnum ) -> int: diff --git a/src/superannotate/lib/infrastructure/utils.py b/src/superannotate/lib/infrastructure/utils.py index eaa75e80d..9508e6c40 100644 --- a/src/superannotate/lib/infrastructure/utils.py +++ b/src/superannotate/lib/infrastructure/utils.py @@ -138,6 +138,23 @@ def get(self, key, **kwargs): return self._K_V_map[key] +class CategoryCache(BaseCachedWorkManagementRepository): + def sync(self, project: ProjectEntity): + response = self.work_management.list_project_categories(project.id) + if not response.ok: + raise AppException(response.error) + categories = response.data + self._K_V_map[project.id] = { + "category_name_id_map": { + category.value: category.id for category in categories + }, + "category_id_name_map": { + category.id: category.value for category in categories + }, + } + self._update_cache_timestamp(project.id) + + class RoleCache(BaseCachedWorkManagementRepository): def sync(self, project: ProjectEntity): response = self.work_management.list_workflow_roles( @@ -221,6 +238,7 @@ def get(self, key, **kwargs): class CachedWorkManagementRepository: def __init__(self, ttl_seconds: int, work_management): + self._category_cache = CategoryCache(ttl_seconds, work_management) self._role_cache = RoleCache(ttl_seconds, work_management) self._status_cache = StatusCache(ttl_seconds, work_management) self._project_custom_field_cache = CustomFieldCache( @@ -236,6 +254,12 @@ def __init__(self, ttl_seconds: int, work_management): CustomFieldEntityEnum.TEAM, ) + def get_category_id(self, project, category_name: str) -> int: + data = self._category_cache.get(project.id, project=project) + if category_name in data["category_name_id_map"]: + return data["category_name_id_map"][category_name] + raise AppException("Invalid category provided.") + def get_role_id(self, project, role_name: str) -> int: role_data = self._role_cache.get(project.id, project=project) if role_name in role_data["role_name_id_map"]: diff --git a/tests/integration/items/test_list_items.py b/tests/integration/items/test_list_items.py index 7d2cdabf9..f0a80fb79 100644 --- a/tests/integration/items/test_list_items.py +++ b/tests/integration/items/test_list_items.py @@ -1,10 +1,13 @@ +import json import os import random import string +import time from pathlib import Path from src.superannotate import AppException from src.superannotate import SAClient +from tests import DATA_SET_PATH from tests.integration.base import BaseTestCase sa = SAClient() @@ -61,3 +64,61 @@ def test_list_items_URL_limit(self): sa.attach_items(self.PROJECT_NAME, items_for_attache) items = sa.list_items(self.PROJECT_NAME, name__in=item_names) assert len(items) == 125 + + +class TestListItemsMultimodal(BaseTestCase): + PROJECT_NAME = "TestListItemsMultimodal" + PROJECT_DESCRIPTION = "TestSearchItems" + PROJECT_TYPE = "Multimodal" + TEST_FOLDER_PATH = "data_set/sample_project_vector" + CATEGORIES = ["c_1", "c_2", "c_3"] + ANNOTATIONS = [ + {"metadata": {"name": "item_1", "item_category": {"value": "c1"}}, "data": {}}, + {"metadata": {"name": "item_2", "item_category": {"value": "c2"}}, "data": {}}, + {"metadata": {"name": "item_3", "item_category": {"value": "c3"}}, "data": {}}, + ] + CLASSES_TEMPLATE_PATH = DATA_SET_PATH / "editor_templates/from1_classes.json" + EDITOR_TEMPLATE_PATH = DATA_SET_PATH / "editor_templates/form1.json" + + def setUp(self, *args, **kwargs): + self.tearDown() + self._project = sa.create_project( + self.PROJECT_NAME, + self.PROJECT_DESCRIPTION, + "Multimodal", + settings=[ + {"attribute": "CategorizeItems", "value": 1}, + {"attribute": "TemplateState", "value": 1}, + ], + ) + project = sa.controller.get_project(self.PROJECT_NAME) + time.sleep(10) + with open(self.EDITOR_TEMPLATE_PATH) as f: + res = sa.controller.service_provider.projects.attach_editor_template( + sa.controller.team, project, template=json.load(f) + ) + assert res.ok + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, self.CLASSES_TEMPLATE_PATH + ) + + def test_list_category_filter(self): + sa.upload_annotations( + self.PROJECT_NAME, self.ANNOTATIONS, data_spec="multimodal" + ) + items = sa.list_items( + self.PROJECT_NAME, + include=["categories"], + categories__value__in=["c1", "c2"], + ) + assert [i["categories"][0]["value"] for i in items] == ["c1", "c2"] + assert ( + len( + sa.list_items( + self.PROJECT_NAME, + include=["categories"], + categories__value__in=["c3"], + ) + ) + == 1 + )