Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 88 additions & 64 deletions src/superannotate/lib/app/interface/sdk_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,10 @@
from lib.core.types import MLModel
from lib.core.types import PriorityScoreEntity

from lib.core.pydantic_v1 import ValidationError
from lib.core.pydantic_v1 import constr
from lib.core.pydantic_v1 import conlist
from lib.core.pydantic_v1 import parse_obj_as
from lib.infrastructure.utils import extract_project_folder
from lib.infrastructure.validators import wrap_error

from superannotate_core.core.entities import BaseItemEntity
from superannotate_core.app import Project, Folder
Expand Down Expand Up @@ -743,7 +741,9 @@ def get_project_workflow(self, project: Union[str, dict]):
return workflow.data

def search_annotation_classes(
self, project: Union[NotEmptyStr, dict], name_contains: Optional[str] = None
self,
project: Union[NotEmptyStr, dict, Project],
name_contains: Optional[str] = None,
):
"""Searches annotation classes by name_prefix (case-insensitive)

Expand All @@ -757,17 +757,20 @@ def search_annotation_classes(
:return: annotation classes of the project
:rtype: list of dicts
"""
project_name, folder_name = extract_project_folder(project)
if isinstance(project, Project):
project = project.dict()

project_name, _ = extract_project_folder(project)
project = self.controller.get_project(project_name)
condition = Condition("project_id", project.id, EQ)
if name_contains:
condition &= Condition("name", name_contains, EQ) & Condition(
"pattern", True, EQ
)
response = self.controller.annotation_classes.list(condition)
if response.errors:
raise AppException(response.errors)
return response.data
condition = (
Condition("name", name_contains, EQ) & Condition("pattern", True, EQ)
if name_contains
else None
)
return [
annotation_class.dict()
for annotation_class in project.list_annotation_classes(condition)
]

def set_project_status(self, project: NotEmptyStr, status: PROJECT_STATUS):
"""Set project status
Expand Down Expand Up @@ -1353,8 +1356,7 @@ def upload_video_to_project(

def create_annotation_class(
self,
# todo project: Union[Project, NotEmptyStr],
project: Union[NotEmptyStr],
project: Union[NotEmptyStr, dict, Project],
name: NotEmptyStr,
color: NotEmptyStr,
attribute_groups: Optional[List[AttributeGroupSchema]] = None,
Expand Down Expand Up @@ -1447,38 +1449,41 @@ def create_annotation_class(
)

"""
# todo if isinstance(project, Project):
# if isinstance(project, Project):
# project = project.dict()
_class_type = ClassTypeEnum(class_type)
# try:
# annotation_class = AnnotationClass(
# {
# "name": name,
# "color": color,
# "type": ClassTypeEnum(class_type),
# "attribute_groups": attribute_groups,
# }
# )
# except ValidationError as e:
# raise AppException(wrap_error(e))
project = self.controller.get_project(project)
if isinstance(project, Project):
project = project.dict()

project_name, _ = extract_project_folder(project)
project = self.controller.get_project(project_name)

if project.type == ProjectType.Pixel.value and any(
"default_value" in attr_group.keys() for attr_group in attribute_groups
):
raise AppException(
'The "default_value" key is not supported for project type Pixel.'
)

