Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions src/superannotate/lib/app/interface/sdk_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@
from lib.core import LIMITED_FUNCTIONS
from lib.core.entities import AttachmentEntity
from lib.core.entities import SettingEntity
from lib.core.entities.classes import AnnotationClassEntity
from lib.core.entities.integrations import IntegrationEntity
from lib.core.entities.project_entities import AnnotationClassEntity
from lib.core.enums import ImageQuality
from lib.core.exceptions import AppException
from lib.core.types import AttributeGroup
from lib.core.types import MLModel
from lib.core.types import PriorityScore
from lib.core.types import Project
from lib.infrastructure.controller import Controller
from lib.infrastructure.validators import wrap_error
from pydantic import conlist
from pydantic import parse_obj_as
from pydantic import StrictBool
Expand Down Expand Up @@ -580,8 +581,7 @@ def search_annotation_classes(
"""
project_name, folder_name = extract_project_folder(project)
classes = self.controller.search_annotation_classes(project_name, name_contains)
classes = [BaseSerializer(attribute).serialize() for attribute in classes.data]
return classes
return BaseSerializer.serialize_iterable(classes.data)

def set_project_default_image_quality_in_editor(
self,
Expand Down Expand Up @@ -1160,12 +1160,18 @@ def create_annotation_class(
attribute_groups = (
list(map(lambda x: x.dict(), attribute_groups)) if attribute_groups else []
)
try:
annotation_class = AnnotationClassEntity(
name=name,
color=color, # noqa
attribute_groups=attribute_groups,
type=class_type, # noqa
)
except ValidationError as e:
raise AppException(wrap_error(e))

response = self.controller.create_annotation_class(
project_name=project,
name=name,
color=color,
attribute_groups=attribute_groups,
class_type=class_type,
project_name=project, annotation_class=annotation_class
)
if response.errors:
raise AppException(response.errors)
Expand Down Expand Up @@ -1238,9 +1244,8 @@ def create_annotation_classes_from_classes_json(
classes_json = json.load(data)
try:
annotation_classes = parse_obj_as(List[AnnotationClassEntity], classes_json)
except ValidationError:
except ValidationError as _:
raise AppException("Couldn't validate annotation classes.")
logger.info(f"Creating annotation classes in project {project}.")
response = self.controller.create_annotation_classes(
project_name=project,
annotation_classes=annotation_classes,
Expand Down
2 changes: 2 additions & 0 deletions src/superannotate/lib/app/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def _serialize(
flat: bool = False,
exclude: Set[str] = None,
):
if not entity:
return None
if isinstance(entity, dict):
return entity
if isinstance(entity, BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion src/superannotate/lib/core/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from lib.core.entities.base import ProjectEntity
from lib.core.entities.base import SettingEntity
from lib.core.entities.base import SubSetEntity
from lib.core.entities.classes import AnnotationClassEntity
from lib.core.entities.integrations import IntegrationEntity
from lib.core.entities.items import DocumentEntity
from lib.core.entities.items import TmpImageEntity
from lib.core.entities.items import VideoEntity
from lib.core.entities.project_entities import AnnotationClassEntity
from lib.core.entities.project_entities import BaseEntity
from lib.core.entities.project_entities import ConfigEntity
from lib.core.entities.project_entities import FolderEntity
Expand Down
95 changes: 95 additions & 0 deletions src/superannotate/lib/core/entities/classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from enum import Enum
from typing import List
from typing import Optional

from lib.core.enums import BaseTitledEnum
from lib.core.enums import ClassTypeEnum
from pydantic import BaseModel as BasePydanticModel
from pydantic import Extra
from pydantic import StrictInt
from pydantic import StrictStr
from pydantic import validator
from pydantic.color import Color
from pydantic.color import ColorType


DATE_REGEX = r"\d{4}-[01]\d-[0-3]\dT[0-2]\d:[0-5]\d:[0-5]\d(?:\.\d{3})Z"
DATE_TIME_FORMAT_ERROR_MESSAGE = (
"does not match expected format YYYY-MM-DDTHH:MM:SS.fffZ"
)


class HexColor(BasePydanticModel):
__root__: ColorType

@validator("__root__")
def validate_color(cls, v):
return "#{:02X}{:02X}{:02X}".format(*Color(v).as_rgb_tuple())


class BaseModel(BasePydanticModel):
class Config:
extra = Extra.allow
error_msg_templates = {
"type_error.integer": "integer type expected",
"type_error.string": "str type expected",
"value_error.missing": "field required",
}

def dict(self, *args, fill_enum_values=False, **kwargs):
data = super().dict(*args, **kwargs)
if fill_enum_values:
data = self._fill_enum_values(data)
return data

@staticmethod
def _fill_enum_values(data: dict) -> dict:
for key, val in data.items():
if isinstance(val, BaseTitledEnum):
data[key] = val.__doc__
return data


class GroupTypeEnum(str, Enum):
RADIO = "radio"
CHECKLIST = "checklist"
NUMERIC = "numeric"
TEXT = "text"


class Attribute(BaseModel):
id: Optional[StrictInt]
group_id: Optional[StrictInt]
project_id: Optional[StrictInt]
name: StrictStr

def __hash__(self):
return hash(f"{self.id}{self.group_id}{self.name}")


class AttributeGroup(BaseModel):
id: Optional[StrictInt]
group_type: Optional[GroupTypeEnum]
class_id: Optional[StrictInt]
name: StrictStr
is_multiselect: Optional[bool]
attributes: Optional[List[Attribute]]

def __hash__(self):
return hash(f"{self.id}{self.class_id}{self.name}")


class AnnotationClassEntity(BaseModel):
id: Optional[StrictInt]
project_id: Optional[StrictInt]
type: ClassTypeEnum = ClassTypeEnum.OBJECT
name: StrictStr
color: HexColor
attribute_groups: List[AttributeGroup] = []

def __hash__(self):
return hash(f"{self.id}{self.type}{self.name}")

class Config:
validate_assignment = True
exclude_none = True
2 changes: 1 addition & 1 deletion src/superannotate/lib/core/service_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class ServiceResponse(BaseModel):
content: Union[bytes, str]
data: Any
count: Optional[int] = 0
_error: str
_error: str = None

class Config:
extra = Extra.allow
Expand Down
10 changes: 10 additions & 0 deletions src/superannotate/lib/core/serviceproviders.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ def get_annotations(
) -> List[dict]:
raise NotImplementedError

@abstractmethod
def create_annotation_classes(self, project_id: int, team_id: int, data: List):
raise NotImplementedError

@abstractmethod
async def download_annotations(
self,
Expand Down Expand Up @@ -400,3 +404,9 @@ def delete_custom_fields(
items: List[Dict[str, List[str]]],
) -> ServiceResponse:
raise NotImplementedError

@abstractmethod
def list_annotation_classes(
self, project_id: int, team_id: int, query_string: str = None
):
raise NotImplementedError
1 change: 1 addition & 0 deletions src/superannotate/lib/core/usecases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lib.core.usecases.annotations import * # noqa: F403 F401
from lib.core.usecases.classes import * # noqa: F403 F401
from lib.core.usecases.custom_fields import * # noqa: F403 F401
from lib.core.usecases.folders import * # noqa: F403 F401
from lib.core.usecases.images import * # noqa: F403 F401
Expand Down
Loading