Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/superannotate/lib/app/interface/sdk_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3418,7 +3418,7 @@ def list_items(
exclude = {"meta", "annotator_email", "qa_email"}
if not include_custom_metadata:
exclude.add("custom_metadata")
return BaseSerializer.serialize_iterable(res, exclude=exclude)
return BaseSerializer.serialize_iterable(res, exclude=exclude, by_alias=False)

def list_projects(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/superannotate/lib/core/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lib.core.entities.classes import AnnotationClassEntity
from lib.core.entities.folder import FolderEntity
from lib.core.entities.integrations import IntegrationEntity
from lib.core.entities.items import CategoryEntity
from lib.core.entities.items import ClassificationEntity
from lib.core.entities.items import DocumentEntity
from lib.core.entities.items import ImageEntity
Expand All @@ -12,7 +13,6 @@
from lib.core.entities.items import TiledEntity
from lib.core.entities.items import VideoEntity
from lib.core.entities.project import AttachmentEntity
from lib.core.entities.project import CategoryEntity
from lib.core.entities.project import ContributorEntity
from lib.core.entities.project import CustomFieldEntity
from lib.core.entities.project import ProjectEntity
Expand Down
2 changes: 2 additions & 0 deletions src/superannotate/lib/core/entities/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class ItemFilters(BaseFilters):
assignments__user_role__in: Optional[List[str]]
assignments__user_role__ne: Optional[str]
assignments__user_role__notin: Optional[List[str]]
categories__value: Optional[str]
categories__value__in: Optional[List[str]]


class ProjectFilters(BaseFilters):
Expand Down
12 changes: 10 additions & 2 deletions src/superannotate/lib/core/entities/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

from lib.core.entities.base import BaseItemEntity
from lib.core.entities.base import TimedBaseModel
from lib.core.entities.project import TimedBaseModel
from lib.core.enums import ApprovalStatus
from lib.core.enums import ProjectType
from lib.core.pydantic_v1 import Extra
Expand All @@ -18,9 +18,17 @@ class Config:
extra = Extra.ignore


class CategoryEntity(TimedBaseModel):
id: int
value: str = Field(None, alias="name")

class Config:
extra = Extra.ignore


class MultiModalItemCategoryEntity(TimedBaseModel):
id: int = Field(None, alias="category_id")
name: str = Field(None, alias="category_name")
value: str = Field(None, alias="category_name")

class Config:
extra = Extra.ignore
Expand Down
5 changes: 0 additions & 5 deletions src/superannotate/lib/core/entities/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,3 @@ def is_system(self):

class Config:
extra = Extra.ignore


class CategoryEntity(BaseModel):
id: Optional[int]
name: Optional[str]
6 changes: 6 additions & 0 deletions src/superannotate/lib/core/serviceproviders.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,12 @@ class BaseServiceProvider:
def get_role_id(self, project: entities.ProjectEntity, role_name: str) -> int:
raise NotImplementedError

@abstractmethod
def get_category_id(
self, project: entities.ProjectEntity, category_name: str
) -> int:
raise NotImplementedError

@abstractmethod
def get_role_name(self, project: entities.ProjectEntity, role_id: int) -> str:
raise NotImplementedError
Expand Down
6 changes: 4 additions & 2 deletions src/superannotate/lib/core/usecases/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,8 +2098,10 @@ def execute(self):
if categorization_enabled:
item_id_category_map = {}
for item_name in uploaded_annotations:
category = name_annotation_map[item_name]["metadata"].get(
"item_category"
category = (
name_annotation_map[item_name]["metadata"]
.get("item_category", {})
.get("value")
)
if category:
item_id_category_map[name_item_map[item_name].id] = category
Expand Down
4 changes: 3 additions & 1 deletion src/superannotate/lib/infrastructure/annotation_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def get_component_value(self, component_id: str):
return None

def set_component_value(self, component_id: str, value: Any):
self.annotation.setdefault("data", {}).setdefault(component_id, {})["value"] = value
self.annotation.setdefault("data", {}).setdefault(component_id, {})[
"value"
] = value
return self


Expand Down
10 changes: 10 additions & 0 deletions src/superannotate/lib/infrastructure/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def handle(self, filters: Dict[str, Any], query: Query = None) -> Query:
for key, val in filters.items():
_keys = key.split("__")
val = self._handle_special_fields(_keys, val)
if _keys[0] == "categories" and _keys[1] == "value":
_keys[1] = "category_id"
condition, _key = determine_condition_and_key(_keys)
query &= Filter(_key, val, condition)
return super().handle(filters, query)
Expand Down Expand Up @@ -147,6 +149,14 @@ def _handle_special_fields(self, keys: List[str], val):
]
else:
val = self._service_provider.get_role_id(self._project, val)
elif keys[0] == "categories" and keys[1] == "value":
if isinstance(val, list):
val = [
self._service_provider.get_category_id(self._project, i)
for i in val
]
else:
val = self._service_provider.get_category_id(self._project, val)
return val


Expand Down
7 changes: 7 additions & 0 deletions src/superannotate/lib/infrastructure/serviceprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def list_custom_field_names(self, entity: CustomFieldEntityEnum) -> List[str]:
self.client.team_id, entity=entity
)

def get_category_id(
self, project: entities.ProjectEntity, category_name: str
) -> int:
return self._cached_work_management_repository.get_category_id(
project, category_name
)

def get_custom_field_id(
self, field_name: str, entity: CustomFieldEntityEnum
) -> int:
Expand Down
24 changes: 24 additions & 0 deletions src/superannotate/lib/infrastructure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,23 @@ def get(self, key, **kwargs):
return self._K_V_map[key]


class CategoryCache(BaseCachedWorkManagementRepository):
def sync(self, project: ProjectEntity):
response = self.work_management.list_project_categories(project.id)
if not response.ok:
raise AppException(response.error)
categories = response.data
self._K_V_map[project.id] = {
"category_name_id_map": {
category.value: category.id for category in categories
},
"category_id_name_map": {
category.id: category.value for category in categories
},
}
self._update_cache_timestamp(project.id)


class RoleCache(BaseCachedWorkManagementRepository):
def sync(self, project: ProjectEntity):
response = self.work_management.list_workflow_roles(
Expand Down Expand Up @@ -221,6 +238,7 @@ def get(self, key, **kwargs):

class CachedWorkManagementRepository:
def __init__(self, ttl_seconds: int, work_management):
self._category_cache = CategoryCache(ttl_seconds, work_management)
self._role_cache = RoleCache(ttl_seconds, work_management)
self._status_cache = StatusCache(ttl_seconds, work_management)
self._project_custom_field_cache = CustomFieldCache(
Expand All @@ -236,6 +254,12 @@ def __init__(self, ttl_seconds: int, work_management):
CustomFieldEntityEnum.TEAM,
)

def get_category_id(self, project, category_name: str) -> int:
data = self._category_cache.get(project.id, project=project)
if category_name in data["category_name_id_map"]:
return data["category_name_id_map"][category_name]
raise AppException("Invalid category provided.")

def get_role_id(self, project, role_name: str) -> int:
role_data = self._role_cache.get(project.id, project=project)
if role_name in role_data["role_name_id_map"]:
Expand Down
61 changes: 61 additions & 0 deletions tests/integration/items/test_list_items.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import os
import random
import string
import time
from pathlib import Path

from src.superannotate import AppException
from src.superannotate import SAClient
from tests import DATA_SET_PATH
from tests.integration.base import BaseTestCase

sa = SAClient()
Expand Down Expand Up @@ -61,3 +64,61 @@ def test_list_items_URL_limit(self):
sa.attach_items(self.PROJECT_NAME, items_for_attache)
items = sa.list_items(self.PROJECT_NAME, name__in=item_names)
assert len(items) == 125


class TestListItemsMultimodal(BaseTestCase):
PROJECT_NAME = "TestListItemsMultimodal"
PROJECT_DESCRIPTION = "TestSearchItems"
PROJECT_TYPE = "Multimodal"
TEST_FOLDER_PATH = "data_set/sample_project_vector"
CATEGORIES = ["c_1", "c_2", "c_3"]
ANNOTATIONS = [
{"metadata": {"name": "item_1", "item_category": {"value": "c1"}}, "data": {}},
{"metadata": {"name": "item_2", "item_category": {"value": "c2"}}, "data": {}},
{"metadata": {"name": "item_3", "item_category": {"value": "c3"}}, "data": {}},
]
CLASSES_TEMPLATE_PATH = DATA_SET_PATH / "editor_templates/from1_classes.json"
EDITOR_TEMPLATE_PATH = DATA_SET_PATH / "editor_templates/form1.json"

def setUp(self, *args, **kwargs):
self.tearDown()
self._project = sa.create_project(
self.PROJECT_NAME,
self.PROJECT_DESCRIPTION,
"Multimodal",
settings=[
{"attribute": "CategorizeItems", "value": 1},
{"attribute": "TemplateState", "value": 1},
],
)
project = sa.controller.get_project(self.PROJECT_NAME)
time.sleep(10)
with open(self.EDITOR_TEMPLATE_PATH) as f:
res = sa.controller.service_provider.projects.attach_editor_template(
sa.controller.team, project, template=json.load(f)
)
assert res.ok
sa.create_annotation_classes_from_classes_json(
self.PROJECT_NAME, self.CLASSES_TEMPLATE_PATH
)

def test_list_category_filter(self):
sa.upload_annotations(
self.PROJECT_NAME, self.ANNOTATIONS, data_spec="multimodal"
)
items = sa.list_items(
self.PROJECT_NAME,
include=["categories"],
categories__value__in=["c1", "c2"],
)
assert [i["categories"][0]["value"] for i in items] == ["c1", "c2"]
assert (
len(
sa.list_items(
self.PROJECT_NAME,
include=["categories"],
categories__value__in=["c3"],
)
)
== 1
)