_class_type = ClassTypeEnum.get_value(class_type)
if (
project.type != ProjectType.Document
and _class_type == ClassTypeEnum.RELATIONSHIP
and _class_type == ClassTypeEnum.relationship
):
raise AppException(
f"{class_type} class type is not supported in {project.type.name} project."
)
project.create_annotation_class(
annotation_class = project.create_annotation_class(
name=name,
color=color,
class_type=_class_type,
attribute_groups=attribute_groups,
)
if annotation_class:
return annotation_class.dict()
raise AppException("Failed to create annotation class")

def delete_annotation_class(
self, project: NotEmptyStr, annotation_class: Union[dict, NotEmptyStr]
self,
project: Union[NotEmptyStr, dict, Project],
annotation_class: Union[dict, NotEmptyStr],
):
"""Deletes annotation class from project

Expand All @@ -1489,24 +1494,32 @@ def delete_annotation_class(
:type annotation_class: str or dict
"""

if isinstance(annotation_class, str):
try:
annotation_class = AnnotationClassEntity(
name=annotation_class,
color="#ffffff", # noqa Random, just need to serialize
)
except ValidationError as e:
raise AppException(wrap_error(e))
if isinstance(annotation_class, dict) and "name" in annotation_class.keys():
class_name = annotation_class["name"]
elif isinstance(annotation_class, str):
class_name = annotation_class
else:
annotation_class = AnnotationClassEntity(**annotation_class)
project = self.controller.projects.get_by_name(project).data
raise AppException("Invalid value provided for annotation_class.")

self.controller.annotation_classes.delete(
project=project, annotation_class=annotation_class
)
if isinstance(project, Project):
project = project.dict()

project_name, _ = extract_project_folder(project)
project = self.controller.get_project(project_name)

condition = Condition("name", class_name, EQ) & Condition("pattern", True, EQ)
annotation_classes = project.list_annotation_classes(condition=condition)
if annotation_classes:
class_to_delete = annotation_classes[0]
logger.info(
"Deleting annotation class from project %s with name %s",
project.name,
class_to_delete.name,
)
project.delete_annotation_class(class_id=class_to_delete.id)

def download_annotation_classes_json(
self, project: NotEmptyStr, folder: Union[str, Path]
self, project: Union[NotEmptyStr, dict, Project], folder: Union[str, Path]
):
"""Downloads project classes.json to folder

Expand All @@ -1519,20 +1532,28 @@ def download_annotation_classes_json(
:return: path of the download file
:rtype: str
"""
if isinstance(project, Project):
project = project.dict()

project = self.controller.projects.get_by_name(project).data
response = self.controller.annotation_classes.download(
project=project, download_path=folder
project_name, _ = extract_project_folder(project)
project = self.controller.get_project(project_name)
logger.info(
f"Downloading classes.json from project {project.name} to folder {str(folder)}."
)
if response.errors:
raise AppException(response.errors)
return response.data
annotation_classes: List[dict] = [
annotation_class.dict()
for annotation_class in project.list_annotation_classes()
]
json_path = f"{folder}/classes.json"
with open(json_path, "w") as f:
json.dump(annotation_classes, f, indent=4)
return json_path

def create_annotation_classes_from_classes_json(
self,
project: Union[NotEmptyStr, dict],
project: Union[NotEmptyStr, dict, Project],
classes_json: Union[List[AnnotationClassEntity], str, Path],
from_s3_bucket=False,
from_s3_bucket: str = None,
):
"""Creates annotation classes in project from a SuperAnnotate format
annotation classes.json.
Expand All @@ -1549,7 +1570,7 @@ def create_annotation_classes_from_classes_json(
:return: list of created annotation class metadatas
:rtype: list of dicts
"""
if isinstance(classes_json, str) or isinstance(classes_json, Path):
if isinstance(classes_json, (str, Path)):
if from_s3_bucket:
from_session = boto3.Session()
from_s3 = from_session.resource("s3")
Expand All @@ -1561,12 +1582,12 @@ def create_annotation_classes_from_classes_json(
else:
data = open(classes_json, encoding="utf-8")
classes_json = json.load(data)
# try:
# # annotation_classes = parse_obj_as(List[AnnotationClassEntity], classes_json)
# annotation_classes = [AnnotationClassEntity.from_json(i) for i in classes_json]
# except ValidationError as _:
# raise AppException("Couldn't validate annotation classes.")
project = self.controller.get_project(project)

if isinstance(project, Project):
project = project.dict()

project_name, _ = extract_project_folder(project)
project = self.controller.get_project(project_name)
annotation_classes = project.create_annotation_classes(classes_json)
return [i.dict() for i in annotation_classes]

Expand Down Expand Up @@ -1684,7 +1705,10 @@ def download_image(
return response.data

def upload_annotations(
self, project: NotEmptyStr, annotations: List[dict], keep_status: bool = False
self,
project: Union[NotEmptyStr, dict],
annotations: List[dict],
keep_status: bool = False,
):
"""Uploads a list of annotation dicts as annotations to the SuperAnnotate directory.

Expand Down
1 change: 0 additions & 1 deletion src/superannotate/lib/core/usecases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from lib.core.usecases.annotations import * # noqa: F403 F401
from lib.core.usecases.classes import * # noqa: F403 F401
from lib.core.usecases.images import * # noqa: F403 F401
from lib.core.usecases.integrations import * # noqa: F403 F401
from lib.core.usecases.items import * # noqa: F403 F401
Expand Down
Loading